Shortcuts

Note

You are reading the documentation for MMEditing 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMEditing 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMEditing 1.0 for more details.

Source code for mmedit.models.components.discriminators.patch_disc

# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import ConvModule, build_conv_layer
from mmcv.runner import load_checkpoint

from mmedit.models.common import generation_init_weights
from mmedit.models.registry import COMPONENTS
from mmedit.utils import get_root_logger


[docs]@COMPONENTS.register_module() class PatchDiscriminator(nn.Module): """A PatchGAN discriminator. Args: in_channels (int): Number of channels in input images. base_channels (int): Number of channels at the first conv layer. Default: 64. num_conv (int): Number of stacked intermediate convs (excluding input and output conv). Default: 3. norm_cfg (dict): Config dict to build norm layer. Default: `dict(type='BN')`. init_cfg (dict): Config dict for initialization. `type`: The name of our initialization method. Default: 'normal'. `gain`: Scaling factor for normal, xavier and orthogonal. Default: 0.02. """ def __init__(self, in_channels, base_channels=64, num_conv=3, norm_cfg=dict(type='BN'), init_cfg=dict(type='normal', gain=0.02)): super().__init__() assert isinstance(norm_cfg, dict), ("'norm_cfg' should be dict, but" f'got {type(norm_cfg)}') assert 'type' in norm_cfg, "'norm_cfg' must have key 'type'" # We use norm layers in the patch discriminator. # Only for IN, use bias since it does not have affine parameters. use_bias = norm_cfg['type'] == 'IN' kernel_size = 4 padding = 1 # input layer sequence = [ ConvModule( in_channels=in_channels, out_channels=base_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=True, norm_cfg=None, act_cfg=dict(type='LeakyReLU', negative_slope=0.2)) ] # stacked intermediate layers, # gradually increasing the number of filters multiple_now = 1 multiple_prev = 1 for n in range(1, num_conv): multiple_prev = multiple_now multiple_now = min(2**n, 8) sequence += [ ConvModule( in_channels=base_channels * multiple_prev, out_channels=base_channels * multiple_now, kernel_size=kernel_size, stride=2, padding=padding, bias=use_bias, norm_cfg=norm_cfg, act_cfg=dict(type='LeakyReLU', negative_slope=0.2)) ] multiple_prev = multiple_now multiple_now = min(2**num_conv, 8) sequence += [ ConvModule( in_channels=base_channels * multiple_prev, out_channels=base_channels * multiple_now, kernel_size=kernel_size, stride=1, padding=padding, bias=use_bias, norm_cfg=norm_cfg, act_cfg=dict(type='LeakyReLU', negative_slope=0.2)) ] # output one-channel prediction map sequence += [ build_conv_layer( dict(type='Conv2d'), base_channels * multiple_now, 1, kernel_size=kernel_size, stride=1, padding=padding) ] self.model = nn.Sequential(*sequence) self.init_type = 'normal' if init_cfg is None else init_cfg.get( 'type', 'normal') self.init_gain = 0.02 if init_cfg is None else init_cfg.get( 'gain', 0.02)
[docs] def forward(self, x): """Forward function. Args: x (Tensor): Input tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ return self.model(x)
[docs] def init_weights(self, pretrained=None): """Initialize weights for the model. Args: pretrained (str, optional): Path for pretrained weights. If given None, pretrained weights will not be loaded. Default: None. """ if isinstance(pretrained, str): logger = get_root_logger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: generation_init_weights( self, init_type=self.init_type, init_gain=self.init_gain) else: raise TypeError("'pretrained' must be a str or None. " f'But received {type(pretrained)}.')
Read the Docs v: latest
Versions
latest
stable
1.x
v0.16.0
v0.15.2
v0.15.1
v0.15.0
v0.14.0
v0.13.0
v0.12.0
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.