You are reading the documentation for MMEditing 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMEditing 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMEditing 1.0 for more details.

Source code for mmedit.models.restorers.esrgan

# Copyright (c) OpenMMLab. All rights reserved.
import torch

from ..common import set_requires_grad
from ..registry import MODELS
from .srgan import SRGAN

[docs]@MODELS.register_module() class ESRGAN(SRGAN): """Enhanced SRGAN model for single image super-resolution. Ref: ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. It uses RaGAN for GAN updates: The relativistic discriminator: a key element missing from standard GAN. Args: generator (dict): Config for the generator. discriminator (dict): Config for the discriminator. Default: None. gan_loss (dict): Config for the gan loss. Note that the loss weight in gan loss is only for the generator. pixel_loss (dict): Config for the pixel loss. Default: None. perceptual_loss (dict): Config for the perceptual loss. Default: None. train_cfg (dict): Config for training. Default: None. You may change the training of gan by setting: `disc_steps`: how many discriminator updates after one generate update; `disc_init_steps`: how many discriminator updates at the start of the training. These two keys are useful when training with WGAN. test_cfg (dict): Config for testing. Default: None. pretrained (str): Path for pretrained model. Default: None. """
[docs] def train_step(self, data_batch, optimizer): """Train step. Args: data_batch (dict): A batch of data. optimizer (obj): Optimizer. Returns: dict: Returned output. """ # data lq = data_batch['lq'] gt = data_batch['gt'] # generator fake_g_output = self.generator(lq) losses = dict() log_vars = dict() # no updates to discriminator parameters. set_requires_grad(self.discriminator, False) if (self.step_counter % self.disc_steps == 0 and self.step_counter >= self.disc_init_steps): if self.pixel_loss: losses['loss_pix'] = self.pixel_loss(fake_g_output, gt) if self.perceptual_loss: loss_percep, loss_style = self.perceptual_loss( fake_g_output, gt) if loss_percep is not None: losses['loss_perceptual'] = loss_percep if loss_style is not None: losses['loss_style'] = loss_style # gan loss for generator real_d_pred = self.discriminator(gt).detach() fake_g_pred = self.discriminator(fake_g_output) loss_gan_fake = self.gan_loss( fake_g_pred - torch.mean(real_d_pred), target_is_real=True, is_disc=False) loss_gan_real = self.gan_loss( real_d_pred - torch.mean(fake_g_pred), target_is_real=False, is_disc=False) losses['loss_gan'] = (loss_gan_fake + loss_gan_real) / 2 # parse loss loss_g, log_vars_g = self.parse_losses(losses) log_vars.update(log_vars_g) # optimize optimizer['generator'].zero_grad() loss_g.backward() optimizer['generator'].step() # discriminator set_requires_grad(self.discriminator, True) # real fake_d_pred = self.discriminator(fake_g_output).detach() real_d_pred = self.discriminator(gt) loss_d_real = self.gan_loss( real_d_pred - torch.mean(fake_d_pred), target_is_real=True, is_disc=True ) * 0.5 # 0.5 for averaging loss_d_real and loss_d_fake loss_d, log_vars_d = self.parse_losses(dict(loss_d_real=loss_d_real)) optimizer['discriminator'].zero_grad() loss_d.backward() log_vars.update(log_vars_d) # fake fake_d_pred = self.discriminator(fake_g_output.detach()) loss_d_fake = self.gan_loss( fake_d_pred - torch.mean(real_d_pred.detach()), target_is_real=False, is_disc=True ) * 0.5 # 0.5 for averaging loss_d_real and loss_d_fake loss_d, log_vars_d = self.parse_losses(dict(loss_d_fake=loss_d_fake)) loss_d.backward() log_vars.update(log_vars_d) optimizer['discriminator'].step() self.step_counter += 1 log_vars.pop('loss') # remove the unnecessary 'loss' outputs = dict( log_vars=log_vars, num_samples=len(, results=dict(lq=lq.cpu(), gt=gt.cpu(), output=fake_g_output.cpu())) return outputs
