You are reading the documentation for MMEditing 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMEditing 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMEditing 1.0 for more details.

Source code for mmedit.models.backbones.sr_backbones.sr_resnet

# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.runner import load_checkpoint

from mmedit.models.common import (PixelShufflePack, ResidualBlockNoBN,
                                  default_init_weights, make_layer)
from mmedit.models.registry import BACKBONES
from mmedit.utils import get_root_logger

[docs]@BACKBONES.register_module() class MSRResNet(nn.Module): """Modified SRResNet. A compacted version modified from SRResNet in "Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network". It uses residual blocks without BN, similar to EDSR. Currently, it supports x2, x3 and x4 upsampling scale factor. Args: in_channels (int): Channel number of inputs. out_channels (int): Channel number of outputs. mid_channels (int): Channel number of intermediate features. Default: 64. num_blocks (int): Block number in the trunk network. Default: 16. upscale_factor (int): Upsampling factor. Support x2, x3 and x4. Default: 4. """ _supported_upscale_factors = [2, 3, 4] def __init__(self, in_channels, out_channels, mid_channels=64, num_blocks=16, upscale_factor=4): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.mid_channels = mid_channels self.num_blocks = num_blocks self.upscale_factor = upscale_factor self.conv_first = nn.Conv2d( in_channels, mid_channels, 3, 1, 1, bias=True) self.trunk_net = make_layer( ResidualBlockNoBN, num_blocks, mid_channels=mid_channels) # upsampling if self.upscale_factor in [2, 3]: self.upsample1 = PixelShufflePack( mid_channels, mid_channels, self.upscale_factor, upsample_kernel=3) elif self.upscale_factor == 4: self.upsample1 = PixelShufflePack( mid_channels, mid_channels, 2, upsample_kernel=3) self.upsample2 = PixelShufflePack( mid_channels, mid_channels, 2, upsample_kernel=3) else: raise ValueError( f'Unsupported scale factor {self.upscale_factor}. ' f'Currently supported ones are ' f'{self._supported_upscale_factors}.') self.conv_hr = nn.Conv2d( mid_channels, mid_channels, 3, 1, 1, bias=True) self.conv_last = nn.Conv2d( mid_channels, out_channels, 3, 1, 1, bias=True) self.img_upsampler = nn.Upsample( scale_factor=self.upscale_factor, mode='bilinear', align_corners=False) # activation function self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
[docs] def forward(self, x): """Forward function. Args: x (Tensor): Input tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ feat = self.lrelu(self.conv_first(x)) out = self.trunk_net(feat) if self.upscale_factor in [2, 3]: out = self.upsample1(out) elif self.upscale_factor == 4: out = self.upsample1(out) out = self.upsample2(out) out = self.conv_last(self.lrelu(self.conv_hr(out))) upsampled_img = self.img_upsampler(x) out += upsampled_img return out
[docs] def init_weights(self, pretrained=None, strict=True): """Init weights for models. Args: pretrained (str, optional): Path for pretrained weights. If given None, pretrained weights will not be loaded. Defaults to None. strict (boo, optional): Whether strictly load the pretrained model. Defaults to True. """ if isinstance(pretrained, str): logger = get_root_logger() load_checkpoint(self, pretrained, strict=strict, logger=logger) elif pretrained is None: # Initialization methods like `kaiming_init` are for VGG-style # modules. For modules with residual paths, using smaller std is # better for stability and performance. There is a global residual # path in MSRResNet and empirically we use 0.1. See more details in # "ESRGAN: Enhanced Super-Resolution Generative Adversarial # Networks" for m in [self.conv_first, self.conv_hr, self.conv_last]: default_init_weights(m, 0.1) else: raise TypeError(f'"pretrained" must be a str or None. ' f'But received {type(pretrained)}.')
Read the Docs v: latest
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.