Shortcuts

Note

You are reading the documentation for MMEditing 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMEditing 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMEditing 1.0 for more details.

Source code for mmedit.models.video_interpolators.basic_interpolator

# Copyright (c) OpenMMLab. All rights reserved.
import numbers
import os.path as osp

import mmcv
import numpy as np
import torch
from mmcv.runner import auto_fp16

from mmedit.core import psnr, ssim, tensor2img
from ..base import BaseModel
from ..builder import build_backbone, build_loss
from ..registry import MODELS


[docs]@MODELS.register_module() class BasicInterpolator(BaseModel): """Basic model for video interpolation. It must contain a generator that takes frames as inputs and outputs an interpolated frame. It also has a pixel-wise loss for training. The subclasses should overwrite the function `forward_train`, `forward_test` and `train_step`. Args: generator (dict): Config for the generator structure. pixel_loss (dict): Config for pixel-wise loss. train_cfg (dict): Config for training. Default: None. test_cfg (dict): Config for testing. Default: None. required_frames (int): Required frames in each process. Default: 2 step_frames (int): Step size of video frame interpolation. Default: 1 pretrained (str): Path for pretrained model. Default: None. """ allowed_metrics = {'PSNR': psnr, 'SSIM': ssim} def __init__(self, generator, pixel_loss, train_cfg=None, test_cfg=None, required_frames=2, step_frames=1, pretrained=None): super().__init__() self.train_cfg = train_cfg self.test_cfg = test_cfg # support fp16 self.fp16_enabled = False # generator self.generator = build_backbone(generator) self.init_weights(pretrained) # loss self.pixel_loss = build_loss(pixel_loss) # Required frames in each process self.required_frames = required_frames # Step size of video frame interpolation self.step_frames = step_frames
[docs] def init_weights(self, pretrained=None): """Init weights for models. Args: pretrained (str, optional): Path for pretrained weights. If given None, pretrained weights will not be loaded. Defaults to None. """ self.generator.init_weights(pretrained)
[docs] @auto_fp16(apply_to=('inputs', )) def forward(self, inputs, target=None, test_mode=False, **kwargs): """Forward function. Args: inputs (Tensor): Tensor of input frames. target (Tensor): Tensor of target frame. Default: None. test_mode (bool): Whether in test mode or not. Default: False. kwargs (dict): Other arguments. """ if test_mode: return self.forward_test(inputs, target, **kwargs) return self.forward_train(inputs, target)
[docs] def forward_train(self, inputs, target): """Training forward function. This is a basic function, interpolate a frame between the given two frames. Args: inputs (Tensor): Tensor of input frame(s). target (Tensor): Tensor of target frame(s). Returns: Tensor: Output tensor. """ losses = dict() output = self.generator(inputs) loss_pix = self.pixel_loss(output, target) losses['loss_pix'] = loss_pix outputs = dict( losses=losses, num_samples=len(target.data), results=dict( inputs=inputs.cpu(), target=target.cpu(), output=output.cpu())) return outputs
[docs] def evaluate(self, output, target): """Evaluation function. Args: output (Tensor): Model output. target (Tensor): GT Tensor. Returns: dict: Evaluation results. """ crop_border = self.test_cfg.get('crop_border', 0) convert_to = self.test_cfg.get('convert_to', None) eval_result = dict() for metric in self.test_cfg.metrics: if output.ndim == 5: # a sequence: (n, t, c, h, w) avg = [] for i in range(0, output.size(1)): output_i = tensor2img(output[:, i, :, :, :]) target_i = tensor2img(target[:, i, :, :, :]) avg.append(self.allowed_metrics[metric]( output_i, target_i, crop_border, convert_to=convert_to)) eval_result[metric] = np.mean(avg) elif output.ndim == 4: # an image: (n, c, h, w) output_img = tensor2img(output) target_img = tensor2img(target) value = self.allowed_metrics[metric]( output_img, target_img, crop_border, convert_to=convert_to) eval_result[metric] = value return eval_result
[docs] def forward_test(self, inputs, target=None, meta=None, save_image=False, save_path=None, iteration=None): """Testing forward function. This is a basic function, interpolate a frame between the given two frames. Args: inputs (Tensor): Tensor of input frames. target (Tensor): Tensor of target frame(s). Default: None. meta (list[dict]): Meta data, such as path of target file. Default: None. save_image (bool): Whether to save image. Default: False. save_path (str): Path to save image. Default: None. iteration (int): Iteration for the saving image name. Default: None. Returns: dict: Output results. """ output = self.generator(inputs).clamp(0, 1) if self.test_cfg is not None and self.test_cfg.get('metrics', None): assert target is not None, ( 'evaluation with metrics must have target images.') results = dict(eval_result=self.evaluate(output, target)) else: results = dict(inputs=inputs.cpu(), output=output.cpu()) if target is not None: results['target'] = target.cpu() # save image if save_image: self._save_image(meta, iteration, save_path, output) return results
@staticmethod def _save_image(meta, iteration, save_path, output): """Save the image. Args: meta (list[dict]): Meta data, such as path of target file. Default: None. These dictionaries should contain 'target_path' (str of a path) or 'inputs_path' (list of str) iteration (int): Iteration for the saving image name. Default: None. save_path (str): Path to save image. Default: None. output (Tensor): Output image. """ if output.ndim == 4: # an image img_name = meta[0]['key'].replace('/', '_') if isinstance(iteration, numbers.Number): save_path = osp.join(save_path, f'{img_name}-{iteration + 1:06d}.png') elif iteration is None: save_path = osp.join(save_path, f'{img_name}.png') else: raise ValueError('iteration should be number or None, ' f'but got {type(iteration)}') mmcv.imwrite(tensor2img(output), save_path) elif output.ndim == 5: # a sequence folder_name = meta[0]['key'].split('/')[0] for i in range(0, output.size(1)): if isinstance(iteration, numbers.Number): save_path_i = osp.join(save_path, folder_name, f'{i:08d}-{iteration + 1:06d}.png') elif iteration is None: save_path_i = osp.join(save_path, folder_name, f'{i:08d}.png') else: raise ValueError('iteration should be number or None, ' f'but got {type(iteration)}') mmcv.imwrite(tensor2img(output[:, i, :, :, :]), save_path_i)
[docs] def forward_dummy(self, img): """Used for computing network FLOPs. Args: img (Tensor): Input frames. Returns: Tensor: Output frame(s). """ out = self.generator(img) return out
[docs] def train_step(self, data_batch, optimizer): """Train step. Args: data_batch (dict): A batch of data. optimizer (obj): Optimizer. Returns: dict: Returned output. """ outputs = self(**data_batch, test_mode=False) loss, log_vars = self.parse_losses(outputs.pop('losses')) # optimize optimizer['generator'].zero_grad() loss.backward() optimizer['generator'].step() outputs.update({'log_vars': log_vars}) return outputs
[docs] def val_step(self, data_batch, **kwargs): """Validation step. Args: data_batch (dict): A batch of data. kwargs (dict): Other arguments for ``val_step``. Returns: dict: Returned output. """ output = self.forward_test(**data_batch, **kwargs) return output
[docs] def split_frames(self, input_tensors): """split input tensors for inference. Args: input_tensors (Tensor): Tensor of input frames with shape [1, t, c, h, w] Returns: Tensor: Split tensor with shape [t-1, 2, c, h, w] """ num_frames = input_tensors.shape[1] result = [ input_tensors[:, i:i + self.required_frames] for i in range(0, num_frames - self.required_frames + 1, self.step_frames) ] result = torch.cat(result, dim=0) return result
[docs] @staticmethod def merge_frames(input_tensors, output_tensors): """merge input frames and output frames. This is a basic function, interpolate a frame between the given two frames. Args: input_tensors (Tensor): The input frames with shape [n, 2, c, h, w] output_tensors (Tensor): The output frames with shape [n, 1, c, h, w]. Returns: list[np.array]: The final frames. in_frame, out_frame, in_frame, out_frame, in_frame ... """ num_frames = input_tensors.shape[0] result = [] for i in range(num_frames): result.append(tensor2img(input_tensors[i, 0])) result.append(tensor2img(output_tensors[i, 0])) result.append(tensor2img(input_tensors[-1, 1])) return result
Read the Docs v: latest
Versions
latest
stable
1.x
v0.16.0
v0.15.2
v0.15.1
v0.15.0
v0.14.0
v0.13.0
v0.12.0
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.