You are reading the documentation for MMEditing 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMEditing 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMEditing 1.0 for more details.

Source code for mmedit.datasets.pipelines.generate_assistant

# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch

from ..registry import PIPELINES
from .utils import make_coord

[docs]@PIPELINES.register_module() class GenerateHeatmap: """Generate heatmap from keypoint. Args: keypoint (str): Key of keypoint in dict. ori_size (int | Tuple[int]): Original image size of keypoint. target_size (int | Tuple[int]): Target size of heatmap. sigma (float): Sigma parameter of heatmap. Default: 1.0 """ def __init__(self, keypoint, ori_size, target_size, sigma=1.0): if isinstance(ori_size, int): ori_size = (ori_size, ori_size) else: ori_size = ori_size[:2] if isinstance(target_size, int): target_size = (target_size, target_size) else: target_size = target_size[:2] self.size_ratio = (target_size[0] / ori_size[0], target_size[1] / ori_size[1]) self.keypoint = keypoint self.sigma = sigma self.target_size = target_size self.ori_size = ori_size def __call__(self, results): """Call function. Args: results (dict): A dict containing the necessary information and data for augmentation. Require keypoint. Returns: dict: A dict containing the processed data and information. Add 'heatmap'. """ keypoint_list = [(keypoint[0] * self.size_ratio[0], keypoint[1] * self.size_ratio[1]) for keypoint in results[self.keypoint]] heatmap_list = [ self._generate_one_heatmap(keypoint) for keypoint in keypoint_list ] results['heatmap'] = np.stack(heatmap_list, axis=2) return results def _generate_one_heatmap(self, keypoint): """Generate One Heatmap. Args: landmark (Tuple[float]): Location of a landmark. results: heatmap (np.ndarray): A heatmap of landmark. """ w, h = self.target_size x_range = np.arange(start=0, stop=w, dtype=int) y_range = np.arange(start=0, stop=h, dtype=int) grid_x, grid_y = np.meshgrid(x_range, y_range) dist2 = (grid_x - keypoint[0])**2 + (grid_y - keypoint[1])**2 exponent = dist2 / 2.0 / self.sigma / self.sigma heatmap = np.exp(-exponent) return heatmap def __repr__(self): return (f'{self.__class__.__name__}, ' f'keypoint={self.keypoint}, ' f'ori_size={self.ori_size}, ' f'target_size={self.target_size}, ' f'sigma={self.sigma}')
[docs]@PIPELINES.register_module() class GenerateCoordinateAndCell: """Generate coordinate and cell. Generate coordinate from the desired size of SR image. Train or val: 1. Generate coordinate from GT. 2. Reshape GT image to (HgWg, 3) and transpose to (3, HgWg). where `Hg` and `Wg` represent the height and width of GT. Test: Generate coordinate from LQ and scale or target_size. Then generate cell from coordinate. Args: sample_quantity (int): The quantity of samples in coordinates. To ensure that the GT tensors in a batch have the same dimensions. Default: None. scale (float): Scale of upsampling. Default: None. target_size (tuple[int]): Size of target image. Default: None. The priority of getting 'size of target image' is: 1, results['gt'].shape[-2:] 2, results['lq'].shape[-2:] * scale 3, target_size """ def __init__(self, sample_quantity=None, scale=None, target_size=None): self.sample_quantity = sample_quantity self.scale = scale self.target_size = target_size def __call__(self, results): """Call function. Args: results (dict): A dict containing the necessary information and data for augmentation. Require either in results: 1. 'lq' (tensor), whose shape is similar as (3, H, W). 2. 'gt' (tensor), whose shape is similar as (3, H, W). 3. None, the premise is self.target_size and len(self.target_size) >= 2. Returns: dict: A dict containing the processed data and information. Reshape 'gt' to (-1, 3) and transpose to (3, -1) if 'gt' in results. Add 'coord' and 'cell'. """ # generate hr_coord (and hr_rgb) if 'gt' in results: crop_hr = results['gt'] self.target_size = crop_hr.shape hr_rgb = crop_hr.contiguous().view(3, -1).permute(1, 0) results['gt'] = hr_rgb elif self.scale is not None and 'lq' in results: _, h_lr, w_lr = results['lq'].shape self.target_size = (round(h_lr * self.scale), round(w_lr * self.scale)) else: assert self.target_size is not None assert len(self.target_size) >= 2 hr_coord = make_coord(self.target_size[-2:]) if self.sample_quantity is not None and 'gt' in results: sample_lst = np.random.choice( len(hr_coord), self.sample_quantity, replace=False) hr_coord = hr_coord[sample_lst] results['gt'] = results['gt'][sample_lst] # Preparations for cell decoding cell = torch.ones_like(hr_coord) cell[:, 0] *= 2 / self.target_size[-2] cell[:, 1] *= 2 / self.target_size[-1] results['coord'] = hr_coord results['cell'] = cell return results def __repr__(self): repr_str = self.__class__.__name__ repr_str += (f'sample_quantity={self.sample_quantity}, ' f'scale={self.scale}, target_size={self.target_size}') return repr_str
Read the Docs v: latest
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.