Shortcuts

Note

You are reading the documentation for MMEditing 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMEditing 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMEditing 1.0 for more details.

Source code for mmedit.models.common.gated_conv_module

# Copyright (c) OpenMMLab. All rights reserved.
import copy

import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, build_activation_layer


[docs]class SimpleGatedConvModule(nn.Module): """Simple Gated Convolutional Module. This module is a simple gated convolutional module. The detailed formula is: .. math:: y = \\phi(conv1(x)) * \\sigma(conv2(x)), where `phi` is the feature activation function and `sigma` is the gate activation function. In default, the gate activation function is sigmoid. Args: in_channels (int): Same as nn.Conv2d. out_channels (int): The number of channels of the output feature. Note that `out_channels` in the conv module is doubled since this module contains two convolutions for feature and gate separately. kernel_size (int or tuple[int]): Same as nn.Conv2d. feat_act_cfg (dict): Config dict for feature activation layer. gate_act_cfg (dict): Config dict for gate activation layer. kwargs (keyword arguments): Same as `ConvModule`. """ def __init__(self, in_channels, out_channels, kernel_size, feat_act_cfg=dict(type='ELU'), gate_act_cfg=dict(type='Sigmoid'), **kwargs): super().__init__() # the activation function should specified outside conv module kwargs_ = copy.deepcopy(kwargs) kwargs_['act_cfg'] = None self.with_feat_act = feat_act_cfg is not None self.with_gate_act = gate_act_cfg is not None self.conv = ConvModule(in_channels, out_channels * 2, kernel_size, **kwargs_) if self.with_feat_act: self.feat_act = build_activation_layer(feat_act_cfg) if self.with_gate_act: self.gate_act = build_activation_layer(gate_act_cfg)
[docs] def forward(self, x): """Forward Function. Args: x (torch.Tensor): Input tensor with shape of (n, c, h, w). Returns: torch.Tensor: Output tensor with shape of (n, c, h', w'). """ x = self.conv(x) x, gate = torch.split(x, x.size(1) // 2, dim=1) if self.with_feat_act: x = self.feat_act(x) if self.with_gate_act: gate = self.gate_act(gate) x = x * gate return x
Read the Docs v: latest
Versions
latest
stable
1.x
v0.16.0
v0.15.2
v0.15.1
v0.15.0
v0.14.0
v0.13.0
v0.12.0
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.