Shortcuts

Note

You are reading the documentation for MMEditing 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMEditing 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMEditing 1.0 for more details.

Source code for mmedit.models.backbones.encoder_decoders.gl_encoder_decoder

# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.runner import auto_fp16, load_checkpoint

from mmedit.models.builder import build_component
from mmedit.models.registry import BACKBONES
from mmedit.utils import get_root_logger


[docs]@BACKBONES.register_module() class GLEncoderDecoder(nn.Module): """Encoder-Decoder used in Global&Local model. This implementation follows: Globally and locally Consistent Image Completion The architecture of the encoder-decoder is:\ (conv2d x 6) --> (dilated conv2d x 4) --> (conv2d or deconv2d x 7) Args: encoder (dict): Config dict to encoder. decoder (dict): Config dict to build decoder. dilation_neck (dict): Config dict to build dilation neck. """ def __init__(self, encoder=dict(type='GLEncoder'), decoder=dict(type='GLDecoder'), dilation_neck=dict(type='GLDilationNeck')): super().__init__() self.encoder = build_component(encoder) self.decoder = build_component(decoder) self.dilation_neck = build_component(dilation_neck) # support fp16 self.fp16_enabled = False
[docs] @auto_fp16() def forward(self, x): """Forward Function. Args: x (torch.Tensor): Input tensor with shape of (n, c, h, w). Returns: torch.Tensor: Output tensor with shape of (n, c, h', w'). """ x = self.encoder(x) if isinstance(x, dict): x = x['out'] x = self.dilation_neck(x) x = self.decoder(x) return x
[docs] def init_weights(self, pretrained=None): """Init weights for models. Args: pretrained (str, optional): Path for pretrained weights. If given None, pretrained weights will not be loaded. Defaults to None. """ if isinstance(pretrained, str): logger = get_root_logger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: # Here, we just use the default initialization in `ConvModule`. pass else: raise TypeError('pretrained must be a str or None')
Read the Docs v: latest
Versions
latest
stable
1.x
v0.16.0
v0.15.2
v0.15.1
v0.15.0
v0.14.0
v0.13.0
v0.12.0
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.