Shortcuts

Note

You are reading the documentation for MMEditing 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMEditing 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMEditing 1.0 for more details.

Source code for mmedit.models.restorers.srgan

# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.runner import auto_fp16

from ..builder import build_backbone, build_component, build_loss
from ..common import set_requires_grad
from ..registry import MODELS
from .basic_restorer import BasicRestorer


[docs]@MODELS.register_module() class SRGAN(BasicRestorer): """SRGAN model for single image super-resolution. Ref: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network. Args: generator (dict): Config for the generator. discriminator (dict): Config for the discriminator. Default: None. gan_loss (dict): Config for the gan loss. Note that the loss weight in gan loss is only for the generator. pixel_loss (dict): Config for the pixel loss. Default: None. perceptual_loss (dict): Config for the perceptual loss. Default: None. train_cfg (dict): Config for training. Default: None. You may change the training of gan by setting: `disc_steps`: how many discriminator updates after one generate update; `disc_init_steps`: how many discriminator updates at the start of the training. These two keys are useful when training with WGAN. test_cfg (dict): Config for testing. Default: None. pretrained (str): Path for pretrained model. Default: None. """ def __init__(self, generator, discriminator=None, gan_loss=None, pixel_loss=None, perceptual_loss=None, train_cfg=None, test_cfg=None, pretrained=None): super(BasicRestorer, self).__init__() self.train_cfg = train_cfg self.test_cfg = test_cfg # generator self.generator = build_backbone(generator) # discriminator self.discriminator = build_component( discriminator) if discriminator else None # support fp16 self.fp16_enabled = False # loss self.gan_loss = build_loss(gan_loss) if gan_loss else None self.pixel_loss = build_loss(pixel_loss) if pixel_loss else None self.perceptual_loss = build_loss( perceptual_loss) if perceptual_loss else None self.disc_steps = 1 if self.train_cfg is None else self.train_cfg.get( 'disc_steps', 1) self.disc_init_steps = (0 if self.train_cfg is None else self.train_cfg.get('disc_init_steps', 0)) self.step_counter = 0 # counting training steps self.init_weights(pretrained)
[docs] def init_weights(self, pretrained=None): """Init weights for models. Args: pretrained (str, optional): Path for pretrained weights. If given None, pretrained weights will not be loaded. Defaults to None. """ self.generator.init_weights(pretrained=pretrained) if self.discriminator: self.discriminator.init_weights(pretrained=pretrained)
[docs] @auto_fp16(apply_to=('lq', )) def forward(self, lq, gt=None, test_mode=False, **kwargs): """Forward function. Args: lq (Tensor): Input lq images. gt (Tensor): Ground-truth image. Default: None. test_mode (bool): Whether in test mode or not. Default: False. kwargs (dict): Other arguments. """ if test_mode: return self.forward_test(lq, gt, **kwargs) raise ValueError( 'SRGAN model does not support `forward_train` function.')
[docs] def train_step(self, data_batch, optimizer): """Train step. Args: data_batch (dict): A batch of data. optimizer (obj): Optimizer. Returns: dict: Returned output. """ # data lq = data_batch['lq'] gt = data_batch['gt'] # generator fake_g_output = self.generator(lq) losses = dict() log_vars = dict() # no updates to discriminator parameters. set_requires_grad(self.discriminator, False) if (self.step_counter % self.disc_steps == 0 and self.step_counter >= self.disc_init_steps): if self.pixel_loss: losses['loss_pix'] = self.pixel_loss(fake_g_output, gt) if self.perceptual_loss: loss_percep, loss_style = self.perceptual_loss( fake_g_output, gt) if loss_percep is not None: losses['loss_perceptual'] = loss_percep if loss_style is not None: losses['loss_style'] = loss_style # gan loss for generator fake_g_pred = self.discriminator(fake_g_output) losses['loss_gan'] = self.gan_loss( fake_g_pred, target_is_real=True, is_disc=False) # parse loss loss_g, log_vars_g = self.parse_losses(losses) log_vars.update(log_vars_g) # optimize optimizer['generator'].zero_grad() loss_g.backward() optimizer['generator'].step() # discriminator set_requires_grad(self.discriminator, True) # real real_d_pred = self.discriminator(gt) loss_d_real = self.gan_loss( real_d_pred, target_is_real=True, is_disc=True) loss_d, log_vars_d = self.parse_losses(dict(loss_d_real=loss_d_real)) optimizer['discriminator'].zero_grad() loss_d.backward() log_vars.update(log_vars_d) # fake fake_d_pred = self.discriminator(fake_g_output.detach()) loss_d_fake = self.gan_loss( fake_d_pred, target_is_real=False, is_disc=True) loss_d, log_vars_d = self.parse_losses(dict(loss_d_fake=loss_d_fake)) loss_d.backward() log_vars.update(log_vars_d) optimizer['discriminator'].step() self.step_counter += 1 log_vars.pop('loss') # remove the unnecessary 'loss' outputs = dict( log_vars=log_vars, num_samples=len(gt.data), results=dict(lq=lq.cpu(), gt=gt.cpu(), output=fake_g_output.cpu())) return outputs
Read the Docs v: latest
Versions
latest
stable
1.x
v0.16.0
v0.15.2
v0.15.1
v0.15.0
v0.14.0
v0.13.0
v0.12.0
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.