Shortcuts

Note

You are reading the documentation for MMEditing 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMEditing 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMEditing 1.0 for more details.

Source code for mmedit.models.video_interpolators.cain

# Copyright (c) OpenMMLab. All rights reserved.
import torch

from ..registry import MODELS
from .basic_interpolator import BasicInterpolator


[docs]@MODELS.register_module() class CAIN(BasicInterpolator): """CAIN model for Video Interpolation. Paper: Channel Attention Is All You Need for Video Frame Interpolation Ref repo: https://github.com/myungsub/CAIN Args: generator (dict): Config for the generator structure. pixel_loss (dict): Config for pixel-wise loss. train_cfg (dict): Config for training. Default: None. test_cfg (dict): Config for testing. Default: None. pretrained (str): Path for pretrained model. Default: None. """
[docs] def forward_train(self, inputs, target): """Training forward function. Args: inputs (Tensor): Tensor of inputs frames with shape (n, 2, c, h, w). target (Tensor): Tensor of target frame with shape (n, c, h, w). Returns: Tensor: Output tensor. """ losses = dict() output = self.generator(inputs, padding_flag=False) loss_pix = self.pixel_loss(output, target) losses['loss_pix'] = loss_pix outputs = dict( losses=losses, num_samples=len(target.data), results=dict( inputs=inputs.cpu(), target=target.cpu(), output=output.cpu())) return outputs
[docs] def forward_test(self, inputs, target=None, meta=None, save_image=False, save_path=None, iteration=None): """Testing forward function. Args: inputs (Tensor): The input Tensor with shape (n, 2, c, h, w). target (Tensor): The target Tensor with shape (n, c, h, w). meta (list[dict]): Meta data, such as path of target file. Default: None. save_image (bool): Whether to save image. Default: False. save_path (str): Path to save image. Default: None. iteration (int): Iteration for the saving image name. Default: None. Returns: dict: Output results, which contain either key(s) 1. 'eval_result'. 2. 'inputs', 'pred'. 3. 'inputs', 'pred', and 'target'. """ # generator with torch.no_grad(): pred = self.generator(inputs, padding_flag=True) pred = pred.clamp(0, 1) if self.test_cfg is not None and self.test_cfg.get('metrics', None): assert target is not None, ( 'evaluation with metrics must have target images.') results = dict(eval_result=self.evaluate(pred, target)) else: results = dict(inputs=inputs.cpu(), output=pred.cpu()) if target is not None: results['target'] = target.cpu() # save image if save_image: self._save_image(meta, iteration, save_path, pred) return results
Read the Docs v: latest
Versions
latest
stable
1.x
v0.16.0
v0.15.2
v0.15.1
v0.15.0
v0.14.0
v0.13.0
v0.12.0
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.