Shortcuts

Note

You are reading the documentation for MMEditing 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMEditing 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMEditing 1.0 for more details.

Source code for mmedit.models.base

# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from collections import OrderedDict

import torch
import torch.nn as nn


[docs]class BaseModel(nn.Module, metaclass=ABCMeta): """Base model. All models should subclass it. All subclass should overwrite: ``init_weights``, supporting to initialize models. ``forward_train``, supporting to forward when training. ``forward_test``, supporting to forward when testing. ``train_step``, supporting to train one step when training. """
[docs] @abstractmethod def init_weights(self): """Abstract method for initializing weight. All subclass should overwrite it. """
[docs] @abstractmethod def forward_train(self, imgs, labels): """Abstract method for training forward. All subclass should overwrite it. """
[docs] @abstractmethod def forward_test(self, imgs): """Abstract method for testing forward. All subclass should overwrite it. """
[docs] def forward(self, imgs, labels, test_mode, **kwargs): """Forward function for base model. Args: imgs (Tensor): Input image(s). labels (Tensor): Ground-truth label(s). test_mode (bool): Whether in test mode. kwargs (dict): Other arguments. Returns: Tensor: Forward results. """ if test_mode: return self.forward_test(imgs, **kwargs) return self.forward_train(imgs, labels, **kwargs)
[docs] @abstractmethod def train_step(self, data_batch, optimizer): """Abstract method for one training step. All subclass should overwrite it. """
[docs] def val_step(self, data_batch, **kwargs): """Abstract method for one validation step. All subclass should overwrite it. """ output = self.forward_test(**data_batch, **kwargs) return output
[docs] def parse_losses(self, losses): """Parse losses dict for different loss variants. Args: losses (dict): Loss dict. Returns: loss (float): Sum of the total loss. log_vars (dict): loss dict for different variants. """ log_vars = OrderedDict() for loss_name, loss_value in losses.items(): if isinstance(loss_value, torch.Tensor): log_vars[loss_name] = loss_value.mean() elif isinstance(loss_value, list): log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) else: raise TypeError( f'{loss_name} is not a tensor or list of tensors') loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key) log_vars['loss'] = loss for name in log_vars: log_vars[name] = log_vars[name].item() return loss, log_vars
Read the Docs v: latest
Versions
latest
stable
1.x
v0.16.0
v0.15.2
v0.15.1
v0.15.0
v0.14.0
v0.13.0
v0.12.0
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.