Note
You are reading the documentation for MMEditing 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMEditing 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMEditing 1.0 for more details.
Learn about Configs¶
We use python files as our config system. You can find all the provided configs under $MMEditing/configs
.
Config Name Style¶
We follow the below style to name config files. Contributors are advised to follow the same style.
{model}_[model setting]_{backbone}_[refiner]_[norm setting]_[misc]_[gpu x batch_per_gpu]_{schedule}_{dataset}
{xxx}
is required field and [yyy]
is optional.
{model}
: model type likesrcnn
,dim
, etc.[model setting]
: specific setting for some model, likeresolution
for input images,stage name
for training, etc.{backbone}
: backbone type liker50
(ResNet-50),x101
(ResNeXt-101).[refiner]
: refiner type likepln
(Plain Refiner).[norm_setting]
:bn
(Batch Normalization) is used unless specified, other norm layer type could begn
(Group Normalization),syncbn
(Synchronized Batch Normalization).[misc]
: miscellaneous setting/plugins of model, e.g.dconv
,gcb
,attention
,albu
,mstrain
.[gpu x batch_per_gpu]
: GPUs and samples per GPU,8x2
is used by default.{schedule}
: training schedule,20k
,100k
, etc.20k
means 20,000 iterations.100k
means 100,000 iterations.{dataset}
: dataset likeplaces
(for inpainting),comp1k
(for matting),div2k
(for restoration) andpaired
(for generation).
Config System for Generation¶
Same as MMDetection, we incorporate modular and inheritance design into our config system, which is convenient to conduct various experiments.
An Example - pix2pix¶
To help the users have a basic idea of a complete config and the modules in a generation system, we make brief comments on the config of pix2pix as the following. For more detailed usage and the corresponding alternative for each modules, please refer to the API documentation.
## model settings
model = dict(
type='Pix2Pix', ## The name of synthesizer
generator=dict(
type='UnetGenerator', ## The name of generator
in_channels=3, ## The input channels of generator
out_channels=3, ## The output channels of generator
num_down=8, ## The umber of downsamplings in the generator
base_channels=64, ## The number of channels at the last conv layer of generator
norm_cfg=dict(type='BN'), ## The config of norm layer
use_dropout=True, ## Whether to use dropout layers in the generator
init_cfg=dict(type='normal', gain=0.02)), ## The config of initialization
discriminator=dict(
type='PatchDiscriminator', ## The name of discriminator
in_channels=6, ## The input channels of discriminator
base_channels=64, ## The number of channels at the first conv layer of discriminator
num_conv=3, ## The number of stacked intermediate conv layers (excluding input and output conv layer) in the discriminator
norm_cfg=dict(type='BN'), ## The config of norm layer
init_cfg=dict(type='normal', gain=0.02)), ## The config of initialization
gan_loss=dict(
type='GANLoss', ## The name of GAN loss
gan_type='vanilla', ## The type of GAN loss
real_label_val=1.0, ## The value for real label of GAN loss
fake_label_val=0.0, ## The value for fake label of GAN loss
loss_weight=1.0), ## The weight of GAN loss
pixel_loss=dict(type='L1Loss', loss_weight=100.0, reduction='mean'))
## model training and testing settings
train_cfg = dict(
direction='b2a') ## Image-to-image translation direction (the model training direction, same as testing direction) for pix2pix. Model default: a2b
test_cfg = dict(
direction='b2a', ## Image-to-image translation direction (the model training direction, same as testing direction) for pix2pix. Model default: a2b
show_input=True) ## Whether to show input real images when saving testing images for pix2pix
## dataset settings
train_dataset_type = 'GenerationPairedDataset' ## The type of dataset for training
val_dataset_type = 'GenerationPairedDataset' ## The type of dataset for validation/testing
img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ## Image normalization config to normalize the input images
train_pipeline = [
dict(
type='LoadPairedImageFromFile', ## Load a pair of images from file path pipeline
io_backend='disk', ## IO backend where images are store
key='pair', ## Keys to find corresponding path
flag='color'), ## Loading flag for images
dict(
type='Resize', ## Resize pipeline
keys=['img_a', 'img_b'], ## The keys of images to be resized
scale=(286, 286), ## The scale to resize images
interpolation='bicubic'), ## Algorithm used for interpolation when resizing images
dict(
type='FixedCrop', ## Fixed crop pipeline, cropping paired images to a specific size at a specific position for pix2pix training
keys=['img_a', 'img_b'], ## The keys of images to be cropped
crop_size=(256, 256)), ## The size to crop images
dict(
type='Flip', ## Flip pipeline
keys=['img_a', 'img_b'], ## The keys of images to be flipped
direction='horizontal'), ## Flip images horizontally or vertically
dict(
type='RescaleToZeroOne', ## Rescale images from [0, 255] to [0, 1]
keys=['img_a', 'img_b']), ## The keys of images to be rescaled
dict(
type='Normalize', ## Image normalization pipeline
keys=['img_a', 'img_b'], ## The keys of images to be normalized
to_rgb=True, ## Whether to convert image channels from BGR to RGB
**img_norm_cfg), ## Image normalization config (see above for the definition of `img_norm_cfg`)
dict(
type='ImageToTensor', ## Image to tensor pipeline
keys=['img_a', 'img_b']), ## The keys of images to be converted from image to tensor
dict(
type='Collect', ## Pipeline that decides which keys in the data should be passed to the synthesizer
keys=['img_a', 'img_b'], ## The keys of images
meta_keys=['img_a_path', 'img_b_path']) ## The meta keys of images
]
test_pipeline = [
dict(
type='LoadPairedImageFromFile', ## Load a pair of images from file path pipeline
io_backend='disk', ## IO backend where images are store
key='pair', ## Keys to find corresponding path
flag='color'), ## Loading flag for images
dict(
type='Resize', ## Resize pipeline
keys=['img_a', 'img_b'], ## The keys of images to be resized
scale=(256, 256), ## The scale to resize images
interpolation='bicubic'), ## Algorithm used for interpolation when resizing images
dict(
type='RescaleToZeroOne', ## Rescale images from [0, 255] to [0, 1]
keys=['img_a', 'img_b']), ## The keys of images to be rescaled
dict(
type='Normalize', ## Image normalization pipeline
keys=['img_a', 'img_b'], ## The keys of images to be normalized
to_rgb=True, ## Whether to convert image channels from BGR to RGB
**img_norm_cfg), ## Image normalization config (see above for the definition of `img_norm_cfg`)
dict(
type='ImageToTensor', ## Image to tensor pipeline
keys=['img_a', 'img_b']), ## The keys of images to be converted from image to tensor
dict(
type='Collect', ## Pipeline that decides which keys in the data should be passed to the synthesizer
keys=['img_a', 'img_b'], ## The keys of images
meta_keys=['img_a_path', 'img_b_path']) ## The meta keys of images
]
data_root = 'data/pix2pix/facades' ## The root path of data
data = dict(
samples_per_gpu=1, ## Batch size of a single GPU
workers_per_gpu=4, ## Worker to pre-fetch data for each single GPU
drop_last=True, ## Whether to drop out the last batch of data in training
val_samples_per_gpu=1, ## Batch size of a single GPU in validation
val_workers_per_gpu=0, ## Worker to pre-fetch data for each single GPU in validation
train=dict( ## Training dataset config
type=train_dataset_type,
dataroot=data_root,
pipeline=train_pipeline,
test_mode=False),
val=dict( ## Validation dataset config
type=val_dataset_type,
dataroot=data_root,
pipeline=test_pipeline,
test_mode=True),
test=dict( ## Testing dataset config
type=val_dataset_type,
dataroot=data_root,
pipeline=test_pipeline,
test_mode=True))
## optimizer
optimizers = dict( ## Config used to build optimizer, support all the optimizers in PyTorch whose arguments are also the same as those in PyTorch
generator=dict(type='Adam', lr=2e-4, betas=(0.5, 0.999)),
discriminator=dict(type='Adam', lr=2e-4, betas=(0.5, 0.999)))
## learning policy
lr_config = dict(policy='Fixed', by_epoch=False) ## Learning rate scheduler config used to register LrUpdater hook
## checkpoint saving
checkpoint_config = dict(interval=4000, save_optimizer=True, by_epoch=False) ## Config to set the checkpoint hook, Refer to https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/checkpoint.py for implementation.
evaluation = dict( ## The config to build the evaluation hook
interval=4000, ## Evaluation interval
save_image=True) ## Whether to save images
log_config = dict( ## config to register logger hook
interval=100, ## Interval to print the log
hooks=[
dict(type='TextLoggerHook', by_epoch=False), ## The logger used to record the training process
## dict(type='TensorboardLoggerHook') # The Tensorboard logger is also supported
])
visual_config = None ## The config to build the visualization hook
## runtime settings
total_iters = 80000 ## Total iterations to train the model
cudnn_benchmark = True ## Set cudnn_benchmark
dist_params = dict(backend='nccl') ## Parameters to setup distributed training, the port can also be set
log_level = 'INFO' ## The level of logging
load_from = None ## Load models as a pre-trained model from a given path. This will not resume training
resume_from = None ## Resume checkpoints from a given path, the training will be resumed from the epoch when the checkpoint's is saved
workflow = [('train', 1)] ## Workflow for runner. [('train', 1)] means there is only one workflow and the workflow named 'train' is executed once. Keep this unchanged when training current generation models
exp_name = 'pix2pix_facades' ## The experiment name
work_dir = f'./work_dirs/{exp_name}' ## Directory to save the model checkpoints and logs for the current experiments.
Config System for Inpainting¶
Config Name Style¶
Same as MMDetection, we incorporate modular and inheritance design into our config system, which is convenient to conduct various experiments.
Config Field Description¶
To help the users have a basic idea of a complete config and the modules in a inpainting system, we make brief comments on the config of Global&Local as the following. For more detailed usage and the corresponding alternative for each modules, please refer to the API documentation.
model = dict(
type='GLInpaintor', ## The name of inpaintor
encdec=dict(
type='GLEncoderDecoder', ## The name of encoder-decoder
encoder=dict(type='GLEncoder', norm_cfg=dict(type='SyncBN')), ## The config of encoder
decoder=dict(type='GLDecoder', norm_cfg=dict(type='SyncBN')), ## The config of decoder
dilation_neck=dict(
type='GLDilationNeck', norm_cfg=dict(type='SyncBN'))), ## The config of dilation neck
disc=dict(
type='GLDiscs', ## The name of discriminator
global_disc_cfg=dict(
in_channels=3, ## The input channel of discriminator
max_channels=512, ## The maximum middle channel in discriminator
fc_in_channels=512 * 4 * 4, ## The input channel of last fc layer
fc_out_channels=1024, ## The output channel of last fc channel
num_convs=6, ## The number of convs used in discriminator
norm_cfg=dict(type='SyncBN') ## The config of norm layer
),
local_disc_cfg=dict(
in_channels=3, ## The input channel of discriminator
max_channels=512, ## The maximum middle channel in discriminator
fc_in_channels=512 * 4 * 4, ## The input channel of last fc layer
fc_out_channels=1024, ## The output channel of last fc channel
num_convs=5, ## The number of convs used in discriminator
norm_cfg=dict(type='SyncBN') ## The config of norm layer
),
),
loss_gan=dict(
type='GANLoss', ## The name of GAN loss
gan_type='vanilla', ## The type of GAN loss
loss_weight=0.001 ## The weight of GAN loss
),
loss_l1_hole=dict(
type='L1Loss', ## The type of l1 loss
loss_weight=1.0 ## The weight of l1 loss
),
pretrained=None) ## The path of pretrained weight
train_cfg = dict(
disc_step=1, ## The steps of training discriminator before training generator
iter_tc=90000, ## Iterations of warming up generator
iter_td=100000, ## Iterations of warming up discriminator
start_iter=0, ## Starting iteration
local_size=(128, 128)) ## The size of local patches
test_cfg = dict(metrics=['l1']) ## The config of testing scheme
dataset_type = 'ImgInpaintingDataset' ## The type of dataset
input_shape = (256, 256) ## The shape of input image
train_pipeline = [
dict(type='LoadImageFromFile', key='gt_img'), ## The config of loading image
dict(
type='LoadMask', ## The type of loading mask pipeline
mask_mode='bbox', ## The type of mask
mask_config=dict(
max_bbox_shape=(128, 128), ## The shape of bbox
max_bbox_delta=40, ## The changing delta of bbox height and width
min_margin=20, ## The minimum margin from bbox to the image border
img_shape=input_shape)), ## The input image shape
dict(
type='Crop', ## The type of crop pipeline
keys=['gt_img'], ## The keys of images to be cropped
crop_size=(384, 384), ## The size of cropped patch
random_crop=True, ## Whether to use random crop
),
dict(
type='Resize', ## The type of resizing pipeline
keys=['gt_img'], ## They keys of images to be resized
scale=input_shape, ## The scale of resizing function
keep_ratio=False, ## Whether to keep ratio during resizing
),
dict(
type='Normalize', ## The type of normalizing pipeline
keys=['gt_img'], ## The keys of images to be normed
mean=[127.5] * 3, ## Mean value used in normalization
std=[127.5] * 3, ## Std value used in normalization
to_rgb=False), ## Whether to transfer image channels to rgb
dict(type='GetMaskedImage'), ## The config of getting masked image pipeline
dict(
type='Collect', ## The type of collecting data from current pipeline
keys=['gt_img', 'masked_img', 'mask', 'mask_bbox'], ## The keys of data to be collected
meta_keys=['gt_img_path']), ## The meta keys of data to be collected
dict(type='ImageToTensor', keys=['gt_img', 'masked_img', 'mask']), ## The config dict of image to tensor pipeline
dict(type='ToTensor', keys=['mask_bbox']) ## The config dict of ToTensor pipeline
]
test_pipeline = train_pipeline ## Constructing testing/validation pipeline
data_root = 'data/places365' ## Set data root
data = dict(
samples_per_gpu=12, ## Batch size of a single GPU
workers_per_gpu=8, ## Worker to pre-fetch data for each single GPU
val_samples_per_gpu=1, ## Batch size of a single GPU in validation
val_workers_per_gpu=8, ## Worker to pre-fetch data for each single GPU in validation
drop_last=True, ## Whether to drop out the last batch of data
train=dict( ## Train dataset config
type=dataset_type,
ann_file=f'{data_root}/train_places_img_list_total.txt',
data_prefix=data_root,
pipeline=train_pipeline,
test_mode=False),
val=dict( ## Validation dataset config
type=dataset_type,
ann_file=f'{data_root}/val_places_img_list.txt',
data_prefix=data_root,
pipeline=test_pipeline,
test_mode=True))
optimizers = dict( ## Config used to build optimizer, support all the optimizers in PyTorch whose arguments are also the same as those in PyTorch
generator=dict(type='Adam', lr=0.0004), disc=dict(type='Adam', lr=0.0004))
lr_config = dict(policy='Fixed', by_epoch=False) ## Learning rate scheduler config used to register LrUpdater hook
checkpoint_config = dict(by_epoch=False, interval=50000) ## Config to set the checkpoint hook, Refer to https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/checkpoint.py for implementation.
log_config = dict( ## config to register logger hook
interval=100, ## Interval to print the log
hooks=[
dict(type='TextLoggerHook', by_epoch=False),
## dict(type='TensorboardLoggerHook'), # The Tensorboard logger is also supported
## dict(type='PaviLoggerHook', init_kwargs=dict(project='mmedit'))
]) ## The logger used to record the training process.
visual_config = dict( ## config to register visualization hook
type='MMEditVisualizationHook',
output_dir='visual',
interval=1000,
res_name_list=[
'gt_img', 'masked_img', 'fake_res', 'fake_img', 'fake_gt_local'
],
) ## The logger used to visualize the training process.
evaluation = dict(interval=50000) ## The config to build the evaluation hook
total_iters = 500002
dist_params = dict(backend='nccl') ## Parameters to setup distributed training, the port can also be set.
log_level = 'INFO' ## The level of logging.
work_dir = None ## Directory to save the model checkpoints and logs for the current experiments.
load_from = None ## load models as a pre-trained model from a given path. This will not resume training.
resume_from = None ## Resume checkpoints from a given path, the training will be resumed from the epoch when the checkpoint's is saved.
workflow = [('train', 10000)] ## Workflow for runner. [('train', 1)] means there is only one workflow and the workflow named 'train' is executed once. The workflow trains the model by 12 epochs according to the total_epochs.
exp_name = 'gl_places' ## The experiment name
find_unused_parameters = False ## Whether to set find unused parameters in ddp
Config System for Matting¶
Same as MMDetection, we incorporate modular and inheritance design into our config system, which is convenient to conduct various experiments.
An Example - Deep Image Matting Model¶
To help the users have a basic idea of a complete config, we make a brief comments on the config of the original DIM model we implemented as the following. For more detailed usage and the corresponding alternative for each modules, please refer to the API documentation.
## model settings
model = dict(
type='DIM', ## The name of model (we call mattor).
backbone=dict( ## The config of the backbone.
type='SimpleEncoderDecoder', ## The type of the backbone.
encoder=dict( ## The config of the encoder.
type='VGG16'), ## The type of the encoder.
decoder=dict( ## The config of the decoder.
type='PlainDecoder')), ## The type of the decoder.
pretrained='./weights/vgg_state_dict.pth', ## The pretrained weight of the encoder to be loaded.
loss_alpha=dict( ## The config of the alpha loss.
type='CharbonnierLoss', ## The type of the loss for predicted alpha matte.
loss_weight=0.5), ## The weight of the alpha loss.
loss_comp=dict( ## The config of the composition loss.
type='CharbonnierCompLoss', ## The type of the composition loss.
loss_weight=0.5)) ## The weight of the composition loss.
train_cfg = dict( ## Config of training DIM model.
train_backbone=True, ## In DIM stage1, backbone is trained.
train_refiner=False) ## In DIM stage1, refiner is not trained.
test_cfg = dict( ## Config of testing DIM model.
refine=False, ## Whether use refiner output as output, in stage1, we don't use it.
metrics=['SAD', 'MSE', 'GRAD', 'CONN']) ## The metrics used when testing.
## data settings
dataset_type = 'AdobeComp1kDataset' ## Dataset type, this will be used to define the dataset.
data_root = 'data/adobe_composition-1k' ## Root path of data.
img_norm_cfg = dict( ## Image normalization config to normalize the input images.
mean=[0.485, 0.456, 0.406], ## Mean values used to pre-training the pre-trained backbone models.
std=[0.229, 0.224, 0.225], ## Standard variance used to pre-training the pre-trained backbone models.
to_rgb=True) ## The channel orders of image used to pre-training the pre-trained backbone models.
train_pipeline = [ ## Training data processing pipeline.
dict(
type='LoadImageFromFile', ## Load alpha matte from file.
key='alpha', ## Key of alpha matte in annotation file. The pipeline will read alpha matte from path `alpha_path`.
flag='grayscale'), ## Load as grayscale image which has shape (height, width).
dict(
type='LoadImageFromFile', ## Load image from file.
key='fg'), ## Key of image to load. The pipeline will read fg from path `fg_path`.
dict(
type='LoadImageFromFile', ## Load image from file.
key='bg'), ## Key of image to load. The pipeline will read bg from path `bg_path`.
dict(
type='LoadImageFromFile', ## Load image from file.
key='merged'), ## Key of image to load. The pipeline will read merged from path `merged_path`.
dict(
type='CropAroundUnknown', ## Crop images around unknown area (semi-transparent area).
keys=['alpha', 'merged', 'ori_merged', 'fg', 'bg'], ## Images to crop.
crop_sizes=[320, 480, 640]), ## Candidate crop size.
dict(
type='Flip', ## Augmentation pipeline that flips the images.
keys=['alpha', 'merged', 'ori_merged', 'fg', 'bg']), ## Images to be flipped.
dict(
type='Resize', ## Augmentation pipeline that resizes the images.
keys=['alpha', 'merged', 'ori_merged', 'fg', 'bg'], ## Images to be resized.
scale=(320, 320), ## Target size.
keep_ratio=False), ## Whether to keep the ratio between height and width.
dict(
type='GenerateTrimap', ## Generate trimap from alpha matte.
kernel_size=(1, 30)), ## Kernel size range of the erode/dilate kernel.
dict(
type='RescaleToZeroOne', ## Rescale images from [0, 255] to [0, 1].
keys=['merged', 'alpha', 'ori_merged', 'fg', 'bg']), ## Images to be rescaled.
dict(
type='Normalize', ## Augmentation pipeline that normalize the input images.
keys=['merged'], ## Images to be normalized.
**img_norm_cfg), ## Normalization config. See above for definition of `img_norm_cfg`
dict(
type='Collect', ## Pipeline that decides which keys in the data should be passed to the model
keys=['merged', 'alpha', 'trimap', 'ori_merged', 'fg', 'bg'], ## Keys to pass to the model
meta_keys=[]), ## Meta information keys. In training, meta information is not needed.
dict(
type='ImageToTensor', ## Convert images to tensor.
keys=['merged', 'alpha', 'trimap', 'ori_merged', 'fg', 'bg']), ## Images to be converted to Tensor.
]
test_pipeline = [
dict(
type='LoadImageFromFile', ## Load alpha matte.
key='alpha', ## Key of alpha matte in annotation file. The pipeline will read alpha matte from path `alpha_path`.
flag='grayscale',
save_original_img=True),
dict(
type='LoadImageFromFile', ## Load image from file
key='trimap', ## Key of image to load. The pipeline will read trimap from path `trimap_path`.
flag='grayscale', ## Load as grayscale image which has shape (height, width).
save_original_img=True), ## Save a copy of trimap for calculating metrics. It will be saved with key `ori_trimap`
dict(
type='LoadImageFromFile', ## Load image from file
key='merged'), ## Key of image to load. The pipeline will read merged from path `merged_path`.
dict(
type='Pad', ## Pipeline to pad images to align with the downsample factor of the model.
keys=['trimap', 'merged'], ## Images to be padded.
mode='reflect'), ## Mode of the padding.
dict(
type='RescaleToZeroOne', ## Same as it in train_pipeline.
keys=['merged', 'ori_alpha']), ## Images to be rescaled.
dict(
type='Normalize', ## Same as it in train_pipeline.
keys=['merged'],
**img_norm_cfg),
dict(
type='Collect', ## Same as it in train_pipeline.
keys=['merged', 'trimap'],
meta_keys=[
'merged_path', 'pad', 'merged_ori_shape', 'ori_alpha',
'ori_trimap'
]),
dict(
type='ImageToTensor', ## Same as it in train_pipeline.
keys=['merged', 'trimap']),
]
data = dict(
samples_per_gpu=1, ## Batch size of a single GPU.
workers_per_gpu=4, ## Worker to pre-fetch data for each single GPU.
drop_last=True, ## Use drop_last in data_loader.
train=dict( ## Train dataset config.
type=dataset_type, ## Type of dataset.
ann_file=f'{data_root}/training_list.json', ## Path of annotation file
data_prefix=data_root, ## Prefix of image path.
pipeline=train_pipeline), ## See above for train_pipeline
val=dict( ## Validation dataset config.
type=dataset_type,
ann_file=f'{data_root}/test_list.json',
data_prefix=data_root,
pipeline=test_pipeline), ## See above for test_pipeline
test=dict( ## Test dataset config.
type=dataset_type,
ann_file=f'{data_root}/test_list.json',
data_prefix=data_root,
pipeline=test_pipeline)) ## See above for test_pipeline
## optimizer
optimizers = dict(type='Adam', lr=0.00001) ## Config used to build optimizer, support all the optimizers in PyTorch whose arguments are also the same as those in PyTorch.
## learning policy
lr_config = dict( ## Learning rate scheduler config used to register LrUpdater hook
policy='Fixed') ## The policy of scheduler, also support CosineAnnealing, Cyclic, etc. Refer to details of supported LrUpdater from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/lr_updater.py#L9.
## checkpoint saving
checkpoint_config = dict( ## Config to set the checkpoint hook, Refer to https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/checkpoint.py for implementation.
interval=40000, ## The save interval is 40000 iterations.
by_epoch=False) ## Count by iterations.
evaluation = dict( ## The config to build the evaluation hook.
interval=40000) ## Evaluation interval.
log_config = dict( ## Config to register logger hook.
interval=10, ## Interval to print the log.
hooks=[
dict(type='TextLoggerHook', by_epoch=False), ## The logger used to record the training process.
## dict(type='TensorboardLoggerHook') # The Tensorboard logger is also supported.
])
## runtime settings
total_iters = 1000000 ## Total iterations to train the model.
dist_params = dict(backend='nccl') ## Parameters to setup distributed training, the port can also be set.
log_level = 'INFO' ## The level of logging.
work_dir = './work_dirs/dim_stage1' ## Directory to save the model checkpoints and logs for the current experiments.
load_from = None ## load models as a pre-trained model from a given path. This will not resume training.
resume_from = None ## Resume checkpoints from a given path, the training will be resumed from the epoch when the checkpoint's is saved.
workflow = [('train', 1)] ## Workflow for runner. [('train', 1)] means there is only one workflow and the workflow named 'train' is executed once. Keep this unchanged when training current matting models.
Config System for Restoration¶
An Example - EDSR¶
To help the users have a basic idea of a complete config, we make a brief comments on the config of the EDSR model we implemented as the following. For more detailed usage and the corresponding alternative for each modules, please refer to the API documentation.
exp_name = 'edsr_x2c64b16_1x16_300k_div2k' ## The experiment name
scale = 2 ## Scale factor for upsampling
## model settings
model = dict(
type='BasicRestorer', ## Name of the model
generator=dict( ## Config of the generator
type='EDSR', ## Type of the generator
in_channels=3, ## Channel number of inputs
out_channels=3, ## Channel number of outputs
mid_channels=64, ## Channel number of intermediate features
num_blocks=16, ## Block number in the trunk network
upscale_factor=scale, ## Upsampling factor
res_scale=1, ## Used to scale the residual in residual block
rgb_mean=(0.4488, 0.4371, 0.4040), ## Image mean in RGB orders
rgb_std=(1.0, 1.0, 1.0)), ## Image std in RGB orders
pixel_loss=dict(type='L1Loss', loss_weight=1.0, reduction='mean')) ## Config for pixel loss
## model training and testing settings
train_cfg = None ## Training config
test_cfg = dict( ## Test config
metrics=['PSNR'], ## Metrics used during testing
crop_border=scale) ## Crop border during evaluation
## dataset settings
train_dataset_type = 'SRAnnotationDataset' ## Dataset type for training
val_dataset_type = 'SRFolderDataset' ## Dataset type for validation
train_pipeline = [ ## Training data processing pipeline
dict(type='LoadImageFromFile', ## Load images from files
io_backend='disk', ## io backend
key='lq', ## Keys in results to find corresponding path
flag='unchanged'), ## flag for reading images
dict(type='LoadImageFromFile', ## Load images from files
io_backend='disk', ## io backend
key='gt', ## Keys in results to find corresponding path
flag='unchanged'), ## flag for reading images
dict(type='RescaleToZeroOne', keys=['lq', 'gt']), ## Rescale images from [0, 255] to [0, 1]
dict(type='Normalize', ## Augmentation pipeline that normalize the input images
keys=['lq', 'gt'], ## Images to be normalized
mean=[0, 0, 0], ## Mean values
std=[1, 1, 1], ## Standard variance
to_rgb=True), ## Change to RGB channel
dict(type='PairedRandomCrop', gt_patch_size=96), ## Paired random crop
dict(type='Flip', ## Flip images
keys=['lq', 'gt'], ## Images to be flipped
flip_ratio=0.5, ## Flip ratio
direction='horizontal'), ## Flip direction
dict(type='Flip', ## Flip images
keys=['lq', 'gt'], ## Images to be flipped
flip_ratio=0.5, ## Flip ratio
direction='vertical'), ## Flip direction
dict(type='RandomTransposeHW', ## Random transpose h and w for images
keys=['lq', 'gt'], ## Images to be transposed
transpose_ratio=0.5 ## Transpose ratio
),
dict(type='Collect', ## Pipeline that decides which keys in the data should be passed to the model
keys=['lq', 'gt'], ## Keys to pass to the model
meta_keys=['lq_path', 'gt_path']), ## Meta information keys. In training, meta information is not needed
dict(type='ImageToTensor', ## Convert images to tensor
keys=['lq', 'gt']) ## Images to be converted to Tensor
]
test_pipeline = [ ## Test pipeline
dict(
type='LoadImageFromFile', ## Load images from files
io_backend='disk', ## io backend
key='lq', ## Keys in results to find corresponding path
flag='unchanged'), ## flag for reading images
dict(
type='LoadImageFromFile', ## Load images from files
io_backend='disk', ## io backend
key='gt', ## Keys in results to find corresponding path
flag='unchanged'), ## flag for reading images
dict(type='RescaleToZeroOne', keys=['lq', 'gt']), ## Rescale images from [0, 255] to [0, 1]
dict(
type='Normalize', ## Augmentation pipeline that normalize the input images
keys=['lq', 'gt'], ## Images to be normalized
mean=[0, 0, 0], ## Mean values
std=[1, 1, 1], ## Standard variance
to_rgb=True), ## Change to RGB channel
dict(type='Collect', ## Pipeline that decides which keys in the data should be passed to the model
keys=['lq', 'gt'], ## Keys to pass to the model
meta_keys=['lq_path', 'gt_path']), ## Meta information keys
dict(type='ImageToTensor', ## Convert images to tensor
keys=['lq', 'gt']) ## Images to be converted to Tensor
]
data = dict(
## train
samples_per_gpu=16, ## Batch size of a single GPU
workers_per_gpu=6, ## Worker to pre-fetch data for each single GPU
drop_last=True, ## Use drop_last in data_loader
train=dict( ## Train dataset config
type='RepeatDataset', ## Repeated dataset for iter-based model
times=1000, ## Repeated times for RepeatDataset
dataset=dict(
type=train_dataset_type, ## Type of dataset
lq_folder='data/DIV2K/DIV2K_train_LR_bicubic/X2_sub', ## Path for lq folder
gt_folder='data/DIV2K/DIV2K_train_HR_sub', ## Path for gt folder
ann_file='data/DIV2K/meta_info_DIV2K800sub_GT.txt', ## Path for annotation file
pipeline=train_pipeline, ## See above for train_pipeline
scale=scale)), ## Scale factor for upsampling
## val
val_samples_per_gpu=1, ## Batch size of a single GPU for validation
val_workers_per_gpu=1, ## Worker to pre-fetch data for each single GPU for validation
val=dict(
type=val_dataset_type, ## Type of dataset
lq_folder='data/val_set5/Set5_bicLRx2', ## Path for lq folder
gt_folder='data/val_set5/Set5_mod12', ## Path for gt folder
pipeline=test_pipeline, ## See above for test_pipeline
scale=scale, ## Scale factor for upsampling
filename_tmpl='{}'), ## filename template
## test
test=dict(
type=val_dataset_type, ## Type of dataset
lq_folder='data/val_set5/Set5_bicLRx2', ## Path for lq folder
gt_folder='data/val_set5/Set5_mod12', ## Path for gt folder
pipeline=test_pipeline, ## See above for test_pipeline
scale=scale, ## Scale factor for upsampling
filename_tmpl='{}')) ## filename template
## optimizer
optimizers = dict(generator=dict(type='Adam', lr=1e-4, betas=(0.9, 0.999))) ## Config used to build optimizer, support all the optimizers in PyTorch whose arguments are also the same as those in PyTorch
## learning policy
total_iters = 300000 ## Total training iters
lr_config = dict( ## Learning rate scheduler config used to register LrUpdater hook
policy='Step', by_epoch=False, step=[200000], gamma=0.5) ## The policy of scheduler, also support CosineAnnealing, Cyclic, etc. Refer to details of supported LrUpdater from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/lr_updater.py#L9.
checkpoint_config = dict( ## Config to set the checkpoint hook, Refer to https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/checkpoint.py for implementation.
interval=5000, ## The save interval is 5000 iterations
save_optimizer=True, ## Also save optimizers
by_epoch=False) ## Count by iterations
evaluation = dict( ## The config to build the evaluation hook
interval=5000, ## Evaluation interval
save_image=True, ## Save images during evaluation
gpu_collect=True) ## Use gpu collect
log_config = dict( ## Config to register logger hook
interval=100, ## Interval to print the log
hooks=[
dict(type='TextLoggerHook', by_epoch=False), ## The logger used to record the training process
dict(type='TensorboardLoggerHook'), ## The Tensorboard logger is also supported
])
visual_config = None ## Visual config, we do not use it.
## runtime settings
dist_params = dict(backend='nccl') ## Parameters to setup distributed training, the port can also be set
log_level = 'INFO' ## The level of logging
work_dir = f'./work_dirs/{exp_name}' ## Directory to save the model checkpoints and logs for the current experiments
load_from = None ## load models as a pre-trained model from a given path. This will not resume training
resume_from = None ## Resume checkpoints from a given path, the training will be resumed from the iteration when the checkpoint's is saved
workflow = [('train', 1)] ## Workflow for runner. [('train', 1)] means there is only one workflow and the workflow named 'train' is executed once. Keep this unchanged when training current matting models