Shortcuts

Learn about Configs

We use python files as our config system. You can find all the provided configs under $MMEditing/configs.

Config Name Style

We follow the below style to name config files. Contributors are advised to follow the same style.

{model}_[model setting]_{backbone}_[refiner]_[norm setting]_[misc]_[gpu x batch_per_gpu]_{schedule}_{dataset}

{xxx} is required field and [yyy] is optional.

  • {model}: model type like srcnn, dim, etc.

  • [model setting]: specific setting for some model, like resolution for input images, stage name for training, etc.

  • {backbone}: backbone type like r50 (ResNet-50), x101 (ResNeXt-101).

  • [refiner]: refiner type like pln (Plain Refiner).

  • [norm_setting]: bn (Batch Normalization) is used unless specified, other norm layer type could be gn (Group Normalization), syncbn (Synchronized Batch Normalization).

  • [misc]: miscellaneous setting/plugins of model, e.g. dconv, gcb, attention, albu, mstrain.

  • [gpu x batch_per_gpu]: GPUs and samples per GPU, 8x2 is used by default.

  • {schedule}: training schedule, 20k, 100k, etc. 20k means 20,000 iterations. 100k means 100,000 iterations.

  • {dataset}: dataset like places (for inpainting), comp1k (for matting), div2k (for restoration) and paired (for generation).

Config System for Generation

Same as MMDetection, we incorporate modular and inheritance design into our config system, which is convenient to conduct various experiments.

An Example - pix2pix

To help the users have a basic idea of a complete config and the modules in a generation system, we make brief comments on the config of pix2pix as the following. For more detailed usage and the corresponding alternative for each modules, please refer to the API documentation.

## model settings
model = dict(
    type='Pix2Pix',  ## The name of synthesizer
    generator=dict(
        type='UnetGenerator',  ## The name of generator
        in_channels=3,  ## The input channels of generator
        out_channels=3,  ## The output channels of generator
        num_down=8,  ## The umber of downsamplings in the generator
        base_channels=64,  ## The number of channels at the last conv layer of generator
        norm_cfg=dict(type='BN'),  ## The config of norm layer
        use_dropout=True,  ## Whether to use dropout layers in the generator
        init_cfg=dict(type='normal', gain=0.02)),  ## The config of initialization
    discriminator=dict(
        type='PatchDiscriminator',  ## The name of discriminator
        in_channels=6,  ## The input channels of discriminator
        base_channels=64,  ## The number of channels at the first conv layer of discriminator
        num_conv=3,  ## The number of stacked intermediate conv layers (excluding input and output conv layer) in the discriminator
        norm_cfg=dict(type='BN'),  ## The config of norm layer
        init_cfg=dict(type='normal', gain=0.02)),  ## The config of initialization
    gan_loss=dict(
        type='GANLoss',  ## The name of GAN loss
        gan_type='vanilla',  ## The type of GAN loss
        real_label_val=1.0,  ## The value for real label of GAN loss
        fake_label_val=0.0,  ## The value for fake label of GAN loss
        loss_weight=1.0),  ## The weight of GAN loss
    pixel_loss=dict(type='L1Loss', loss_weight=100.0, reduction='mean'))
## model training and testing settings
train_cfg = dict(
    direction='b2a')  ## Image-to-image translation direction (the model training direction, same as testing direction) for pix2pix. Model default: a2b
test_cfg = dict(
    direction='b2a',   ## Image-to-image translation direction (the model training direction, same as testing direction) for pix2pix. Model default: a2b
    show_input=True)  ## Whether to show input real images when saving testing images for pix2pix

## dataset settings
train_dataset_type = 'GenerationPairedDataset'  ## The type of dataset for training
val_dataset_type = 'GenerationPairedDataset'  ## The type of dataset for validation/testing
img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  ## Image normalization config to normalize the input images
train_pipeline = [
    dict(
        type='LoadPairedImageFromFile',  ## Load a pair of images from file path pipeline
        io_backend='disk',  ## IO backend where images are store
        key='pair',  ## Keys to find corresponding path
        flag='color'),  ## Loading flag for images
    dict(
        type='Resize',  ## Resize pipeline
        keys=['img_a', 'img_b'],  ## The keys of images to be resized
        scale=(286, 286),  ## The scale to resize images
        interpolation='bicubic'),  ## Algorithm used for interpolation when resizing images
    dict(
        type='FixedCrop',  ## Fixed crop pipeline, cropping paired images to a specific size at a specific position for pix2pix training
        keys=['img_a', 'img_b'],  ## The keys of images to be cropped
        crop_size=(256, 256)),  ## The size to crop images
    dict(
        type='Flip',  ## Flip pipeline
        keys=['img_a', 'img_b'],  ## The keys of images to be flipped
        direction='horizontal'),  ## Flip images horizontally or vertically
    dict(
        type='RescaleToZeroOne',  ## Rescale images from [0, 255] to [0, 1]
        keys=['img_a', 'img_b']),  ## The keys of images to be rescaled
    dict(
        type='Normalize',  ## Image normalization pipeline
        keys=['img_a', 'img_b'],  ## The keys of images to be normalized
        to_rgb=True,  ## Whether to convert image channels from BGR to RGB
        **img_norm_cfg),  ## Image normalization config (see above for the definition of `img_norm_cfg`)
    dict(
       type='ImageToTensor',  ## Image to tensor pipeline
       keys=['img_a', 'img_b']),  ## The keys of images to be converted from image to tensor
    dict(
        type='Collect',  ## Pipeline that decides which keys in the data should be passed to the synthesizer
        keys=['img_a', 'img_b'],  ## The keys of images
        meta_keys=['img_a_path', 'img_b_path'])  ## The meta keys of images
]
test_pipeline = [
    dict(
        type='LoadPairedImageFromFile',  ## Load a pair of images from file path pipeline
        io_backend='disk',  ## IO backend where images are store
        key='pair',  ## Keys to find corresponding path
        flag='color'),  ## Loading flag for images
    dict(
        type='Resize',  ## Resize pipeline
        keys=['img_a', 'img_b'],  ## The keys of images to be resized
        scale=(256, 256),  ## The scale to resize images
        interpolation='bicubic'),  ## Algorithm used for interpolation when resizing images
    dict(
        type='RescaleToZeroOne',  ## Rescale images from [0, 255] to [0, 1]
        keys=['img_a', 'img_b']),  ## The keys of images to be rescaled
    dict(
        type='Normalize',  ## Image normalization pipeline
        keys=['img_a', 'img_b'],  ## The keys of images to be normalized
        to_rgb=True,  ## Whether to convert image channels from BGR to RGB
        **img_norm_cfg),  ## Image normalization config (see above for the definition of `img_norm_cfg`)
    dict(
       type='ImageToTensor',  ## Image to tensor pipeline
       keys=['img_a', 'img_b']),  ## The keys of images to be converted from image to tensor
    dict(
        type='Collect',  ## Pipeline that decides which keys in the data should be passed to the synthesizer
        keys=['img_a', 'img_b'],  ## The keys of images
        meta_keys=['img_a_path', 'img_b_path'])  ## The meta keys of images
]
data_root = 'data/pix2pix/facades'  ## The root path of data
data = dict(
    samples_per_gpu=1,  ## Batch size of a single GPU
    workers_per_gpu=4,  ## Worker to pre-fetch data for each single GPU
    drop_last=True,  ## Whether to drop out the last batch of data in training
    val_samples_per_gpu=1,  ## Batch size of a single GPU in validation
    val_workers_per_gpu=0,  ## Worker to pre-fetch data for each single GPU in validation
    train=dict(  ## Training dataset config
        type=train_dataset_type,
        dataroot=data_root,
        pipeline=train_pipeline,
        test_mode=False),
    val=dict(  ## Validation dataset config
        type=val_dataset_type,
        dataroot=data_root,
        pipeline=test_pipeline,
        test_mode=True),
    test=dict(  ## Testing dataset config
        type=val_dataset_type,
        dataroot=data_root,
        pipeline=test_pipeline,
        test_mode=True))

## optimizer
optimizers = dict(  ## Config used to build optimizer, support all the optimizers in PyTorch whose arguments are also the same as those in PyTorch
    generator=dict(type='Adam', lr=2e-4, betas=(0.5, 0.999)),
    discriminator=dict(type='Adam', lr=2e-4, betas=(0.5, 0.999)))

## learning policy
lr_config = dict(policy='Fixed', by_epoch=False)  ## Learning rate scheduler config used to register LrUpdater hook

## checkpoint saving
checkpoint_config = dict(interval=4000, save_optimizer=True, by_epoch=False)  ## Config to set the checkpoint hook, Refer to https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/checkpoint.py for implementation.
evaluation = dict(  ## The config to build the evaluation hook
    interval=4000,  ## Evaluation interval
    save_image=True)  ## Whether to save images
log_config = dict(  ## config to register logger hook
    interval=100,  ## Interval to print the log
    hooks=[
        dict(type='TextLoggerHook', by_epoch=False),   ## The logger used to record the training process
        ## dict(type='TensorboardLoggerHook')  # The Tensorboard logger is also supported
    ])
visual_config = None  ## The config to build the visualization hook

## runtime settings
total_iters = 80000  ## Total iterations to train the model
cudnn_benchmark = True  ## Set cudnn_benchmark
dist_params = dict(backend='nccl')  ## Parameters to setup distributed training, the port can also be set
log_level = 'INFO'  ## The level of logging
load_from = None  ## Load models as a pre-trained model from a given path. This will not resume training
resume_from = None  ## Resume checkpoints from a given path, the training will be resumed from the epoch when the checkpoint's is saved
workflow = [('train', 1)]  ## Workflow for runner. [('train', 1)] means there is only one workflow and the workflow named 'train' is executed once. Keep this unchanged when training current generation models
exp_name = 'pix2pix_facades'  ## The experiment name
work_dir = f'./work_dirs/{exp_name}'  ## Directory to save the model checkpoints and logs for the current experiments.

Config System for Inpainting

Config Name Style

Same as MMDetection, we incorporate modular and inheritance design into our config system, which is convenient to conduct various experiments.

Config Field Description

To help the users have a basic idea of a complete config and the modules in a inpainting system, we make brief comments on the config of Global&Local as the following. For more detailed usage and the corresponding alternative for each modules, please refer to the API documentation.

model = dict(
    type='GLInpaintor', ## The name of inpaintor
    encdec=dict(
        type='GLEncoderDecoder', ## The name of encoder-decoder
        encoder=dict(type='GLEncoder', norm_cfg=dict(type='SyncBN')), ## The config of encoder
        decoder=dict(type='GLDecoder', norm_cfg=dict(type='SyncBN')), ## The config of decoder
        dilation_neck=dict(
            type='GLDilationNeck', norm_cfg=dict(type='SyncBN'))), ## The config of dilation neck
    disc=dict(
        type='GLDiscs', ## The name of discriminator
        global_disc_cfg=dict(
            in_channels=3, ## The input channel of discriminator
            max_channels=512, ## The maximum middle channel in discriminator
            fc_in_channels=512 * 4 * 4, ## The input channel of last fc layer
            fc_out_channels=1024, ## The output channel of last fc channel
            num_convs=6, ## The number of convs used in discriminator
            norm_cfg=dict(type='SyncBN') ## The config of norm layer
        ),
        local_disc_cfg=dict(
            in_channels=3, ## The input channel of discriminator
            max_channels=512, ## The maximum middle channel in discriminator
            fc_in_channels=512 * 4 * 4, ## The input channel of last fc layer
            fc_out_channels=1024, ## The output channel of last fc channel
            num_convs=5, ## The number of convs used in discriminator
            norm_cfg=dict(type='SyncBN') ## The config of norm layer
        ),
    ),
    loss_gan=dict(
        type='GANLoss', ## The name of GAN loss
        gan_type='vanilla', ## The type of GAN loss
        loss_weight=0.001 ## The weight of GAN loss
    ),
    loss_l1_hole=dict(
        type='L1Loss', ## The type of l1 loss
        loss_weight=1.0 ## The weight of l1 loss
    ),
    pretrained=None) ## The path of pretrained weight

train_cfg = dict(
    disc_step=1, ## The steps of training discriminator before training generator
    iter_tc=90000, ## Iterations of warming up generator
    iter_td=100000, ## Iterations of warming up discriminator
    start_iter=0, ## Starting iteration
    local_size=(128, 128)) ## The size of local patches
test_cfg = dict(metrics=['l1']) ## The config of testing scheme

dataset_type = 'ImgInpaintingDataset' ## The type of dataset
input_shape = (256, 256) ## The shape of input image

train_pipeline = [
    dict(type='LoadImageFromFile', key='gt_img'), ## The config of loading image
    dict(
        type='LoadMask', ## The type of loading mask pipeline
        mask_mode='bbox', ## The type of mask
        mask_config=dict(
            max_bbox_shape=(128, 128), ## The shape of bbox
            max_bbox_delta=40, ## The changing delta of bbox height and width
            min_margin=20,  ## The minimum margin from bbox to the image border
            img_shape=input_shape)),  ## The input image shape
    dict(
        type='Crop', ## The type of crop pipeline
        keys=['gt_img'],  ## The keys of images to be cropped
        crop_size=(384, 384),  ## The size of cropped patch
        random_crop=True,  ## Whether to use random crop
    ),
    dict(
        type='Resize',  ## The type of resizing pipeline
        keys=['gt_img'],  ## They keys of images to be resized
        scale=input_shape,  ## The scale of resizing function
        keep_ratio=False,  ## Whether to keep ratio during resizing
    ),
    dict(
        type='Normalize',  ## The type of normalizing pipeline
        keys=['gt_img'],  ## The keys of images to be normed
        mean=[127.5] * 3,  ## Mean value used in normalization
        std=[127.5] * 3,  ## Std value used in normalization
        to_rgb=False),  ## Whether to transfer image channels to rgb
    dict(type='GetMaskedImage'),  ## The config of getting masked image pipeline
    dict(
        type='Collect',  ## The type of collecting data from current pipeline
        keys=['gt_img', 'masked_img', 'mask', 'mask_bbox'],  ## The keys of data to be collected
        meta_keys=['gt_img_path']),  ## The meta keys of data to be collected
    dict(type='ImageToTensor', keys=['gt_img', 'masked_img', 'mask']),  ## The config dict of image to tensor pipeline
    dict(type='ToTensor', keys=['mask_bbox'])  ## The config dict of ToTensor pipeline
]

test_pipeline = train_pipeline  ## Constructing testing/validation pipeline

data_root = 'data/places365'  ## Set data root

data = dict(
    samples_per_gpu=12,  ## Batch size of a single GPU
    workers_per_gpu=8,  ## Worker to pre-fetch data for each single GPU
    val_samples_per_gpu=1,  ## Batch size of a single GPU in validation
    val_workers_per_gpu=8,  ## Worker to pre-fetch data for each single GPU in validation
    drop_last=True,  ## Whether to drop out the last batch of data
    train=dict(  ## Train dataset config
        type=dataset_type,
        ann_file=f'{data_root}/train_places_img_list_total.txt',
        data_prefix=data_root,
        pipeline=train_pipeline,
        test_mode=False),
    val=dict(  ## Validation dataset config
        type=dataset_type,
        ann_file=f'{data_root}/val_places_img_list.txt',
        data_prefix=data_root,
        pipeline=test_pipeline,
        test_mode=True))

optimizers = dict(  ## Config used to build optimizer, support all the optimizers in PyTorch whose arguments are also the same as those in PyTorch
    generator=dict(type='Adam', lr=0.0004), disc=dict(type='Adam', lr=0.0004))

lr_config = dict(policy='Fixed', by_epoch=False)  ## Learning rate scheduler config used to register LrUpdater hook

checkpoint_config = dict(by_epoch=False, interval=50000)  ## Config to set the checkpoint hook, Refer to https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/checkpoint.py for implementation.
log_config = dict(  ## config to register logger hook
    interval=100,  ## Interval to print the log
    hooks=[
        dict(type='TextLoggerHook', by_epoch=False),
        ## dict(type='TensorboardLoggerHook'),  # The Tensorboard logger is also supported
        ## dict(type='PaviLoggerHook', init_kwargs=dict(project='mmedit'))
    ])  ## The logger used to record the training process.

visual_config = dict(  ## config to register visualization hook
    type='VisualizationHook',
    output_dir='visual',
    interval=1000,
    res_name_list=[
        'gt_img', 'masked_img', 'fake_res', 'fake_img', 'fake_gt_local'
    ],
)  ## The logger used to visualize the training process.

evaluation = dict(interval=50000)  ## The config to build the evaluation hook

total_iters = 500002
dist_params = dict(backend='nccl')  ## Parameters to setup distributed training, the port can also be set.
log_level = 'INFO'  ## The level of logging.
work_dir = None  ## Directory to save the model checkpoints and logs for the current experiments.
load_from = None  ## load models as a pre-trained model from a given path. This will not resume training.
resume_from = None  ## Resume checkpoints from a given path, the training will be resumed from the epoch when the checkpoint's is saved.
workflow = [('train', 10000)]  ## Workflow for runner. [('train', 1)] means there is only one workflow and the workflow named 'train' is executed once. The workflow trains the model by 12 epochs according to the total_epochs.
exp_name = 'gl_places'  ## The experiment name
find_unused_parameters = False  ## Whether to set find unused parameters in ddp

Config System for Matting

Same as MMDetection, we incorporate modular and inheritance design into our config system, which is convenient to conduct various experiments.

An Example - Deep Image Matting Model

To help the users have a basic idea of a complete config, we make a brief comments on the config of the original DIM model we implemented as the following. For more detailed usage and the corresponding alternative for each modules, please refer to the API documentation.

## model settings
model = dict(
    type='DIM',  ## The name of model (we call mattor).
    backbone=dict(  ## The config of the backbone.
        type='SimpleEncoderDecoder',  ## The type of the backbone.
        encoder=dict(  ## The config of the encoder.
            type='VGG16'),  ## The type of the encoder.
        decoder=dict(  ## The config of the decoder.
            type='PlainDecoder')),  ## The type of the decoder.
    pretrained='./weights/vgg_state_dict.pth',  ## The pretrained weight of the encoder to be loaded.
    loss_alpha=dict(  ## The config of the alpha loss.
        type='CharbonnierLoss',  ## The type of the loss for predicted alpha matte.
        loss_weight=0.5),  ## The weight of the alpha loss.
    loss_comp=dict(  ## The config of the composition loss.
        type='CharbonnierCompLoss',  ## The type of the composition loss.
        loss_weight=0.5))  ## The weight of the composition loss.
train_cfg = dict(  ## Config of training DIM model.
    train_backbone=True,  ## In DIM stage1, backbone is trained.
    train_refiner=False)  ## In DIM stage1, refiner is not trained.
test_cfg = dict(  ## Config of testing DIM model.
    refine=False,  ## Whether use refiner output as output, in stage1, we don't use it.
    metrics=['SAD', 'MSE', 'GRAD', 'CONN'])  ## The metrics used when testing.

## data settings
dataset_type = 'AdobeComp1kDataset'  ## Dataset type, this will be used to define the dataset.
data_root = 'data/adobe_composition-1k'  ## Root path of data.
img_norm_cfg = dict(  ## Image normalization config to normalize the input images.
    mean=[0.485, 0.456, 0.406],  ## Mean values used to pre-training the pre-trained backbone models.
    std=[0.229, 0.224, 0.225],  ## Standard variance used to pre-training the pre-trained backbone models.
    to_rgb=True)  ## The channel orders of image used to pre-training the pre-trained backbone models.
train_pipeline = [  ## Training data processing pipeline.
    dict(
        type='LoadImageFromFile',  ## Load alpha matte from file.
        key='alpha',  ## Key of alpha matte in annotation file. The pipeline will read alpha matte from path `alpha_path`.
        flag='grayscale'),  ## Load as grayscale image which has shape (height, width).
    dict(
        type='LoadImageFromFile',  ## Load image from file.
        key='fg'),  ## Key of image to load. The pipeline will read fg from path `fg_path`.
    dict(
        type='LoadImageFromFile',  ## Load image from file.
        key='bg'),  ## Key of image to load. The pipeline will read bg from path `bg_path`.
    dict(
        type='LoadImageFromFile',  ## Load image from file.
        key='merged'),  ## Key of image to load. The pipeline will read merged from path `merged_path`.
    dict(
        type='CropAroundUnknown',  ## Crop images around unknown area (semi-transparent area).
        keys=['alpha', 'merged', 'ori_merged', 'fg', 'bg'],  ## Images to crop.
        crop_sizes=[320, 480, 640]),  ## Candidate crop size.
    dict(
        type='Flip',  ## Augmentation pipeline that flips the images.
        keys=['alpha', 'merged', 'ori_merged', 'fg', 'bg']),  ## Images to be flipped.
    dict(
        type='Resize',  ## Augmentation pipeline that resizes the images.
        keys=['alpha', 'merged', 'ori_merged', 'fg', 'bg'],  ## Images to be resized.
        scale=(320, 320),  ## Target size.
        keep_ratio=False),  ## Whether to keep the ratio between height and width.
    dict(
        type='GenerateTrimap',  ## Generate trimap from alpha matte.
        kernel_size=(1, 30)),  ## Kernel size range of the erode/dilate kernel.
    dict(
        type='RescaleToZeroOne',  ## Rescale images from [0, 255] to [0, 1].
        keys=['merged', 'alpha', 'ori_merged', 'fg', 'bg']),  ## Images to be rescaled.
    dict(
        type='Normalize',  ## Augmentation pipeline that normalize the input images.
        keys=['merged'],  ## Images to be normalized.
        **img_norm_cfg),  ## Normalization config. See above for definition of `img_norm_cfg`
    dict(
        type='Collect',  ## Pipeline that decides which keys in the data should be passed to the model
        keys=['merged', 'alpha', 'trimap', 'ori_merged', 'fg', 'bg'],  ## Keys to pass to the model
        meta_keys=[]),  ## Meta information keys. In training, meta information is not needed.
    dict(
        type='ImageToTensor',  ## Convert images to tensor.
        keys=['merged', 'alpha', 'trimap', 'ori_merged', 'fg', 'bg']),  ## Images to be converted to Tensor.
]
test_pipeline = [
    dict(
        type='LoadImageFromFile',  ## Load alpha matte.
        key='alpha',  ## Key of alpha matte in annotation file. The pipeline will read alpha matte from path `alpha_path`.
        flag='grayscale',
        save_original_img=True),
    dict(
        type='LoadImageFromFile',  ## Load image from file
        key='trimap',  ## Key of image to load. The pipeline will read trimap from path `trimap_path`.
        flag='grayscale',  ## Load as grayscale image which has shape (height, width).
        save_original_img=True),  ## Save a copy of trimap for calculating metrics. It will be saved with key `ori_trimap`
    dict(
        type='LoadImageFromFile',  ## Load image from file
        key='merged'),  ## Key of image to load. The pipeline will read merged from path `merged_path`.
    dict(
        type='Pad',  ## Pipeline to pad images to align with the downsample factor of the model.
        keys=['trimap', 'merged'],  ## Images to be padded.
        mode='reflect'),  ## Mode of the padding.
    dict(
        type='RescaleToZeroOne',  ## Same as it in train_pipeline.
        keys=['merged', 'ori_alpha']),  ## Images to be rescaled.
    dict(
        type='Normalize',  ## Same as it in train_pipeline.
        keys=['merged'],
        **img_norm_cfg),
    dict(
        type='Collect',  ## Same as it in train_pipeline.
        keys=['merged', 'trimap'],
        meta_keys=[
            'merged_path', 'pad', 'merged_ori_shape', 'ori_alpha',
            'ori_trimap'
        ]),
    dict(
        type='ImageToTensor',  ## Same as it in train_pipeline.
        keys=['merged', 'trimap']),
]
data = dict(
    samples_per_gpu=1,  ## Batch size of a single GPU.
    workers_per_gpu=4,  ## Worker to pre-fetch data for each single GPU.
    drop_last=True,  ## Use drop_last in data_loader.
    train=dict(  ## Train dataset config.
        type=dataset_type,  ## Type of dataset.
        ann_file=f'{data_root}/training_list.json',  ## Path of annotation file
        data_prefix=data_root,  ## Prefix of image path.
        pipeline=train_pipeline),  ## See above for train_pipeline
    val=dict(  ## Validation dataset config.
        type=dataset_type,
        ann_file=f'{data_root}/test_list.json',
        data_prefix=data_root,
        pipeline=test_pipeline),  ## See above for test_pipeline
    test=dict(  ## Test dataset config.
        type=dataset_type,
        ann_file=f'{data_root}/test_list.json',
        data_prefix=data_root,
        pipeline=test_pipeline))  ## See above for test_pipeline

## optimizer
optimizers = dict(type='Adam', lr=0.00001)  ## Config used to build optimizer, support all the optimizers in PyTorch whose arguments are also the same as those in PyTorch.
## learning policy
lr_config = dict(  ## Learning rate scheduler config used to register LrUpdater hook
    policy='Fixed')  ## The policy of scheduler, also support CosineAnnealing, Cyclic, etc. Refer to details of supported LrUpdater from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/lr_updater.py#L9.

## checkpoint saving
checkpoint_config = dict(  ## Config to set the checkpoint hook, Refer to https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/checkpoint.py for implementation.
    interval=40000,  ## The save interval is 40000 iterations.
    by_epoch=False)  ## Count by iterations.
evaluation = dict(  ## The config to build the evaluation hook.
    interval=40000)  ## Evaluation interval.
log_config = dict(  ## Config to register logger hook.
    interval=10,  ## Interval to print the log.
    hooks=[
        dict(type='TextLoggerHook', by_epoch=False),  ## The logger used to record the training process.
        ## dict(type='TensorboardLoggerHook')  # The Tensorboard logger is also supported.
    ])

## runtime settings
total_iters = 1000000  ## Total iterations to train the model.
dist_params = dict(backend='nccl')  ## Parameters to setup distributed training, the port can also be set.
log_level = 'INFO'  ## The level of logging.
work_dir = './work_dirs/dim_stage1'  ## Directory to save the model checkpoints and logs for the current experiments.
load_from = None  ## load models as a pre-trained model from a given path. This will not resume training.
resume_from = None  ## Resume checkpoints from a given path, the training will be resumed from the epoch when the checkpoint's is saved.
workflow = [('train', 1)]  ## Workflow for runner. [('train', 1)] means there is only one workflow and the workflow named 'train' is executed once. Keep this unchanged when training current matting models.

Config System for Restoration

An Example - EDSR

To help the users have a basic idea of a complete config, we make a brief comments on the config of the EDSR model we implemented as the following. For more detailed usage and the corresponding alternative for each modules, please refer to the API documentation.

exp_name = 'edsr_x2c64b16_1x16_300k_div2k'  ## The experiment name

scale = 2  ## Scale factor for upsampling
## model settings
model = dict(
    type='BasicRestorer',  ## Name of the model
    generator=dict(  ## Config of the generator
        type='EDSR',  ## Type of the generator
        in_channels=3,  ## Channel number of inputs
        out_channels=3,  ## Channel number of outputs
        mid_channels=64,  ## Channel number of intermediate features
        num_blocks=16,  ## Block number in the trunk network
        upscale_factor=scale, ## Upsampling factor
        res_scale=1,  ## Used to scale the residual in residual block
        rgb_mean=(0.4488, 0.4371, 0.4040),  ## Image mean in RGB orders
        rgb_std=(1.0, 1.0, 1.0)),  ## Image std in RGB orders
    pixel_loss=dict(type='L1Loss', loss_weight=1.0, reduction='mean'))  ## Config for pixel loss
## model training and testing settings
train_cfg = None  ## Training config
test_cfg = dict(  ## Test config
    metrics=['PSNR'],  ## Metrics used during testing
    crop_border=scale)  ## Crop border during evaluation

## dataset settings
train_dataset_type = 'SRAnnotationDataset'  ## Dataset type for training
val_dataset_type = 'SRFolderDataset'  ##  Dataset type for validation
train_pipeline = [  ## Training data processing pipeline
    dict(type='LoadImageFromFile',  ## Load images from files
        io_backend='disk',  ## io backend
        key='lq',  ## Keys in results to find corresponding path
        flag='unchanged'),  ## flag for reading images
    dict(type='LoadImageFromFile',  ## Load images from files
        io_backend='disk',  ## io backend
        key='gt',  ## Keys in results to find corresponding path
        flag='unchanged'),  ## flag for reading images
    dict(type='RescaleToZeroOne', keys=['lq', 'gt']),  ## Rescale images from [0, 255] to [0, 1]
    dict(type='Normalize',  ## Augmentation pipeline that normalize the input images
        keys=['lq', 'gt'],  ## Images to be normalized
        mean=[0, 0, 0],  ## Mean values
        std=[1, 1, 1],  ## Standard variance
        to_rgb=True),  ## Change to RGB channel
    dict(type='PairedRandomCrop', gt_patch_size=96),  ## Paired random crop
    dict(type='Flip',  ## Flip images
        keys=['lq', 'gt'],  ## Images to be flipped
        flip_ratio=0.5,  ## Flip ratio
        direction='horizontal'),  ## Flip direction
    dict(type='Flip',  ## Flip images
        keys=['lq', 'gt'],  ## Images to be flipped
        flip_ratio=0.5,  ## Flip ratio
        direction='vertical'),  ## Flip direction
    dict(type='RandomTransposeHW',  ## Random transpose h and w for images
        keys=['lq', 'gt'],  ## Images to be transposed
        transpose_ratio=0.5  ## Transpose ratio
        ),
    dict(type='Collect',  ## Pipeline that decides which keys in the data should be passed to the model
        keys=['lq', 'gt'],  ## Keys to pass to the model
        meta_keys=['lq_path', 'gt_path']), ## Meta information keys. In training, meta information is not needed
    dict(type='ImageToTensor',  ## Convert images to tensor
        keys=['lq', 'gt'])  ## Images to be converted to Tensor
]
test_pipeline = [  ## Test pipeline
    dict(
        type='LoadImageFromFile',  ## Load images from files
        io_backend='disk',  ## io backend
        key='lq',  ## Keys in results to find corresponding path
        flag='unchanged'),  ## flag for reading images
    dict(
        type='LoadImageFromFile',  ## Load images from files
        io_backend='disk',  ## io backend
        key='gt',  ## Keys in results to find corresponding path
        flag='unchanged'),  ## flag for reading images
    dict(type='RescaleToZeroOne', keys=['lq', 'gt']),  ## Rescale images from [0, 255] to [0, 1]
    dict(
        type='Normalize',  ## Augmentation pipeline that normalize the input images
        keys=['lq', 'gt'],  ## Images to be normalized
        mean=[0, 0, 0],  ## Mean values
        std=[1, 1, 1],  ## Standard variance
        to_rgb=True),  ## Change to RGB channel
    dict(type='Collect',  ## Pipeline that decides which keys in the data should be passed to the model
        keys=['lq', 'gt'],  ## Keys to pass to the model
        meta_keys=['lq_path', 'gt_path']),  ## Meta information keys
    dict(type='ImageToTensor',  ## Convert images to tensor
        keys=['lq', 'gt'])  ## Images to be converted to Tensor
]

data = dict(
    ## train
    samples_per_gpu=16,  ## Batch size of a single GPU
    workers_per_gpu=6,  ## Worker to pre-fetch data for each single GPU
    drop_last=True,  ## Use drop_last in data_loader
    train=dict(  ## Train dataset config
        type='RepeatDataset',  ## Repeated dataset for iter-based model
        times=1000,  ## Repeated times for RepeatDataset
        dataset=dict(
            type=train_dataset_type,  ## Type of dataset
            lq_folder='data/DIV2K/DIV2K_train_LR_bicubic/X2_sub',  ## Path for lq folder
            gt_folder='data/DIV2K/DIV2K_train_HR_sub',  ## Path for gt folder
            ann_file='data/DIV2K/meta_info_DIV2K800sub_GT.txt',  ## Path for annotation file
            pipeline=train_pipeline,  ## See above for train_pipeline
            scale=scale)),  ## Scale factor for upsampling
    ## val
    val_samples_per_gpu=1,  ## Batch size of a single GPU for validation
    val_workers_per_gpu=1,  ## Worker to pre-fetch data for each single GPU for validation
    val=dict(
        type=val_dataset_type,  ## Type of dataset
        lq_folder='data/val_set5/Set5_bicLRx2',  ## Path for lq folder
        gt_folder='data/val_set5/Set5_mod12',  ## Path for gt folder
        pipeline=test_pipeline,  ## See above for test_pipeline
        scale=scale,  ## Scale factor for upsampling
        filename_tmpl='{}'),  ## filename template
    ## test
    test=dict(
        type=val_dataset_type,  ## Type of dataset
        lq_folder='data/val_set5/Set5_bicLRx2',  ## Path for lq folder
        gt_folder='data/val_set5/Set5_mod12',  ## Path for gt folder
        pipeline=test_pipeline,  ## See above for test_pipeline
        scale=scale,  ## Scale factor for upsampling
        filename_tmpl='{}'))  ## filename template

## optimizer
optimizers = dict(generator=dict(type='Adam', lr=1e-4, betas=(0.9, 0.999)))  ## Config used to build optimizer, support all the optimizers in PyTorch whose arguments are also the same as those in PyTorch

## learning policy
total_iters = 300000  ## Total training iters
lr_config = dict( ## Learning rate scheduler config used to register LrUpdater hook
    policy='Step', by_epoch=False, step=[200000], gamma=0.5)  ## The policy of scheduler, also support CosineAnnealing, Cyclic, etc. Refer to details of supported LrUpdater from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/lr_updater.py#L9.

checkpoint_config = dict(  ## Config to set the checkpoint hook, Refer to https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/checkpoint.py for implementation.
    interval=5000,  ## The save interval is 5000 iterations
    save_optimizer=True,  ## Also save optimizers
    by_epoch=False)  ## Count by iterations
evaluation = dict(  ## The config to build the evaluation hook
    interval=5000,  ## Evaluation interval
    save_image=True,  ## Save images during evaluation
    gpu_collect=True)  ## Use gpu collect
log_config = dict(  ## Config to register logger hook
    interval=100,  ## Interval to print the log
    hooks=[
        dict(type='TextLoggerHook', by_epoch=False),  ## The logger used to record the training process
        dict(type='TensorboardLoggerHook'),  ## The Tensorboard logger is also supported
    ])
visual_config = None  ## Visual config, we do not use it.

## runtime settings
dist_params = dict(backend='nccl')  ## Parameters to setup distributed training, the port can also be set
log_level = 'INFO'  ## The level of logging
work_dir = f'./work_dirs/{exp_name}'  ## Directory to save the model checkpoints and logs for the current experiments
load_from = None ## load models as a pre-trained model from a given path. This will not resume training
resume_from = None ## Resume checkpoints from a given path, the training will be resumed from the iteration when the checkpoint's is saved
workflow = [('train', 1)]  ## Workflow for runner. [('train', 1)] means there is only one workflow and the workflow named 'train' is executed once. Keep this unchanged when training current matting models
Read the Docs v: v0.15.0
Versions
latest
stable
v0.15.0
v0.14.0
v0.13.0
v0.12.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.