You are reading the documentation for MMEditing 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMEditing 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMEditing 1.0 for more details.
Source code for mmedit.models.losses.utils
# Copyright (c) OpenMMLab. All rights reserved. import functools import torch.nn.functional as F[docs]def reduce_loss(loss, reduction): """Reduce loss as specified. Args: loss (Tensor): Elementwise loss tensor. reduction (str): Options are "none", "mean" and "sum". Returns: Tensor: Reduced loss tensor. """ reduction_enum = F._Reduction.get_enum(reduction) # none: 0, elementwise_mean:1, sum: 2 if reduction_enum == 0: return loss if reduction_enum == 1: return loss.mean() return loss.sum()[docs]def mask_reduce_loss(loss, weight=None, reduction='mean', sample_wise=False): """Apply element-wise weight and reduce loss. Args: loss (Tensor): Element-wise loss. weight (Tensor): Element-wise weights. Default: None. reduction (str): Same as built-in losses of PyTorch. Options are "none", "mean" and "sum". Default: 'mean'. sample_wise (bool): Whether calculate the loss sample-wise. This argument only takes effect when `reduction` is 'mean' and `weight` (argument of `forward()`) is not None. It will first reduces loss with 'mean' per-sample, and then it means over all the samples. Default: False. Returns: Tensor: Processed loss values. """ # if weight is specified, apply element-wise weight if weight is not None: assert weight.dim() == loss.dim() assert weight.size(1) == 1 or weight.size(1) == loss.size(1) loss = loss * weight # if weight is not specified or reduction is sum, just reduce the loss if weight is None or reduction == 'sum': loss = reduce_loss(loss, reduction) # if reduction is mean, then compute mean over masked region elif reduction == 'mean': # expand weight from N1HW to NCHW if weight.size(1) == 1: weight = weight.expand_as(loss) # small value to prevent division by zero eps = 1e-12 # perform sample-wise mean if sample_wise: weight = weight.sum(dim=[1, 2, 3], keepdim=True) # NCHW to N111 loss = (loss / (weight + eps)).sum() / weight.size(0) # perform pixel-wise mean else: loss = loss.sum() / (weight.sum() + eps) return lossdef masked_loss(loss_func): """Create a masked version of a given loss function. To use this decorator, the loss function must have the signature like `loss_func(pred, target, **kwargs)`. The function only needs to compute element-wise loss without any reduction. This decorator will add weight and reduction arguments to the function. The decorated function will have the signature like `loss_func(pred, target, weight=None, reduction='mean', avg_factor=None, **kwargs)`. :Example: >>> import torch >>> @masked_loss >>> def l1_loss(pred, target): >>> return (pred - target).abs() >>> pred = torch.Tensor([0, 2, 3]) >>> target = torch.Tensor([1, 1, 1]) >>> weight = torch.Tensor([1, 0, 1]) >>> l1_loss(pred, target) tensor(1.3333) >>> l1_loss(pred, target, weight) tensor(1.5000) >>> l1_loss(pred, target, reduction='none') tensor([1., 1., 2.]) >>> l1_loss(pred, target, weight, reduction='sum') tensor(3.) """ @functools.wraps(loss_func) def wrapper(pred, target, weight=None, reduction='mean', sample_wise=False, **kwargs): # get element-wise loss loss = loss_func(pred, target, **kwargs) loss = mask_reduce_loss(loss, weight, reduction, sample_wise) return loss return wrapper