You are reading the documentation for MMEditing 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMEditing 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMEditing 1.0 for more details.

Source code for mmedit.models.common.partial_conv

# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import CONV_LAYERS

[docs]@CONV_LAYERS.register_module(name='PConv') class PartialConv2d(nn.Conv2d): """Implementation for partial convolution. Image Inpainting for Irregular Holes Using Partial Convolutions [] Args: multi_channel (bool): If True, the mask is multi-channel. Otherwise, the mask is single-channel. eps (float): Need to be changed for mixed precision training. For mixed precision training, you need change 1e-8 to 1e-6. """ def __init__(self, *args, multi_channel=False, eps=1e-8, **kwargs): super().__init__(*args, **kwargs) # whether the mask is multi-channel or not self.multi_channel = multi_channel self.eps = eps if self.multi_channel: out_channels, in_channels = self.out_channels, self.in_channels else: out_channels, in_channels = 1, 1 self.register_buffer( 'weight_mask_updater', torch.ones(out_channels, in_channels, self.kernel_size[0], self.kernel_size[1])) self.mask_kernel_numel =[1:4]) self.mask_kernel_numel = (self.mask_kernel_numel).item()
[docs] def forward(self, input, mask=None, return_mask=True): """Forward function for partial conv2d. Args: input (torch.Tensor): Tensor with shape of (n, c, h, w). mask (torch.Tensor): Tensor with shape of (n, c, h, w) or (n, 1, h, w). If mask is not given, the function will work as standard conv2d. Default: None. return_mask (bool): If True and mask is not None, the updated mask will be returned. Default: True. Returns: torch.Tensor : Results after partial conv.\ torch.Tensor : Updated mask will be returned if mask is given and \ ``return_mask`` is True. """ assert input.dim() == 4 if mask is not None: assert mask.dim() == 4 if self.multi_channel: assert mask.shape[1] == input.shape[1] else: assert mask.shape[1] == 1 # update mask and compute mask ratio if mask is not None: with torch.no_grad(): updated_mask = F.conv2d( mask, self.weight_mask_updater, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation) mask_ratio = self.mask_kernel_numel / (updated_mask + self.eps) updated_mask = torch.clamp(updated_mask, 0, 1) mask_ratio = mask_ratio * updated_mask # standard conv2d if mask is not None: input = input * mask raw_out = super().forward(input) if mask is not None: if self.bias is None: output = raw_out * mask_ratio else: # compute new bias when mask is given bias_view = self.bias.view(1, self.out_channels, 1, 1) output = (raw_out - bias_view) * mask_ratio + bias_view output = output * updated_mask else: output = raw_out if return_mask and mask is not None: return output, updated_mask return output
