Source code for mmedit.models.backbones.sr_backbones.real_basicvsr_net

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.runner import load_checkpoint

from mmedit.models.backbones.sr_backbones.basicvsr_net import (
    BasicVSRNet, ResidualBlocksWithInputConv)
from mmedit.models.registry import BACKBONES
from mmedit.utils import get_root_logger

[docs]@BACKBONES.register_module() class RealBasicVSRNet(nn.Module): """RealBasicVSR network structure for real-world video super-resolution. Support only x4 upsampling. Paper: Investigating Tradeoffs in Real-World Video Super-Resolution, arXiv Args: mid_channels (int, optional): Channel number of the intermediate features. Default: 64. num_propagation_blocks (int, optional): Number of residual blocks in each propagation branch. Default: 20. num_cleaning_blocks (int, optional): Number of residual blocks in the image cleaning module. Default: 20. dynamic_refine_thres (int, optional): Stop cleaning the images when the residue is smaller than this value. Default: 255. spynet_pretrained (str, optional): Pre-trained model path of SPyNet. Default: None. is_fix_cleaning (bool, optional): Whether to fix the weights of the image cleaning module during training. Default: False. is_sequential_cleaning (bool, optional): Whether to clean the images sequentially. This is used to save GPU memory, but the speed is slightly slower. Default: False. """ def __init__(self, mid_channels=64, num_propagation_blocks=20, num_cleaning_blocks=20, dynamic_refine_thres=255, spynet_pretrained=None, is_fix_cleaning=False, is_sequential_cleaning=False): super().__init__() self.dynamic_refine_thres = dynamic_refine_thres / 255. self.is_sequential_cleaning = is_sequential_cleaning # image cleaning module self.image_cleaning = nn.Sequential( ResidualBlocksWithInputConv(3, mid_channels, num_cleaning_blocks), nn.Conv2d(mid_channels, 3, 3, 1, 1, bias=True), ) if is_fix_cleaning: # keep the weights of the cleaning module fixed self.image_cleaning.requires_grad_(False) # BasicVSR self.basicvsr = BasicVSRNet(mid_channels, num_propagation_blocks, spynet_pretrained) self.basicvsr.spynet.requires_grad_(False)
[docs] def forward(self, lqs, return_lqs=False): n, t, c, h, w = lqs.size() for _ in range(0, 3): # at most 3 cleaning, determined empirically if self.is_sequential_cleaning: residues = [] for i in range(0, t): residue_i = self.image_cleaning(lqs[:, i, :, :, :]) lqs[:, i, :, :, :] += residue_i residues.append(residue_i) residues = torch.stack(residues, dim=1) else: # time -> batch, then apply cleaning at once lqs = lqs.view(-1, c, h, w) residues = self.image_cleaning(lqs) lqs = (lqs + residues).view(n, t, c, h, w) # determine whether to continue cleaning if torch.mean(torch.abs(residues)) < self.dynamic_refine_thres: break # Super-resolution (BasicVSR) outputs = self.basicvsr(lqs) if return_lqs: return outputs, lqs else: return outputs
[docs] def init_weights(self, pretrained=None, strict=True): """Init weights for models. Args: pretrained (str, optional): Path for pretrained weights. If given None, pretrained weights will not be loaded. Defaults: None. strict (boo, optional): Whether strictly load the pretrained model. Defaults to True. """ if isinstance(pretrained, str): logger = get_root_logger() load_checkpoint(self, pretrained, strict=strict, logger=logger) elif pretrained is not None: raise TypeError(f'"pretrained" must be a str or None. ' f'But received {type(pretrained)}.')
Read the Docs v: v0.13.0
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.