Shortcuts

Note

You are reading the documentation for MMEditing 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMEditing 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMEditing 1.0 for more details.

Source code for mmedit.models.backbones.encoder_decoders.decoders.fba_decoder

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, constant_init, kaiming_init
from mmcv.runner import load_checkpoint
from mmcv.utils.parrots_wrapper import _BatchNorm

from mmedit.models.registry import COMPONENTS
from mmedit.utils import get_root_logger


[docs]@COMPONENTS.register_module() class FBADecoder(nn.Module): """Decoder for FBA matting. Args: pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid Module. in_channels (int): Input channels. channels (int): Channels after modules, before conv_seg. conv_cfg (dict|None): Config of conv layers. norm_cfg (dict|None): Config of norm layers. act_cfg (dict): Config of activation layers. align_corners (bool): align_corners argument of F.interpolate. """ def __init__(self, pool_scales, in_channels, channels, conv_cfg=None, norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU'), align_corners=False): super().__init__() assert isinstance(pool_scales, (list, tuple)) # Pyramid Pooling Module self.pool_scales = pool_scales self.in_channels = in_channels self.channels = channels self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.act_cfg = act_cfg self.align_corners = align_corners self.batch_norm = False self.ppm = [] for scale in self.pool_scales: self.ppm.append( nn.Sequential( nn.AdaptiveAvgPool2d(scale), *(ConvModule( self.in_channels, self.channels, kernel_size=1, bias=True, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg).children()))) self.ppm = nn.ModuleList(self.ppm) # Followed the author's implementation that # concatenate conv layers described in the supplementary # material between up operations self.conv_up1 = nn.Sequential(*(list( ConvModule( self.in_channels + len(pool_scales) * 256, self.channels, padding=1, kernel_size=3, bias=True, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg).children()) + list( ConvModule( self.channels, self.channels, padding=1, bias=True, kernel_size=3, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg).children()))) self.conv_up2 = nn.Sequential(*(list( ConvModule( self.channels * 2, self.channels, padding=1, kernel_size=3, bias=True, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg).children()))) if (self.norm_cfg['type'] == 'BN'): d_up3 = 128 else: d_up3 = 64 self.conv_up3 = nn.Sequential(*(list( ConvModule( self.channels + d_up3, 64, padding=1, kernel_size=3, bias=True, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg).children()))) self.unpool = nn.MaxUnpool2d(2, stride=2) conv_up4_list = list( ConvModule( 64 + 3 + 3 + 2, 32, padding=1, kernel_size=3, bias=True, act_cfg=self.act_cfg).children()) conv_up4_list += list( ConvModule( 32, 16, padding=1, kernel_size=3, bias=True, act_cfg=self.act_cfg).children()) conv_up4_list += list( ConvModule( 16, 7, padding=0, kernel_size=1, bias=True, act_cfg=None).children()) self.conv_up4 = nn.Sequential(*conv_up4_list)
[docs] def init_weights(self, pretrained=None): """Init weights for the model. Args: pretrained (str, optional): Path for pretrained weights. If given None, pretrained weights will not be loaded. Defaults to None. """ if isinstance(pretrained, str): logger = get_root_logger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: for m in self.modules(): if isinstance(m, nn.Conv2d): kaiming_init(m) elif isinstance(m, (_BatchNorm, nn.GroupNorm)): constant_init(m, 1) else: raise TypeError('pretrained must be a str or None')
[docs] def forward(self, inputs): """Forward function. Args: inputs (dict): Output dict of FbaEncoder. Returns: Tensor: Predicted alpha, fg and bg of the current batch. """ conv_out = inputs['conv_out'] img = inputs['merged'] two_channel_trimap = inputs['two_channel_trimap'] conv5 = conv_out[-1] input_size = conv5.size() ppm_out = [conv5] for pool_scale in self.ppm: ppm_out.append( nn.functional.interpolate( pool_scale(conv5), (input_size[2], input_size[3]), mode='bilinear', align_corners=self.align_corners)) ppm_out = torch.cat(ppm_out, 1) x = self.conv_up1(ppm_out) x = torch.nn.functional.interpolate( x, scale_factor=2, mode='bilinear', align_corners=self.align_corners) x = torch.cat((x, conv_out[-4]), 1) x = self.conv_up2(x) x = torch.nn.functional.interpolate( x, scale_factor=2, mode='bilinear', align_corners=self.align_corners) x = torch.cat((x, conv_out[-5]), 1) x = self.conv_up3(x) x = torch.nn.functional.interpolate( x, scale_factor=2, mode='bilinear', align_corners=self.align_corners) x = torch.cat((x, conv_out[-6][:, :3], img, two_channel_trimap), 1) output = self.conv_up4(x) alpha = torch.clamp(output[:, 0:1], 0, 1) F = torch.sigmoid(output[:, 1:4]) B = torch.sigmoid(output[:, 4:7]) return alpha, F, B
Read the Docs v: latest
Versions
latest
stable
1.x
v0.16.0
v0.15.2
v0.15.1
v0.15.0
v0.14.0
v0.13.0
v0.12.0
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.