Shortcuts

Note

You are reading the documentation for MMEditing 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMEditing 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMEditing 1.0 for more details.

Source code for mmedit.models.backbones.sr_backbones.tof

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmcv.runner import load_checkpoint

from mmedit.models.common import flow_warp
from mmedit.models.registry import BACKBONES
from mmedit.utils import get_root_logger


class BasicModule(nn.Module):
    """Basic module of SPyNet.

    Note that unlike the common spynet architecture, the basic module here
    contains batch normalization.
    """

    def __init__(self):
        super().__init__()

        self.basic_module = nn.Sequential(
            ConvModule(
                in_channels=8,
                out_channels=32,
                kernel_size=7,
                stride=1,
                padding=3,
                norm_cfg=dict(type='BN'),
                act_cfg=dict(type='ReLU')),
            ConvModule(
                in_channels=32,
                out_channels=64,
                kernel_size=7,
                stride=1,
                padding=3,
                norm_cfg=dict(type='BN'),
                act_cfg=dict(type='ReLU')),
            ConvModule(
                in_channels=64,
                out_channels=32,
                kernel_size=7,
                stride=1,
                padding=3,
                norm_cfg=dict(type='BN'),
                act_cfg=dict(type='ReLU')),
            ConvModule(
                in_channels=32,
                out_channels=16,
                kernel_size=7,
                stride=1,
                padding=3,
                norm_cfg=dict(type='BN'),
                act_cfg=dict(type='ReLU')),
            ConvModule(
                in_channels=16,
                out_channels=2,
                kernel_size=7,
                stride=1,
                padding=3,
                norm_cfg=None,
                act_cfg=None))

    def forward(self, tensor_input):
        """
        Args:
            tensor_input (Tensor): Input tensor with shape (b, 8, h, w).
                8 channels contain:
                [reference image (3), neighbor image (3), initial flow (2)].

        Returns:
            Tensor: Estimated flow with shape (b, 2, h, w)
        """
        return self.basic_module(tensor_input)


class SPyNet(nn.Module):
    """SPyNet architecture.

    Note that this implementation is specifically for TOFlow. It differs from
    the common SPyNet in the following aspects:
        1. The basic modules here contain BatchNorm.
        2. Normalization and denormalization are not done here, as
            they are done in TOFlow.
    Paper:
        Optical Flow Estimation using a Spatial Pyramid Network
    Code reference:
        https://github.com/Coldog2333/pytoflow
    """

    def __init__(self):
        super().__init__()

        self.basic_module = nn.ModuleList([BasicModule() for _ in range(4)])

    def forward(self, ref, supp):
        """
        Args:
            ref (Tensor): Reference image with shape of (b, 3, h, w).
            supp: The supporting image to be warped: (b, 3, h, w).

        Returns:
            Tensor: Estimated optical flow: (b, 2, h, w).
        """
        num_batches, _, h, w = ref.size()
        ref = [ref]
        supp = [supp]

        # generate downsampled frames
        for _ in range(3):
            ref.insert(
                0,
                F.avg_pool2d(
                    input=ref[0],
                    kernel_size=2,
                    stride=2,
                    count_include_pad=False))
            supp.insert(
                0,
                F.avg_pool2d(
                    input=supp[0],
                    kernel_size=2,
                    stride=2,
                    count_include_pad=False))

        # flow computation
        flow = ref[0].new_zeros(num_batches, 2, h // 16, w // 16)
        for i in range(4):
            flow_up = F.interpolate(
                input=flow,
                scale_factor=2,
                mode='bilinear',
                align_corners=True) * 2.0
            flow = flow_up + self.basic_module[i](
                torch.cat([
                    ref[i],
                    flow_warp(supp[i], flow_up.permute(0, 2, 3, 1)), flow_up
                ], 1))
        return flow


[docs]@BACKBONES.register_module() class TOFlow(nn.Module): """PyTorch implementation of TOFlow. In TOFlow, the LR frames are pre-upsampled and have the same size with the GT frames. Paper: Xue et al., Video Enhancement with Task-Oriented Flow, IJCV 2018 Code reference: 1. https://github.com/anchen1011/toflow 2. https://github.com/Coldog2333/pytoflow Args: adapt_official_weights (bool): Whether to adapt the weights translated from the official implementation. Set to false if you want to train from scratch. Default: False """ def __init__(self, adapt_official_weights=False): super().__init__() self.adapt_official_weights = adapt_official_weights self.ref_idx = 0 if adapt_official_weights else 3 # The mean and std are for img with range (0, 1) self.register_buffer( 'mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) self.register_buffer( 'std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) # flow estimation module self.spynet = SPyNet() # reconstruction module self.conv1 = nn.Conv2d(3 * 7, 64, 9, 1, 4) self.conv2 = nn.Conv2d(64, 64, 9, 1, 4) self.conv3 = nn.Conv2d(64, 64, 1) self.conv4 = nn.Conv2d(64, 3, 1) # activation function self.relu = nn.ReLU(inplace=True)
[docs] def normalize(self, img): """Normalize the input image. Args: img (Tensor): Input image. Returns: Tensor: Normalized image. """ return (img - self.mean) / self.std
[docs] def denormalize(self, img): """Denormalize the output image. Args: img (Tensor): Output image. Returns: Tensor: Denormalized image. """ return img * self.std + self.mean
[docs] def forward(self, lrs): """ Args: lrs: Input lr frames: (b, 7, 3, h, w). Returns: Tensor: SR frame: (b, 3, h, w). """ # In the official implementation, the 0-th frame is the reference frame if self.adapt_official_weights: lrs = lrs[:, [3, 0, 1, 2, 4, 5, 6], :, :, :] num_batches, num_lrs, _, h, w = lrs.size() lrs = self.normalize(lrs.view(-1, 3, h, w)) lrs = lrs.view(num_batches, num_lrs, 3, h, w) lr_ref = lrs[:, self.ref_idx, :, :, :] lr_aligned = [] for i in range(7): # 7 frames if i == self.ref_idx: lr_aligned.append(lr_ref) else: lr_supp = lrs[:, i, :, :, :] flow = self.spynet(lr_ref, lr_supp) lr_aligned.append(flow_warp(lr_supp, flow.permute(0, 2, 3, 1))) # reconstruction hr = torch.stack(lr_aligned, dim=1) hr = hr.view(num_batches, -1, h, w) hr = self.relu(self.conv1(hr)) hr = self.relu(self.conv2(hr)) hr = self.relu(self.conv3(hr)) hr = self.conv4(hr) + lr_ref return self.denormalize(hr)
[docs] def init_weights(self, pretrained=None, strict=True): """Init weights for models. Args: pretrained (str, optional): Path for pretrained weights. If given None, pretrained weights will not be loaded. Defaults to None. strict (boo, optional): Whether strictly load the pretrained model. Defaults to True. """ if isinstance(pretrained, str): logger = get_root_logger() load_checkpoint(self, pretrained, strict=strict, logger=logger) elif pretrained is None: pass # use default initialization else: raise TypeError('"pretrained" must be a str or None. ' f'But received {type(pretrained)}.')
Read the Docs v: latest
Versions
latest
stable
1.x
v0.16.0
v0.15.2
v0.15.1
v0.15.0
v0.14.0
v0.13.0
v0.12.0
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.