Shortcuts

Note

You are reading the documentation for MMEditing 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMEditing 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMEditing 1.0 for more details.

Source code for mmedit.datasets.generation_unpaired_dataset

# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp

import numpy as np

from .base_generation_dataset import BaseGenerationDataset
from .registry import DATASETS


[docs]@DATASETS.register_module() class GenerationUnpairedDataset(BaseGenerationDataset): """General unpaired image folder dataset for image generation. It assumes that the training directory of images from domain A is '/path/to/data/trainA', and that from domain B is '/path/to/data/trainB', respectively. '/path/to/data' can be initialized by args 'dataroot'. During test time, the directory is '/path/to/data/testA' and '/path/to/data/testB', respectively. Args: dataroot (str | :obj:`Path`): Path to the folder root of unpaired images. pipeline (List[dict | callable]): A sequence of data transformations. test_mode (bool): Store `True` when building test dataset. Default: `False`. """ def __init__(self, dataroot, pipeline, test_mode=False): super().__init__(pipeline, test_mode) phase = 'test' if test_mode else 'train' self.dataroot_a = osp.join(str(dataroot), phase + 'A') self.dataroot_b = osp.join(str(dataroot), phase + 'B') self.data_infos_a = self.load_annotations(self.dataroot_a) self.data_infos_b = self.load_annotations(self.dataroot_b) self.len_a = len(self.data_infos_a) self.len_b = len(self.data_infos_b)
[docs] def load_annotations(self, dataroot): """Load unpaired image paths of one domain. Args: dataroot (str): Path to the folder root for unpaired images of one domain. Returns: list[dict]: List that contains unpaired image paths of one domain. """ data_infos = [] paths = sorted(self.scan_folder(dataroot)) for path in paths: data_infos.append(dict(path=path)) return data_infos
[docs] def prepare_train_data(self, idx): """Prepare unpaired training data. Args: idx (int): Index of current batch. Returns: dict: Prepared training data batch. """ img_a_path = self.data_infos_a[idx % self.len_a]['path'] idx_b = np.random.randint(0, self.len_b) img_b_path = self.data_infos_b[idx_b]['path'] results = dict(img_a_path=img_a_path, img_b_path=img_b_path) return self.pipeline(results)
[docs] def prepare_test_data(self, idx): """Prepare unpaired test data. Args: idx (int): Index of current batch. Returns: list[dict]: Prepared test data batch. """ img_a_path = self.data_infos_a[idx % self.len_a]['path'] img_b_path = self.data_infos_b[idx % self.len_b]['path'] results = dict(img_a_path=img_a_path, img_b_path=img_b_path) return self.pipeline(results)
def __len__(self): return max(self.len_a, self.len_b)
Read the Docs v: latest
Versions
latest
stable
1.x
v0.16.0
v0.15.2
v0.15.1
v0.15.0
v0.14.0
v0.13.0
v0.12.0
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.