Shortcuts

Note

You are reading the documentation for MMEditing 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMEditing 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMEditing 1.0 for more details.

Source code for mmedit.models.components.discriminators.deepfill_disc

# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import normal_init
from mmcv.runner import load_checkpoint

from mmedit.models import build_component
from mmedit.models.registry import COMPONENTS
from mmedit.utils import get_root_logger


[docs]@COMPONENTS.register_module() class DeepFillv1Discriminators(nn.Module): """Discriminators used in DeepFillv1 model. In DeepFillv1 model, the discriminators are independent without any concatenation like Global&Local model. Thus, we call this model `DeepFillv1Discriminators`. There exist a global discriminator and a local discriminator with global and local input respectively. The details can be found in: Generative Image Inpainting with Contextual Attention. Args: global_disc_cfg (dict): Config dict for global discriminator. local_disc_cfg (dict): Config dict for local discriminator. """ def __init__(self, global_disc_cfg, local_disc_cfg): super().__init__() self.global_disc = build_component(global_disc_cfg) self.local_disc = build_component(local_disc_cfg)
[docs] def forward(self, x): """Forward function. Args: x (tuple[torch.Tensor]): Contains global image and the local image patch. Returns: tuple[torch.Tensor]: Contains the prediction from discriminators \ in global image and local image patch. """ global_img, local_img = x global_pred = self.global_disc(global_img) local_pred = self.local_disc(local_img) return global_pred, local_pred
[docs] def init_weights(self, pretrained=None): """Init weights for models. Args: pretrained (str, optional): Path for pretrained weights. If given None, pretrained weights will not be loaded. Defaults to None. """ if isinstance(pretrained, str): logger = get_root_logger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: for m in self.modules(): if isinstance(m, nn.Linear): normal_init(m, 0, std=0.02) elif isinstance(m, nn.Conv2d): normal_init(m, 0.0, std=0.02) else: raise TypeError('pretrained must be a str or None but got' f'{type(pretrained)} instead.')
Read the Docs v: latest
Versions
latest
stable
1.x
v0.16.0
v0.15.2
v0.15.1
v0.15.0
v0.14.0
v0.13.0
v0.12.0
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.