Shortcuts

Note

You are reading the documentation for MMEditing 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMEditing 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMEditing 1.0 for more details.

Source code for mmedit.models.common.aspp

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.cnn import ConvModule
from torch import nn
from torch.nn import functional as F

from .separable_conv_module import DepthwiseSeparableConvModule


class ASPPPooling(nn.Sequential):

    def __init__(self, in_channels, out_channels, conv_cfg, norm_cfg, act_cfg):
        super().__init__(
            nn.AdaptiveAvgPool2d(1),
            ConvModule(
                in_channels,
                out_channels,
                1,
                conv_cfg=conv_cfg,
                norm_cfg=norm_cfg,
                act_cfg=act_cfg))

    def forward(self, x):
        size = x.shape[-2:]
        for mod in self:
            x = mod(x)
        return F.interpolate(
            x, size=size, mode='bilinear', align_corners=False)


[docs]class ASPP(nn.Module): """ASPP module from DeepLabV3. The code is adopted from https://github.com/pytorch/vision/blob/master/torchvision/models/ segmentation/deeplabv3.py For more information about the module: `"Rethinking Atrous Convolution for Semantic Image Segmentation" <https://arxiv.org/abs/1706.05587>`_. Args: in_channels (int): Input channels of the module. out_channels (int): Output channels of the module. mid_channels (int): Output channels of the intermediate ASPP conv modules. dilations (Sequence[int]): Dilation rate of three ASPP conv module. Default: [12, 24, 36]. conv_cfg (dict): Config dict for convolution layer. If "None", nn.Conv2d will be applied. Default: None. norm_cfg (dict): Config dict for normalization layer. Default: dict(type='BN'). act_cfg (dict): Config dict for activation layer. Default: dict(type='ReLU'). separable_conv (bool): Whether replace normal conv with depthwise separable conv which is faster. Default: False. """ def __init__(self, in_channels, out_channels=256, mid_channels=256, dilations=(12, 24, 36), conv_cfg=None, norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU'), separable_conv=False): super().__init__() if separable_conv: conv_module = DepthwiseSeparableConvModule else: conv_module = ConvModule modules = [] modules.append( ConvModule( in_channels, mid_channels, 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)) for dilation in dilations: modules.append( conv_module( in_channels, mid_channels, 3, padding=dilation, dilation=dilation, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)) modules.append( ASPPPooling(in_channels, mid_channels, conv_cfg, norm_cfg, act_cfg)) self.convs = nn.ModuleList(modules) self.project = nn.Sequential( ConvModule( 5 * mid_channels, out_channels, 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg), nn.Dropout(0.5))
[docs] def forward(self, x): """Forward function for ASPP module. Args: x (Tensor): Input tensor with shape (N, C, H, W). Returns: Tensor: Output tensor. """ res = [] for conv in self.convs: res.append(conv(x)) res = torch.cat(res, dim=1) return self.project(res)
Read the Docs v: latest
Versions
latest
stable
1.x
v0.16.0
v0.15.2
v0.15.1
v0.15.0
v0.14.0
v0.13.0
v0.12.0
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.