• Docs >
  • Module code >
  • mmedit.models.backbones.encoder_decoders.encoders.pconv_encoder
Shortcuts

Note

You are reading the documentation for MMEditing 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMEditing 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMEditing 1.0 for more details.

Source code for mmedit.models.backbones.encoder_decoders.encoders.pconv_encoder

# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.utils.parrots_wrapper import _BatchNorm

from mmedit.models.common import MaskConvModule
from mmedit.models.registry import COMPONENTS


[docs]@COMPONENTS.register_module() class PConvEncoder(nn.Module): """Encoder with partial conv. About the details for this architecture, pls see: Image Inpainting for Irregular Holes Using Partial Convolutions Args: in_channels (int): The number of input channels. Default: 3. num_layers (int): The number of convolutional layers. Default 7. conv_cfg (dict): Config for convolution module. Default: {'type': 'PConv', 'multi_channel': True}. norm_cfg (dict): Config for norm layer. Default: {'type': 'BN'}. norm_eval (bool): Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effective on Batch Norm and its variants only. """ def __init__(self, in_channels=3, num_layers=7, conv_cfg=dict(type='PConv', multi_channel=True), norm_cfg=dict(type='BN', requires_grad=True), norm_eval=False): super().__init__() self.num_layers = num_layers self.norm_eval = norm_eval self.enc1 = MaskConvModule( in_channels, 64, kernel_size=7, stride=2, padding=3, conv_cfg=conv_cfg, norm_cfg=None, act_cfg=dict(type='ReLU')) self.enc2 = MaskConvModule( 64, 128, kernel_size=5, stride=2, padding=2, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=dict(type='ReLU')) self.enc3 = MaskConvModule( 128, 256, kernel_size=5, stride=2, padding=2, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=dict(type='ReLU')) self.enc4 = MaskConvModule( 256, 512, kernel_size=3, stride=2, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=dict(type='ReLU')) for i in range(4, num_layers): name = f'enc{i+1}' self.add_module( name, MaskConvModule( 512, 512, kernel_size=3, stride=2, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=dict(type='ReLU')))
[docs] def train(self, mode=True): super().train(mode) if mode and self.norm_eval: for m in self.modules(): # trick: eval have effect on BatchNorm only if isinstance(m, _BatchNorm): m.eval()
[docs] def forward(self, x, mask): """Forward function for partial conv encoder. Args: x (torch.Tensor): Masked image with shape (n, c, h, w). mask (torch.Tensor): Mask tensor with shape (n, c, h, w). Returns: dict: Contains the results and middle level features in this \ module. `hidden_feats` contain the middle feature maps and \ `hidden_masks` store updated masks. """ # dict for hidden layers of main information flow hidden_feats = {} # dict for hidden layers of mask information flow hidden_masks = {} hidden_feats['h0'], hidden_masks['h0'] = x, mask h_key_prev = 'h0' for i in range(1, self.num_layers + 1): l_key = f'enc{i}' h_key = f'h{i}' hidden_feats[h_key], hidden_masks[h_key] = getattr(self, l_key)( hidden_feats[h_key_prev], hidden_masks[h_key_prev]) h_key_prev = h_key outputs = dict( out=hidden_feats[f'h{self.num_layers}'], hidden_feats=hidden_feats, hidden_masks=hidden_masks) return outputs
Read the Docs v: latest
Versions
latest
stable
1.x
v0.16.0
v0.15.2
v0.15.1
v0.15.0
v0.14.0
v0.13.0
v0.12.0
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.