Shortcuts

Note

You are reading the documentation for MMEditing 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMEditing 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMEditing 1.0 for more details.

Source code for mmedit.models.common.generation_model_utils

# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, kaiming_init, normal_init, xavier_init
from torch.nn import init


[docs]def generation_init_weights(module, init_type='normal', init_gain=0.02): """Default initialization of network weights for image generation. By default, we use normal init, but xavier and kaiming might work better for some applications. Args: module (nn.Module): Module to be initialized. init_type (str): The name of an initialization method: normal | xavier | kaiming | orthogonal. init_gain (float): Scaling factor for normal, xavier and orthogonal. """ def init_func(m): """Initialization function. Args: m (nn.Module): Module to be initialized. """ classname = m.__class__.__name__ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): if init_type == 'normal': normal_init(m, 0.0, init_gain) elif init_type == 'xavier': xavier_init(m, gain=init_gain, distribution='normal') elif init_type == 'kaiming': kaiming_init( m, a=0, mode='fan_in', nonlinearity='leaky_relu', distribution='normal') elif init_type == 'orthogonal': init.orthogonal_(m.weight, gain=init_gain) init.constant_(m.bias.data, 0.0) else: raise NotImplementedError( f"Initialization method '{init_type}' is not implemented") elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; # only normal distribution applies. normal_init(m, 1.0, init_gain) module.apply(init_func)
[docs]class GANImageBuffer: """This class implements an image buffer that stores previously generated images. This buffer allows us to update the discriminator using a history of generated images rather than the ones produced by the latest generator to reduce model oscillation. Args: buffer_size (int): The size of image buffer. If buffer_size = 0, no buffer will be created. buffer_ratio (float): The chance / possibility to use the images previously stored in the buffer. """ def __init__(self, buffer_size, buffer_ratio=0.5): self.buffer_size = buffer_size # create an empty buffer if self.buffer_size > 0: self.img_num = 0 self.image_buffer = [] self.buffer_ratio = buffer_ratio
[docs] def query(self, images): """Query current image batch using a history of generated images. Args: images (Tensor): Current image batch without history information. """ if self.buffer_size == 0: # if the buffer size is 0, do nothing return images return_images = [] for image in images: image = torch.unsqueeze(image.data, 0) # if the buffer is not full, keep inserting current images if self.img_num < self.buffer_size: self.img_num = self.img_num + 1 self.image_buffer.append(image) return_images.append(image) else: use_buffer = np.random.random() < self.buffer_ratio # by self.buffer_ratio, the buffer will return a previously # stored image, and insert the current image into the buffer if use_buffer: random_id = np.random.randint(0, self.buffer_size) image_tmp = self.image_buffer[random_id].clone() self.image_buffer[random_id] = image return_images.append(image_tmp) # by (1 - self.buffer_ratio), the buffer will return the # current image else: return_images.append(image) # collect all the images and return return_images = torch.cat(return_images, 0) return return_images
[docs]class UnetSkipConnectionBlock(nn.Module): """Construct a Unet submodule with skip connections, with the following. structure: downsampling - `submodule` - upsampling. Args: outer_channels (int): Number of channels at the outer conv layer. inner_channels (int): Number of channels at the inner conv layer. in_channels (int): Number of channels in input images/features. If is None, equals to `outer_channels`. Default: None. submodule (UnetSkipConnectionBlock): Previously constructed submodule. Default: None. is_outermost (bool): Whether this module is the outermost module. Default: False. is_innermost (bool): Whether this module is the innermost module. Default: False. norm_cfg (dict): Config dict to build norm layer. Default: `dict(type='BN')`. use_dropout (bool): Whether to use dropout layers. Default: False. """ def __init__(self, outer_channels, inner_channels, in_channels=None, submodule=None, is_outermost=False, is_innermost=False, norm_cfg=dict(type='BN'), use_dropout=False): super().__init__() # cannot be both outermost and innermost assert not (is_outermost and is_innermost), ( "'is_outermost' and 'is_innermost' cannot be True" 'at the same time.') self.is_outermost = is_outermost assert isinstance(norm_cfg, dict), ("'norm_cfg' should be dict, but" f'got {type(norm_cfg)}') assert 'type' in norm_cfg, "'norm_cfg' must have key 'type'" # We use norm layers in the unet skip connection block. # Only for IN, use bias since it does not have affine parameters. use_bias = norm_cfg['type'] == 'IN' kernel_size = 4 stride = 2 padding = 1 if in_channels is None: in_channels = outer_channels down_conv_cfg = dict(type='Conv2d') down_norm_cfg = norm_cfg down_act_cfg = dict(type='LeakyReLU', negative_slope=0.2) up_conv_cfg = dict(type='Deconv') up_norm_cfg = norm_cfg up_act_cfg = dict(type='ReLU') up_in_channels = inner_channels * 2 up_bias = use_bias middle = [submodule] upper = [] if is_outermost: down_act_cfg = None down_norm_cfg = None up_bias = True up_norm_cfg = None upper = [nn.Tanh()] elif is_innermost: down_norm_cfg = None up_in_channels = inner_channels middle = [] else: upper = [nn.Dropout(0.5)] if use_dropout else [] down = [ ConvModule( in_channels=in_channels, out_channels=inner_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=use_bias, conv_cfg=down_conv_cfg, norm_cfg=down_norm_cfg, act_cfg=down_act_cfg, order=('act', 'conv', 'norm')) ] up = [ ConvModule( in_channels=up_in_channels, out_channels=outer_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=up_bias, conv_cfg=up_conv_cfg, norm_cfg=up_norm_cfg, act_cfg=up_act_cfg, order=('act', 'conv', 'norm')) ] model = down + middle + up + upper self.model = nn.Sequential(*model)
[docs] def forward(self, x): """Forward function. Args: x (Tensor): Input tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ if self.is_outermost: return self.model(x) # add skip connections return torch.cat([x, self.model(x)], 1)
[docs]class ResidualBlockWithDropout(nn.Module): """Define a Residual Block with dropout layers. Ref: Deep Residual Learning for Image Recognition A residual block is a conv block with skip connections. A dropout layer is added between two common conv modules. Args: channels (int): Number of channels in the conv layer. padding_mode (str): The name of padding layer: 'reflect' | 'replicate' | 'zeros'. norm_cfg (dict): Config dict to build norm layer. Default: `dict(type='IN')`. use_dropout (bool): Whether to use dropout layers. Default: True. """ def __init__(self, channels, padding_mode, norm_cfg=dict(type='BN'), use_dropout=True): super().__init__() assert isinstance(norm_cfg, dict), ("'norm_cfg' should be dict, but" f'got {type(norm_cfg)}') assert 'type' in norm_cfg, "'norm_cfg' must have key 'type'" # We use norm layers in the residual block with dropout layers. # Only for IN, use bias since it does not have affine parameters. use_bias = norm_cfg['type'] == 'IN' block = [ ConvModule( in_channels=channels, out_channels=channels, kernel_size=3, padding=1, bias=use_bias, norm_cfg=norm_cfg, padding_mode=padding_mode) ] if use_dropout: block += [nn.Dropout(0.5)] block += [ ConvModule( in_channels=channels, out_channels=channels, kernel_size=3, padding=1, bias=use_bias, norm_cfg=norm_cfg, act_cfg=None, padding_mode=padding_mode) ] self.block = nn.Sequential(*block)
[docs] def forward(self, x): """Forward function. Add skip connections without final ReLU. Args: x (Tensor): Input tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ out = x + self.block(x) return out
Read the Docs v: latest
Versions
latest
stable
1.x
v0.16.0
v0.15.2
v0.15.1
v0.15.0
v0.14.0
v0.13.0
v0.12.0
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.