Shortcuts

Source code for mmedit.datasets.base_dataset

# Copyright (c) OpenMMLab. All rights reserved.
import copy
from abc import ABCMeta, abstractmethod

from torch.utils.data import Dataset

from .pipelines import Compose


[docs]class BaseDataset(Dataset, metaclass=ABCMeta): """Base class for datasets. All datasets should subclass it. All subclasses should overwrite: ``load_annotations``, supporting to load information and generate image lists. Args: pipeline (list[dict | callable]): A sequence of data transforms. test_mode (bool): If True, the dataset will work in test mode. Otherwise, in train mode. """ def __init__(self, pipeline, test_mode=False): super().__init__() self.test_mode = test_mode self.pipeline = Compose(pipeline)
[docs] @abstractmethod def load_annotations(self): """Abstract function for loading annotation. All subclasses should overwrite this function """
[docs] def prepare_train_data(self, idx): """Prepare training data. Args: idx (int): Index of the training batch data. Returns: dict: Returned training batch. """ results = copy.deepcopy(self.data_infos[idx]) return self.pipeline(results)
[docs] def prepare_test_data(self, idx): """Prepare testing data. Args: idx (int): Index for getting each testing batch. Returns: Tensor: Returned testing batch. """ results = copy.deepcopy(self.data_infos[idx]) return self.pipeline(results)
def __len__(self): """Length of the dataset. Returns: int: Length of the dataset. """ return len(self.data_infos) def __getitem__(self, idx): """Get item at each call. Args: idx (int): Index for getting each item. """ if self.test_mode: return self.prepare_test_data(idx) return self.prepare_train_data(idx)
Read the Docs v: v0.15.0
Versions
latest
stable
v0.15.0
v0.14.0
v0.13.0
v0.12.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.