You are reading the documentation for MMEditing 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMEditing 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMEditing 1.0 for more details.

Source code for mmedit.models.backbones.sr_backbones.real_basicvsr_net

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.runner import load_checkpoint

from mmedit.models.backbones.sr_backbones.basicvsr_net import (
    BasicVSRNet, ResidualBlocksWithInputConv)
from mmedit.models.registry import BACKBONES
from mmedit.utils import get_root_logger

[docs]@BACKBONES.register_module() class RealBasicVSRNet(nn.Module): """RealBasicVSR network structure for real-world video super-resolution. Support only x4 upsampling. Paper: Investigating Tradeoffs in Real-World Video Super-Resolution, arXiv Args: mid_channels (int, optional): Channel number of the intermediate features. Default: 64. num_propagation_blocks (int, optional): Number of residual blocks in each propagation branch. Default: 20. num_cleaning_blocks (int, optional): Number of residual blocks in the image cleaning module. Default: 20. dynamic_refine_thres (int, optional): Stop cleaning the images when the residue is smaller than this value. Default: 255. spynet_pretrained (str, optional): Pre-trained model path of SPyNet. Default: None. is_fix_cleaning (bool, optional): Whether to fix the weights of the image cleaning module during training. Default: False. is_sequential_cleaning (bool, optional): Whether to clean the images sequentially. This is used to save GPU memory, but the speed is slightly slower. Default: False. """ def __init__(self, mid_channels=64, num_propagation_blocks=20, num_cleaning_blocks=20, dynamic_refine_thres=255, spynet_pretrained=None, is_fix_cleaning=False, is_sequential_cleaning=False): super().__init__() self.dynamic_refine_thres = dynamic_refine_thres / 255. self.is_sequential_cleaning = is_sequential_cleaning # image cleaning module self.image_cleaning = nn.Sequential( ResidualBlocksWithInputConv(3, mid_channels, num_cleaning_blocks), nn.Conv2d(mid_channels, 3, 3, 1, 1, bias=True), ) if is_fix_cleaning: # keep the weights of the cleaning module fixed self.image_cleaning.requires_grad_(False) # BasicVSR self.basicvsr = BasicVSRNet(mid_channels, num_propagation_blocks, spynet_pretrained) self.basicvsr.spynet.requires_grad_(False)
[docs] def forward(self, lqs, return_lqs=False): n, t, c, h, w = lqs.size() for _ in range(0, 3): # at most 3 cleaning, determined empirically if self.is_sequential_cleaning: residues = [] for i in range(0, t): residue_i = self.image_cleaning(lqs[:, i, :, :, :]) lqs[:, i, :, :, :] += residue_i residues.append(residue_i) residues = torch.stack(residues, dim=1) else: # time -> batch, then apply cleaning at once lqs = lqs.view(-1, c, h, w) residues = self.image_cleaning(lqs) lqs = (lqs + residues).view(n, t, c, h, w) # determine whether to continue cleaning if torch.mean(torch.abs(residues)) < self.dynamic_refine_thres: break # Super-resolution (BasicVSR) outputs = self.basicvsr(lqs) if return_lqs: return outputs, lqs else: return outputs
[docs] def init_weights(self, pretrained=None, strict=True): """Init weights for models. Args: pretrained (str, optional): Path for pretrained weights. If given None, pretrained weights will not be loaded. Defaults: None. strict (boo, optional): Whether strictly load the pretrained model. Defaults to True. """ if isinstance(pretrained, str): logger = get_root_logger() load_checkpoint(self, pretrained, strict=strict, logger=logger) elif pretrained is not None: raise TypeError(f'"pretrained" must be a str or None. ' f'But received {type(pretrained)}.')
