Shortcuts

Note

You are reading the documentation for MMEditing 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMEditing 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMEditing 1.0 for more details.

Source code for mmedit.models.backbones.generation_backbones.resnet_generator

# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.runner import load_checkpoint

from mmedit.models.common import (ResidualBlockWithDropout,
                                  generation_init_weights)
from mmedit.models.registry import BACKBONES
from mmedit.utils import get_root_logger


[docs]@BACKBONES.register_module() class ResnetGenerator(nn.Module): """Construct a Resnet-based generator that consists of residual blocks between a few downsampling/upsampling operations. Args: in_channels (int): Number of channels in input images. out_channels (int): Number of channels in output images. base_channels (int): Number of filters at the last conv layer. Default: 64. norm_cfg (dict): Config dict to build norm layer. Default: `dict(type='IN')`. use_dropout (bool): Whether to use dropout layers. Default: False. num_blocks (int): Number of residual blocks. Default: 9. padding_mode (str): The name of padding layer in conv layers: 'reflect' | 'replicate' | 'zeros'. Default: 'reflect'. init_cfg (dict): Config dict for initialization. `type`: The name of our initialization method. Default: 'normal'. `gain`: Scaling factor for normal, xavier and orthogonal. Default: 0.02. """ def __init__(self, in_channels, out_channels, base_channels=64, norm_cfg=dict(type='IN'), use_dropout=False, num_blocks=9, padding_mode='reflect', init_cfg=dict(type='normal', gain=0.02)): super().__init__() assert num_blocks >= 0, ('Number of residual blocks must be ' f'non-negative, but got {num_blocks}.') assert isinstance(norm_cfg, dict), ("'norm_cfg' should be dict, but" f'got {type(norm_cfg)}') assert 'type' in norm_cfg, "'norm_cfg' must have key 'type'" # We use norm layers in the resnet generator. # Only for IN, use bias since it does not have affine parameters. use_bias = norm_cfg['type'] == 'IN' model = [] model += [ ConvModule( in_channels=in_channels, out_channels=base_channels, kernel_size=7, padding=3, bias=use_bias, norm_cfg=norm_cfg, padding_mode=padding_mode) ] num_down = 2 # add downsampling layers for i in range(num_down): multiple = 2**i model += [ ConvModule( in_channels=base_channels * multiple, out_channels=base_channels * multiple * 2, kernel_size=3, stride=2, padding=1, bias=use_bias, norm_cfg=norm_cfg) ] # add residual blocks multiple = 2**num_down for i in range(num_blocks): model += [ ResidualBlockWithDropout( base_channels * multiple, padding_mode=padding_mode, norm_cfg=norm_cfg, use_dropout=use_dropout) ] # add upsampling layers for i in range(num_down): multiple = 2**(num_down - i) model += [ ConvModule( in_channels=base_channels * multiple, out_channels=base_channels * multiple // 2, kernel_size=3, stride=2, padding=1, bias=use_bias, conv_cfg=dict(type='Deconv', output_padding=1), norm_cfg=norm_cfg) ] model += [ ConvModule( in_channels=base_channels, out_channels=out_channels, kernel_size=7, padding=3, bias=True, norm_cfg=None, act_cfg=dict(type='Tanh'), padding_mode=padding_mode) ] self.model = nn.Sequential(*model) self.init_type = 'normal' if init_cfg is None else init_cfg.get( 'type', 'normal') self.init_gain = 0.02 if init_cfg is None else init_cfg.get( 'gain', 0.02)
[docs] def forward(self, x): """Forward function. Args: x (Tensor): Input tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ return self.model(x)
[docs] def init_weights(self, pretrained=None, strict=True): """Initialize weights for the model. Args: pretrained (str, optional): Path for pretrained weights. If given None, pretrained weights will not be loaded. Default: None. strict (bool, optional): Whether to allow different params for the model and checkpoint. Default: True. """ if isinstance(pretrained, str): logger = get_root_logger() load_checkpoint(self, pretrained, strict=strict, logger=logger) elif pretrained is None: generation_init_weights( self, init_type=self.init_type, init_gain=self.init_gain) else: raise TypeError("'pretrained' must be a str or None. " f'But received {type(pretrained)}.')
Read the Docs v: latest
Versions
latest
stable
1.x
v0.16.0
v0.15.2
v0.15.1
v0.15.0
v0.14.0
v0.13.0
v0.12.0
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.