Shortcuts

Note

You are reading the documentation for MMEditing 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMEditing 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMEditing 1.0 for more details.

# Source code for mmedit.core.misc

# Copyright (c) OpenMMLab. All rights reserved.
import math

import numpy as np
import torch
from torchvision.utils import make_grid

[docs]def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
"""Convert torch Tensors into image numpy arrays.

After clamping to (min, max), image values will be normalized to [0, 1].

For different tensor shapes, this function will have different behaviors:

1. 4D mini-batch Tensor of shape (N x 3/1 x H x W):
Use make_grid to stitch images in the batch dimension, and then
convert it to numpy array.
2. 3D Tensor of shape (3/1 x H x W) and 2D Tensor of shape (H x W):
Directly change to numpy array.

Note that the image channel in input tensors should be RGB order. This
function will convert it to cv2 convention, i.e., (H x W x C) with BGR
order.

Args:
tensor (Tensor | list[Tensor]): Input tensors.
out_type (numpy type): Output types. If np.uint8, transform outputs
to uint8 type with range [0, 255]; otherwise, float type with
range [0, 1]. Default: np.uint8.
min_max (tuple): min and max values for clamp.

Returns:
(Tensor | list[Tensor]): 3D ndarray of shape (H x W x C) or 2D ndarray
of shape (H x W).
"""
if not (torch.is_tensor(tensor) or
(isinstance(tensor, list)
and all(torch.is_tensor(t) for t in tensor))):
raise TypeError(
f'tensor or list of tensors expected, got {type(tensor)}')

if torch.is_tensor(tensor):
tensor = [tensor]
result = []
for _tensor in tensor:
# Squeeze two times so that:
# 1. (1, 1, h, w) -> (h, w) or
# 3. (1, 3, h, w) -> (3, h, w) or
# 2. (n>1, 3/1, h, w) -> (n>1, 3/1, h, w)
_tensor = _tensor.squeeze(0).squeeze(0)
_tensor = _tensor.float().detach().cpu().clamp_(*min_max)
_tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
n_dim = _tensor.dim()
if n_dim == 4:
img_np = make_grid(
_tensor, nrow=int(math.sqrt(_tensor.size(0))),
normalize=False).numpy()
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))
elif n_dim == 3:
img_np = _tensor.numpy()
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))
elif n_dim == 2:
img_np = _tensor.numpy()
else:
raise ValueError('Only support 4D, 3D or 2D tensor. '
if out_type == np.uint8:
# Unlike MATLAB, numpy.unit8() WILL NOT round by default.
img_np = (img_np * 255.0).round()
img_np = img_np.astype(out_type)
result.append(img_np)
result = result[0] if len(result) == 1 else result
return result


© Copyright 2020, MMEditing Authors. Revision 1e7e3d17.

Built with Sphinx using a theme provided by Read the Docs.
Versions
latest
stable
1.x
v0.16.0
v0.15.2
v0.15.1
v0.15.0
v0.14.0
v0.13.0
v0.12.0
dev-1.x