Shortcuts

Note

You are reading the documentation for MMEditing 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMEditing 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMEditing 1.0 for more details.

Source code for mmedit.models.mattors.dim

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.runner import auto_fp16

from ..builder import build_loss
from ..registry import MODELS
from .base_mattor import BaseMattor
from .utils import get_unknown_tensor


[docs]@MODELS.register_module() class DIM(BaseMattor): """Deep Image Matting model. https://arxiv.org/abs/1703.03872 .. note:: For ``(self.train_cfg.train_backbone, self.train_cfg.train_refiner)``: * ``(True, False)`` corresponds to the encoder-decoder stage in \ the paper. * ``(False, True)`` corresponds to the refinement stage in the \ paper. * ``(True, True)`` corresponds to the fine-tune stage in the paper. Args: backbone (dict): Config of backbone. refiner (dict): Config of refiner. train_cfg (dict): Config of training. In ``train_cfg``, ``train_backbone`` should be specified. If the model has a refiner, ``train_refiner`` should be specified. test_cfg (dict): Config of testing. In ``test_cfg``, If the model has a refiner, ``train_refiner`` should be specified. pretrained (str): Path of pretrained model. loss_alpha (dict): Config of the alpha prediction loss. Default: None. loss_comp (dict): Config of the composition loss. Default: None. loss_refine (dict): Config of the loss of the refiner. Default: None. """ def __init__(self, backbone, refiner=None, train_cfg=None, test_cfg=None, pretrained=None, loss_alpha=None, loss_comp=None, loss_refine=None): super().__init__(backbone, refiner, train_cfg, test_cfg, pretrained) if all(v is None for v in (loss_alpha, loss_comp, loss_refine)): raise ValueError('Please specify one loss for DIM.') if loss_alpha is not None: self.loss_alpha = build_loss(loss_alpha) if loss_comp is not None: self.loss_comp = build_loss(loss_comp) if loss_refine is not None: self.loss_refine = build_loss(loss_refine) # support fp16 self.fp16_enabled = False @auto_fp16() def _forward(self, x, refine): raw_alpha = self.backbone(x) pred_alpha = raw_alpha.sigmoid() if refine: refine_input = torch.cat((x[:, :3, :, :], pred_alpha), 1) pred_refine = self.refiner(refine_input, raw_alpha) else: # As ONNX does not support NoneType for output, # we choose to use zero tensor to represent None pred_refine = torch.zeros([]) return pred_alpha, pred_refine def forward_dummy(self, inputs): return self._forward(inputs, self.with_refiner)
[docs] def forward_train(self, merged, trimap, meta, alpha, ori_merged, fg, bg): """Defines the computation performed at every training call. Args: merged (Tensor): of shape (N, C, H, W) encoding input images. Typically these should be mean centered and std scaled. trimap (Tensor): of shape (N, 1, H, W). Tensor of trimap read by opencv. meta (list[dict]): Meta data about the current data batch. alpha (Tensor): of shape (N, 1, H, W). Tensor of alpha read by opencv. ori_merged (Tensor): of shape (N, C, H, W). Tensor of origin merged image read by opencv (not normalized). fg (Tensor): of shape (N, C, H, W). Tensor of fg read by opencv. bg (Tensor): of shape (N, C, H, W). Tensor of bg read by opencv. Returns: dict: Contains the loss items and batch information. """ pred_alpha, pred_refine = self._forward( torch.cat((merged, trimap), 1), self.train_cfg.train_refiner) weight = get_unknown_tensor(trimap, meta) losses = dict() if self.train_cfg.train_backbone: if self.loss_alpha is not None: losses['loss_alpha'] = self.loss_alpha(pred_alpha, alpha, weight) if self.loss_comp is not None: losses['loss_comp'] = self.loss_comp(pred_alpha, fg, bg, ori_merged, weight) if self.train_cfg.train_refiner: losses['loss_refine'] = self.loss_refine(pred_refine, alpha, weight) return {'losses': losses, 'num_samples': merged.size(0)}
[docs] def forward_test(self, merged, trimap, meta, save_image=False, save_path=None, iteration=None): """Defines the computation performed at every test call. Args: merged (Tensor): Image to predict alpha matte. trimap (Tensor): Trimap of the input image. meta (list[dict]): Meta data about the current data batch. Currently only batch_size 1 is supported. It may contain information needed to calculate metrics (``ori_alpha`` and ``ori_trimap``) or save predicted alpha matte (``merged_path``). save_image (bool, optional): Whether save predicted alpha matte. Defaults to False. save_path (str, optional): The directory to save predicted alpha matte. Defaults to None. iteration (int, optional): If given as None, the saved alpha matte will have the same file name with ``merged_path`` in meta dict. If given as an int, the saved alpha matte would named with postfix ``_{iteration}.png``. Defaults to None. Returns: dict: Contains the predicted alpha and evaluation result. """ pred_alpha, pred_refine = self._forward( torch.cat((merged, trimap), 1), self.test_cfg.refine) if self.test_cfg.refine: pred_alpha = pred_refine pred_alpha = pred_alpha.detach().cpu().numpy().squeeze() pred_alpha = self.restore_shape(pred_alpha, meta) eval_result = self.evaluate(pred_alpha, meta) if save_image: self.save_image(pred_alpha, meta, save_path, iteration) return {'pred_alpha': pred_alpha, 'eval_result': eval_result}
Read the Docs v: latest
Versions
latest
stable
1.x
v0.16.0
v0.15.2
v0.15.1
v0.15.0
v0.14.0
v0.13.0
v0.12.0
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.