• Docs >
  • Module code >
  • mmedit.models.backbones.encoder_decoders.necks.contextual_attention_neck
Shortcuts

Note

You are reading the documentation for MMEditing 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMEditing 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMEditing 1.0 for more details.

Source code for mmedit.models.backbones.encoder_decoders.necks.contextual_attention_neck

# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import ConvModule

from mmedit.models.common import SimpleGatedConvModule
from mmedit.models.common.contextual_attention import ContextualAttentionModule
from mmedit.models.registry import COMPONENTS


[docs]@COMPONENTS.register_module() class ContextualAttentionNeck(nn.Module): """Neck with contextual attention module. Args: in_channels (int): The number of input channels. conv_type (str): The type of conv module. In DeepFillv1 model, the `conv_type` should be 'conv'. In DeepFillv2 model, the `conv_type` should be 'gated_conv'. conv_cfg (dict | None): Config of conv module. Default: None. norm_cfg (dict | None): Config of norm module. Default: None. act_cfg (dict | None): Config of activation layer. Default: dict(type='ELU'). contextual_attention_args (dict): Config of contextual attention module. Default: dict(softmax_scale=10.). kwargs (keyword arguments). """ _conv_type = dict(conv=ConvModule, gated_conv=SimpleGatedConvModule) def __init__(self, in_channels, conv_type='conv', conv_cfg=None, norm_cfg=None, act_cfg=dict(type='ELU'), contextual_attention_args=dict(softmax_scale=10.), **kwargs): super().__init__() self.contextual_attention = ContextualAttentionModule( **contextual_attention_args) conv_module = self._conv_type[conv_type] self.conv1 = conv_module( in_channels, in_channels, 3, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, **kwargs) self.conv2 = conv_module( in_channels, in_channels, 3, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, **kwargs)
[docs] def forward(self, x, mask): """Forward Function. Args: x (torch.Tensor): Input tensor with shape of (n, c, h, w). mask (torch.Tensor): Input tensor with shape of (n, 1, h, w). Returns: torch.Tensor: Output tensor with shape of (n, c, h', w'). """ x, offset = self.contextual_attention(x, x, mask) x = self.conv1(x) x = self.conv2(x) return x, offset
Read the Docs v: latest
Versions
latest
stable
1.x
v0.16.0
v0.15.2
v0.15.1
v0.15.0
v0.14.0
v0.13.0
v0.12.0
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.