Shortcuts

Note

You are reading the documentation for MMEditing 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMEditing 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMEditing 1.0 for more details.

Source code for mmedit.models.common.ensemble

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn


[docs]class SpatialTemporalEnsemble(nn.Module): """Apply spatial and temporal ensemble and compute outputs. Args: is_temporal_ensemble (bool, optional): Whether to apply ensemble temporally. If True, the sequence will also be flipped temporally. If the input is an image, this argument must be set to False. Default: False. """ def __init__(self, is_temporal_ensemble=False): super().__init__() self.is_temporal_ensemble = is_temporal_ensemble def _transform(self, imgs, mode): """Apply spatial transform (flip, rotate) to the images. Args: imgs (torch.Tensor): The images to be transformed/ mode (str): The mode of transform. Supported values are 'vertical', 'horizontal', and 'transpose', corresponding to vertical flip, horizontal flip, and rotation, respectively. Returns: torch.Tensor: Output of the model with spatial ensemble applied. """ is_single_image = False if imgs.ndim == 4: if self.is_temporal_ensemble: raise ValueError('"is_temporal_ensemble" must be False if ' 'the input is an image.') is_single_image = True imgs = imgs.unsqueeze(1) if mode == 'vertical': imgs = imgs.flip(4).clone() elif mode == 'horizontal': imgs = imgs.flip(3).clone() elif mode == 'transpose': imgs = imgs.permute(0, 1, 2, 4, 3).clone() if is_single_image: imgs = imgs.squeeze(1) return imgs
[docs] def spatial_ensemble(self, imgs, model): """Apply spatial ensemble. Args: imgs (torch.Tensor): The images to be processed by the model. Its size should be either (n, t, c, h, w) or (n, c, h, w). model (nn.Module): The model to process the images. Returns: torch.Tensor: Output of the model with spatial ensemble applied. """ img_list = [imgs.cpu()] for mode in ['vertical', 'horizontal', 'transpose']: img_list.extend([self._transform(t, mode) for t in img_list]) output_list = [model(t.to(imgs.device)).cpu() for t in img_list] for i in range(len(output_list)): if i > 3: output_list[i] = self._transform(output_list[i], 'transpose') if i % 4 > 1: output_list[i] = self._transform(output_list[i], 'horizontal') if (i % 4) % 2 == 1: output_list[i] = self._transform(output_list[i], 'vertical') outputs = torch.stack(output_list, dim=0) outputs = outputs.mean(dim=0, keepdim=False) return outputs.to(imgs.device)
[docs] def forward(self, imgs, model): """Apply spatial and temporal ensemble. Args: imgs (torch.Tensor): The images to be processed by the model. Its size should be either (n, t, c, h, w) or (n, c, h, w). model (nn.Module): The model to process the images. Returns: torch.Tensor: Output of the model with spatial ensemble applied. """ outputs = self.spatial_ensemble(imgs, model) if self.is_temporal_ensemble: outputs += self.spatial_ensemble(imgs.flip(1), model).flip(1) outputs *= 0.5 return outputs
Read the Docs v: latest
Versions
latest
stable
1.x
v0.16.0
v0.15.2
v0.15.1
v0.15.0
v0.14.0
v0.13.0
v0.12.0
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.