Shortcuts

Note

You are reading the documentation for MMEditing 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMEditing 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMEditing 1.0 for more details.

Source code for mmedit.datasets.base_sr_dataset

# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
from collections import defaultdict
from pathlib import Path

from mmcv import scandir

from .base_dataset import BaseDataset

IMG_EXTENSIONS = ('.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm',
                  '.PPM', '.bmp', '.BMP', '.tif', '.TIF', '.tiff', '.TIFF')


[docs]class BaseSRDataset(BaseDataset): """Base class for super resolution datasets.""" def __init__(self, pipeline, scale, test_mode=False): super().__init__(pipeline, test_mode) self.scale = scale
[docs] @staticmethod def scan_folder(path): """Obtain image path list (including sub-folders) from a given folder. Args: path (str | :obj:`Path`): Folder path. Returns: list[str]: image list obtained form given folder. """ if isinstance(path, (str, Path)): path = str(path) else: raise TypeError("'path' must be a str or a Path object, " f'but received {type(path)}.') images = list(scandir(path, suffix=IMG_EXTENSIONS, recursive=True)) images = [osp.join(path, v) for v in images] assert images, f'{path} has no valid image file.' return images
def __getitem__(self, idx): """Get item at each call. Args: idx (int): Index for getting each item. """ results = copy.deepcopy(self.data_infos[idx]) results['scale'] = self.scale return self.pipeline(results)
[docs] def evaluate(self, results, logger=None): """Evaluate with different metrics. Args: results (list[tuple]): The output of forward_test() of the model. Return: dict: Evaluation results dict. """ if not isinstance(results, list): raise TypeError(f'results must be a list, but got {type(results)}') assert len(results) == len(self), ( 'The length of results is not equal to the dataset len: ' f'{len(results)} != {len(self)}') results = [res['eval_result'] for res in results] # a list of dict eval_result = defaultdict(list) # a dict of list for res in results: for metric, val in res.items(): eval_result[metric].append(val) for metric, val_list in eval_result.items(): assert len(val_list) == len(self), ( f'Length of evaluation result of {metric} is {len(val_list)}, ' f'should be {len(self)}') # average the results eval_result = { metric: sum(values) / len(self) for metric, values in eval_result.items() } return eval_result
Read the Docs v: latest
Versions
latest
stable
1.x
v0.16.0
v0.15.2
v0.15.1
v0.15.0
v0.14.0
v0.13.0
v0.12.0
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.