Shortcuts

Note

You are reading the documentation for MMEditing 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMEditing 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMEditing 1.0 for more details.

Source code for mmedit.core.hooks.visualization

# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp

import mmcv
import torch
from mmcv.runner import HOOKS, Hook
from mmcv.runner.dist_utils import master_only
from torchvision.utils import save_image

from mmedit.utils import deprecated_function


[docs]@HOOKS.register_module() class MMEditVisualizationHook(Hook): """Visualization hook. In this hook, we use the official api `save_image` in torchvision to save the visualization results. Args: output_dir (str): The file path to store visualizations. res_name_list (str): The list contains the name of results in outputs dict. The results in outputs dict must be a torch.Tensor with shape (n, c, h, w). interval (int): The interval of calling this hook. If set to -1, the visualization hook will not be called. Default: -1. filename_tmpl (str): Format string used to save images. The output file name will be formatted as this args. Default: 'iter_{}.png'. rerange (bool): Whether to rerange the output value from [-1, 1] to [0, 1]. We highly recommend users should preprocess the visualization results on their own. Here, we just provide a simple interface. Default: True. bgr2rgb (bool): Whether to reformat the channel dimension from BGR to RGB. The final image we will save is following RGB style. Default: True. nrow (int): The number of samples in a row. Default: 1. padding (int): The number of padding pixels between each samples. Default: 4. """ def __init__(self, output_dir, res_name_list, interval=-1, filename_tmpl='iter_{}.png', rerange=True, bgr2rgb=True, nrow=1, padding=4): assert mmcv.is_list_of(res_name_list, str) self.output_dir = output_dir self.res_name_list = res_name_list self.interval = interval self.filename_tmpl = filename_tmpl self.bgr2rgb = bgr2rgb self.rerange = rerange self.nrow = nrow self.padding = padding mmcv.mkdir_or_exist(self.output_dir)
[docs] @master_only def after_train_iter(self, runner): """The behavior after each train iteration. Args: runner (object): The runner. """ if not self.every_n_iters(runner, self.interval): return results = runner.outputs['results'] filename = self.filename_tmpl.format(runner.iter + 1) img_list = [x for k, x in results.items() if k in self.res_name_list] img_cat = torch.cat(img_list, dim=3).detach() if self.rerange: img_cat = ((img_cat + 1) / 2) if self.bgr2rgb: img_cat = img_cat[:, [2, 1, 0], ...] img_cat = img_cat.clamp_(0, 1) save_image( img_cat, osp.join(self.output_dir, filename), nrow=self.nrow, padding=self.padding)
[docs]@HOOKS.register_module() class VisualizationHook(MMEditVisualizationHook): @deprecated_function('0.16.0', '0.20.0', 'use \'MMEditVisualizationHook\'') def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
Read the Docs v: latest
Versions
latest
stable
1.x
v0.16.0
v0.15.2
v0.15.1
v0.15.0
v0.14.0
v0.13.0
v0.12.0
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.