You are reading the documentation for MMEditing 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMEditing 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMEditing 1.0 for more details.

# Copyright (c) OpenMMLab. All rights reserved.
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F

[docs]class ContextualAttentionModule(nn.Module): """Contexture attention module. The details of this module can be found in: Generative Image Inpainting with Contextual Attention Args: unfold_raw_kernel_size (int): Kernel size used in unfolding raw feature. Default: 4. unfold_raw_stride (int): Stride used in unfolding raw feature. Default: 2. unfold_raw_padding (int): Padding used in unfolding raw feature. Default: 1. unfold_corr_kernel_size (int): Kernel size used in unfolding context for computing correlation maps. Default: 3. unfold_corr_stride (int): Stride used in unfolding context for computing correlation maps. Default: 1. unfold_corr_dilation (int): Dilation used in unfolding context for computing correlation maps. Default: 1. unfold_corr_padding (int): Padding used in unfolding context for computing correlation maps. Default: 1. scale (float): The resale factor used in resize input features. Default: 0.5. fuse_kernel_size (int): The kernel size used in fusion module. Default: 3. softmax_scale (float): The scale factor for softmax function. Default: 10. return_attention_score (bool): If True, the attention score will be returned. Default: True. """ def __init__(self, unfold_raw_kernel_size=4, unfold_raw_stride=2, unfold_raw_padding=1, unfold_corr_kernel_size=3, unfold_corr_stride=1, unfold_corr_dilation=1, unfold_corr_padding=1, scale=0.5, fuse_kernel_size=3, softmax_scale=10, return_attention_score=True): super().__init__() self.unfold_raw_kernel_size = unfold_raw_kernel_size self.unfold_raw_stride = unfold_raw_stride self.unfold_raw_padding = unfold_raw_padding self.unfold_corr_kernel_size = unfold_corr_kernel_size self.unfold_corr_stride = unfold_corr_stride self.unfold_corr_dilation = unfold_corr_dilation self.unfold_corr_padding = unfold_corr_padding self.scale = scale self.fuse_kernel_size = fuse_kernel_size self.with_fuse_correlation = fuse_kernel_size > 1 self.softmax_scale = softmax_scale self.return_attention_score = return_attention_score if self.with_fuse_correlation: assert fuse_kernel_size % 2 == 1 fuse_kernel = torch.eye(fuse_kernel_size).view( 1, 1, fuse_kernel_size, fuse_kernel_size) self.register_buffer('fuse_kernel', fuse_kernel) padding = int((fuse_kernel_size - 1) // 2) self.fuse_conv = partial(F.conv2d, padding=padding, stride=1) self.softmax = nn.Softmax(dim=1)
[docs] def forward(self, x, context, mask=None): """Forward Function. Args: x (torch.Tensor): Tensor with shape (n, c, h, w). context (torch.Tensor): Tensor with shape (n, c, h, w). mask (torch.Tensor): Tensor with shape (n, 1, h, w). Default: None. Returns: tuple(torch.Tensor): Features after contextural attention. """ # raw features to be used in copy (deconv) raw_context = context raw_context_cols = self.im2col( raw_context, kernel_size=self.unfold_raw_kernel_size, stride=self.unfold_raw_stride, padding=self.unfold_raw_padding, normalize=False, return_cols=True) # resize the feature to reduce computational cost x = F.interpolate(x, scale_factor=self.scale) context = F.interpolate(context, scale_factor=self.scale) context_cols = self.im2col( context, kernel_size=self.unfold_corr_kernel_size, stride=self.unfold_corr_stride, padding=self.unfold_corr_padding, dilation=self.unfold_corr_dilation, normalize=True, return_cols=True) h_unfold, w_unfold = self.calculate_unfold_hw( context.size()[-2:], kernel_size=self.unfold_corr_kernel_size, stride=self.unfold_corr_stride, padding=self.unfold_corr_padding, dilation=self.unfold_corr_dilation, ) # reshape context_cols to # (n*h_unfold*w_unfold, c, unfold_mks, unfold_mks) # 'mks' is short for 'mask_kernel_size' context_cols = context_cols.reshape(-1, *context_cols.shape[2:]) # the shape of correlation map should be: # (n, h_unfold*w_unfold, h', w') correlation_map = self.patch_correlation(x, context_cols) # fuse correlation map to enlarge consistent attention region. if self.with_fuse_correlation: correlation_map = self.fuse_correlation_map( correlation_map, h_unfold, w_unfold) correlation_map = self.mask_correlation_map(correlation_map, mask=mask) attention_score = self.softmax(correlation_map * self.softmax_scale) raw_context_filter = raw_context_cols.reshape( -1, *raw_context_cols.shape[2:]) output = self.patch_copy_deconv(attention_score, raw_context_filter) # deconv will cause overlap and we need to remove the effects of that overlap_factor = self.calculate_overlap_factor(attention_score) output /= overlap_factor if self.return_attention_score: n, _, h_s, w_s = attention_score.size() attention_score = attention_score.view(n, h_unfold, w_unfold, h_s, w_s) return output, attention_score return output
[docs] def patch_correlation(self, x, kernel): """Calculate patch correlation. Args: x (torch.Tensor): Input tensor. kernel (torch.Tensor): Kernel tensor. Returns: torch.Tensor: Tensor with shape of (n, l, h, w). """ n, _, h_in, w_in = x.size() patch_corr = F.conv2d( x.view(1, -1, h_in, w_in), kernel, stride=self.unfold_corr_stride, padding=self.unfold_corr_padding, dilation=self.unfold_corr_dilation, groups=n) h_out, w_out = patch_corr.size()[-2:] return patch_corr.view(n, -1, h_out, w_out)
[docs] def patch_copy_deconv(self, attention_score, context_filter): """Copy patches using deconv. Args: attention_score (torch.Tensor): Tensor with shape of (n, l , h, w). context_filter (torch.Tensor): Filter kernel. Returns: torch.Tensor: Tensor with shape of (n, c, h, w). """ n, _, h, w = attention_score.size() attention_score = attention_score.view(1, -1, h, w) output = F.conv_transpose2d( attention_score, context_filter, stride=self.unfold_raw_stride, padding=self.unfold_raw_padding, groups=n) h_out, w_out = output.size()[-2:] return output.view(n, -1, h_out, w_out)
[docs] def fuse_correlation_map(self, correlation_map, h_unfold, w_unfold): """Fuse correlation map. This operation is to fuse correlation map for increasing large consistent correlation regions. The mechanism behind this op is simple and easy to understand. A standard 'Eye' matrix will be applied as a filter on the correlation map in horizontal and vertical direction. The shape of input correlation map is (n, h_unfold*w_unfold, h, w). When adopting fusing, we will apply convolutional filter in the reshaped feature map with shape of (n, 1, h_unfold*w_fold, h*w). A simple specification for horizontal direction is shown below: .. code-block:: python (h, (h, (h, (h, 0) 1) 2) 3) ... (h, 0) (h, 1) 1 (h, 2) 1 (h, 3) 1 ... """ # horizontal direction n, _, h_map, w_map = correlation_map.size() map_ = correlation_map.permute(0, 2, 3, 1) map_ = map_.reshape(n, h_map * w_map, h_unfold * w_unfold, 1) map_ = map_.permute(0, 3, 1, 2).contiguous() map_ = self.fuse_conv(map_, self.fuse_kernel) correlation_map = map_.view(n, h_unfold, w_unfold, h_map, w_map) # vertical direction map_ = correlation_map.permute(0, 2, 1, 4, 3).reshape(n, 1, h_unfold * w_unfold, h_map * w_map) map_ = self.fuse_conv(map_, self.fuse_kernel) # Note that the dimension should be transposed since the convolution of # eye matrix will put the normed scores into the last several dimension correlation_map = map_.view(n, w_unfold, h_unfold, w_map, h_map).permute(0, 4, 3, 2, 1) correlation_map = correlation_map.reshape(n, -1, h_unfold, w_unfold) return correlation_map
[docs] def calculate_unfold_hw(self, input_size, kernel_size=3, stride=1, dilation=1, padding=0): """Calculate (h, w) after unfolding. The official implementation of `unfold` in pytorch will put the dimension (h, w) into `L`. Thus, this function is just to calculate the (h, w) according to the equation in: """ h_in, w_in = input_size h_unfold = int((h_in + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1) w_unfold = int((w_in + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1) return h_unfold, w_unfold
[docs] def calculate_overlap_factor(self, attention_score): """Calculate the overlap factor after applying deconv. Args: attention_score (torch.Tensor): The attention score with shape of (n, c, h, w). Returns: torch.Tensor: The overlap factor will be returned. """ h, w = attention_score.shape[-2:] kernel_size = self.unfold_raw_kernel_size ones_input = torch.ones(1, 1, h, w).to(attention_score) ones_filter = torch.ones(1, 1, kernel_size, kernel_size).to(attention_score) overlap = F.conv_transpose2d( ones_input, ones_filter, stride=self.unfold_raw_stride, padding=self.unfold_raw_padding) # avoid division by zero overlap[overlap == 0] = 1. return overlap
[docs] def mask_correlation_map(self, correlation_map, mask): """Add mask weight for correlation map. Add a negative infinity number to the masked regions so that softmax function will result in 'zero' in those regions. Args: correlation_map (torch.Tensor): Correlation map with shape of (n, h_unfold*w_unfold, h_map, w_map). mask (torch.Tensor): Mask tensor with shape of (n, c, h, w). '1' in the mask indicates masked region while '0' indicates valid region. Returns: torch.Tensor: Updated correlation map with mask. """ if mask is not None: mask = F.interpolate(mask, scale_factor=self.scale) # if any pixel is masked in patch, the patch is considered to be # masked mask_cols = self.im2col( mask, kernel_size=self.unfold_corr_kernel_size, stride=self.unfold_corr_stride, padding=self.unfold_corr_padding, dilation=self.unfold_corr_dilation) mask_cols = (mask_cols.sum(dim=1, keepdim=True) > 0).float() mask_cols = mask_cols.permute(0, 2, 1).reshape(mask.size(0), -1, 1, 1) # add negative inf will bring zero in softmax mask_cols[mask_cols == 1] = -float('inf') correlation_map += mask_cols return correlation_map
[docs] def im2col(self, img, kernel_size, stride=1, padding=0, dilation=1, normalize=False, return_cols=False): """Reshape image-style feature to columns. This function is used for unfold feature maps to columns. The details of this function can be found in: Args: img (torch.Tensor): Features to be unfolded. The shape of this feature should be (n, c, h, w). kernel_size (int): In this function, we only support square kernel with same height and width. stride (int): Stride number in unfolding. Default: 1. padding (int): Padding number in unfolding. Default: 0. dilation (int): Dilation number in unfolding. Default: 1. normalize (bool): If True, the unfolded feature will be normalized. Default: False. return_cols (bool): The official implementation in PyTorch of unfolding will return features with shape of (n, c*$prod{kernel_size}$, L). If True, the features will be reshaped to (n, L, c, kernel_size, kernel_size). Otherwise, the results will maintain the shape as the official implementation. Returns: torch.Tensor: Unfolded columns. If `return_cols` is True, the \ shape of output tensor is \ `(n, L, c, kernel_size, kernel_size)`. Otherwise, the shape \ will be `(n, c*$prod{kernel_size}$, L)`. """ # unfold img to columns with shape (n, c*kernel_size**2, num_cols) img_unfold = F.unfold( img, kernel_size, stride=stride, padding=padding, dilation=dilation) # normalize the feature map if normalize: norm = torch.sqrt((img_unfold**2).sum(dim=1, keepdim=True)) eps = torch.tensor([1e-4]).to(img) img_unfold = img_unfold / torch.max(norm, eps) if return_cols: img_unfold_ = img_unfold.permute(0, 2, 1) n, num_cols = img_unfold_.size()[:2] img_cols = img_unfold_.view(n, num_cols, img.size(1), kernel_size, kernel_size) return img_cols return img_unfold
