Shortcuts

Source code for mmedit.models.components.stylegan2.generator_discriminator

# Copyright (c) OpenMMLab. All rights reserved.
import random

import mmcv
import numpy as np
import torch
import torch.nn as nn
from mmcv.runner.checkpoint import _load_checkpoint_with_prefix

from mmedit.models.registry import COMPONENTS
from .common import get_mean_latent, get_module_device, style_mixing
from .modules import (ConstantInput, ConvDownLayer, EqualLinearActModule,
                      ModMBStddevLayer, ModulatedStyleConv, ModulatedToRGB,
                      PixelNorm, ResBlock)


[docs]@COMPONENTS.register_module() class StyleGANv2Generator(nn.Module): r"""StyleGAN2 Generator. This module comes from MMGeneration. In the future, this code will be removed and StyleGANv2 will be directly imported from mmgeneration. In StyleGAN2, we use a static architecture composing of a style mapping module and number of convolutional style blocks. More details can be found in: Analyzing and Improving the Image Quality of StyleGAN CVPR2020. You can load pretrained model through passing information into ``pretrained`` argument. We have already offered official weights as follows: - styelgan2-ffhq-config-f: http://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-ffhq-config-f-official_20210327_171224-bce9310c.pth # noqa - stylegan2-horse-config-f: http://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-horse-config-f-official_20210327_173203-ef3e69ca.pth # noqa - stylegan2-car-config-f: http://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-car-config-f-official_20210327_172340-8cfe053c.pth # noqa - styelgan2-cat-config-f: http://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-cat-config-f-official_20210327_172444-15bc485b.pth # noqa - stylegan2-church-config-f: http://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-church-config-f-official_20210327_172657-1d42b7d1.pth # noqa If you want to load the ema model, you can just use following codes: .. code-block:: python # ckpt_http is one of the valid path from http source generator = StyleGANv2Generator(1024, 512, pretrained=dict( ckpt_path=ckpt_http, prefix='generator_ema')) Of course, you can also download the checkpoint in advance and set ``ckpt_path`` with local path. If you just want to load the original generator (not the ema model), please set the prefix with 'generator'. Note that our implementation allows to generate BGR image, while the original StyleGAN2 outputs RGB images by default. Thus, we provide ``bgr2rgb`` argument to convert the image space. Args: out_size (int): The output size of the StyleGAN2 generator. style_channels (int): The number of channels for style code. num_mlps (int, optional): The number of MLP layers. Defaults to 8. channel_multiplier (int, optional): The multiplier factor for the channel number. Defaults to 2. blur_kernel (list, optional): The blurry kernel. Defaults to [1, 3, 3, 1]. lr_mlp (float, optional): The learning rate for the style mapping layer. Defaults to 0.01. default_style_mode (str, optional): The default mode of style mixing. In training, we defaultly adopt mixing style mode. However, in the evaluation, we use 'single' style mode. `['mix', 'single']` are currently supported. Defaults to 'mix'. eval_style_mode (str, optional): The evaluation mode of style mixing. Defaults to 'single'. mix_prob (float, optional): Mixing probability. The value should be in range of [0, 1]. Defaults to 0.9. pretrained (dict | None, optional): Information for pretained models. The necessary key is 'ckpt_path'. Besides, you can also provide 'prefix' to load the generator part from the whole state dict. Defaults to None. bgr2rgb (bool, optional): Whether to flip the image channel dimension. Defaults to False. """ def __init__(self, out_size, style_channels, num_mlps=8, channel_multiplier=2, blur_kernel=[1, 3, 3, 1], lr_mlp=0.01, default_style_mode='mix', eval_style_mode='single', mix_prob=0.9, pretrained=None, bgr2rgb=False): super(StyleGANv2Generator, self).__init__() self.out_size = out_size self.style_channels = style_channels self.num_mlps = num_mlps self.channel_multiplier = channel_multiplier self.lr_mlp = lr_mlp self._default_style_mode = default_style_mode self.default_style_mode = default_style_mode self.eval_style_mode = eval_style_mode self.mix_prob = mix_prob self.bgr2rgb = bgr2rgb # define style mapping layers mapping_layers = [PixelNorm()] for _ in range(num_mlps): mapping_layers.append( EqualLinearActModule( style_channels, style_channels, equalized_lr_cfg=dict(lr_mul=lr_mlp, gain=1.), act_cfg=dict(type='fused_bias'))) self.style_mapping = nn.Sequential(*mapping_layers) self.channels = { 4: 512, 8: 512, 16: 512, 32: 512, 64: 256 * channel_multiplier, 128: 128 * channel_multiplier, 256: 64 * channel_multiplier, 512: 32 * channel_multiplier, 1024: 16 * channel_multiplier, } # constant input layer self.constant_input = ConstantInput(self.channels[4]) # 4x4 stage self.conv1 = ModulatedStyleConv( self.channels[4], self.channels[4], kernel_size=3, style_channels=style_channels, blur_kernel=blur_kernel) self.to_rgb1 = ModulatedToRGB( self.channels[4], style_channels, upsample=False) # generator backbone (8x8 --> higher resolutions) self.log_size = int(np.log2(self.out_size)) self.convs = nn.ModuleList() self.upsamples = nn.ModuleList() self.to_rgbs = nn.ModuleList() in_channels_ = self.channels[4] for i in range(3, self.log_size + 1): out_channels_ = self.channels[2**i] self.convs.append( ModulatedStyleConv( in_channels_, out_channels_, 3, style_channels, upsample=True, blur_kernel=blur_kernel)) self.convs.append( ModulatedStyleConv( out_channels_, out_channels_, 3, style_channels, upsample=False, blur_kernel=blur_kernel)) self.to_rgbs.append( ModulatedToRGB(out_channels_, style_channels, upsample=True)) in_channels_ = out_channels_ self.num_latents = self.log_size * 2 - 2 self.num_injected_noises = self.num_latents - 1 # register buffer for injected noises for layer_idx in range(self.num_injected_noises): res = (layer_idx + 5) // 2 shape = [1, 1, 2**res, 2**res] self.register_buffer(f'injected_noise_{layer_idx}', torch.randn(*shape)) if pretrained is not None: self._load_pretrained_model(**pretrained) def _load_pretrained_model(self, ckpt_path, prefix='', map_location='cpu', strict=True): state_dict = _load_checkpoint_with_prefix(prefix, ckpt_path, map_location) self.load_state_dict(state_dict, strict=strict) mmcv.print_log(f'Load pretrained model from {ckpt_path}', 'mmedit')
[docs] def train(self, mode=True): if mode: if self.default_style_mode != self._default_style_mode: mmcv.print_log( f'Switch to train style mode: {self._default_style_mode}', 'mmgen') self.default_style_mode = self._default_style_mode else: if self.default_style_mode != self.eval_style_mode: mmcv.print_log( f'Switch to evaluation style mode: {self.eval_style_mode}', 'mmgen') self.default_style_mode = self.eval_style_mode return super(StyleGANv2Generator, self).train(mode)
def make_injected_noise(self): device = get_module_device(self) noises = [torch.randn(1, 1, 2**2, 2**2, device=device)] for i in range(3, self.log_size + 1): for _ in range(2): noises.append(torch.randn(1, 1, 2**i, 2**i, device=device)) return noises def get_mean_latent(self, num_samples=4096, **kwargs): return get_mean_latent(self, num_samples, **kwargs) def style_mixing(self, n_source, n_target, inject_index=1, truncation_latent=None, truncation=0.7): return style_mixing( self, n_source=n_source, n_target=n_target, inject_index=inject_index, truncation=truncation, truncation_latent=truncation_latent, style_channels=self.style_channels)
[docs] def forward(self, styles, num_batches=-1, return_noise=False, return_latents=False, inject_index=None, truncation=1, truncation_latent=None, input_is_latent=False, injected_noise=None, randomize_noise=True): """Forward function. This function has been integrated with the truncation trick. Please refer to the usage of `truncation` and `truncation_latent`. Args: styles (torch.Tensor | list[torch.Tensor] | callable | None): In StyleGAN2, you can provide noise tensor or latent tensor. Given a list containing more than one noise or latent tensors, style mixing trick will be used in training. Of course, You can directly give a batch of noise through a ``torch.Tensor`` or offer a callable function to sample a batch of noise data. Otherwise, the ``None`` indicates to use the default noise sampler. num_batches (int, optional): The number of batch size. Defaults to 0. return_noise (bool, optional): If True, ``noise_batch`` will be returned in a dict with ``fake_img``. Defaults to False. return_latents (bool, optional): If True, ``latent`` will be returned in a dict with ``fake_img``. Defaults to False. inject_index (int | None, optional): The index number for mixing style codes. Defaults to None. truncation (float, optional): Truncation factor. Give value less than 1., the truncation trick will be adopted. Defaults to 1. truncation_latent (torch.Tensor, optional): Mean truncation latent. Defaults to None. input_is_latent (bool, optional): If `True`, the input tensor is the latent tensor. Defaults to False. injected_noise (torch.Tensor | None, optional): Given a tensor, the random noise will be fixed as this input injected noise. Defaults to None. randomize_noise (bool, optional): If `False`, images are sampled with the buffered noise tensor injected to the style conv block. Defaults to True. Returns: torch.Tensor | dict: Generated image tensor or dictionary containing more data. """ # receive noise and conduct sanity check. if isinstance(styles, torch.Tensor): assert styles.shape[1] == self.style_channels styles = [styles] elif mmcv.is_seq_of(styles, torch.Tensor): for t in styles: assert t.shape[-1] == self.style_channels # receive a noise generator and sample noise. elif callable(styles): device = get_module_device(self) noise_generator = styles assert num_batches > 0 if self.default_style_mode == 'mix' and random.random( ) < self.mix_prob: styles = [ noise_generator((num_batches, self.style_channels)) for _ in range(2) ] else: styles = [noise_generator((num_batches, self.style_channels))] styles = [s.to(device) for s in styles] # otherwise, we will adopt default noise sampler. else: device = get_module_device(self) assert num_batches > 0 and not input_is_latent if self.default_style_mode == 'mix' and random.random( ) < self.mix_prob: styles = [ torch.randn((num_batches, self.style_channels)) for _ in range(2) ] else: styles = [torch.randn((num_batches, self.style_channels))] styles = [s.to(device) for s in styles] if not input_is_latent: noise_batch = styles styles = [self.style_mapping(s) for s in styles] else: noise_batch = None if injected_noise is None: if randomize_noise: injected_noise = [None] * self.num_injected_noises else: injected_noise = [ getattr(self, f'injected_noise_{i}') for i in range(self.num_injected_noises) ] # use truncation trick if truncation < 1: style_t = [] for style in styles: style_t.append(truncation_latent + truncation * (style - truncation_latent)) styles = style_t # no style mixing if len(styles) < 2: inject_index = self.num_latents if styles[0].ndim < 3: latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) else: latent = styles[0] # style mixing else: if inject_index is None: inject_index = random.randint(1, self.num_latents - 1) latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) latent2 = styles[1].unsqueeze(1).repeat( 1, self.num_latents - inject_index, 1) latent = torch.cat([latent, latent2], 1) # 4x4 stage out = self.constant_input(latent) out = self.conv1(out, latent[:, 0], noise=injected_noise[0]) skip = self.to_rgb1(out, latent[:, 1]) _index = 1 # 8x8 ---> higher resolutions for up_conv, conv, noise1, noise2, to_rgb in zip( self.convs[::2], self.convs[1::2], injected_noise[1::2], injected_noise[2::2], self.to_rgbs): out = up_conv(out, latent[:, _index], noise=noise1) out = conv(out, latent[:, _index + 1], noise=noise2) skip = to_rgb(out, latent[:, _index + 2], skip) _index += 2 img = skip if self.bgr2rgb: img = torch.flip(img, dims=1) if return_latents or return_noise: output_dict = dict( fake_img=img, latent=latent, inject_index=inject_index, noise_batch=noise_batch) return output_dict else: return img
[docs]@COMPONENTS.register_module() class StyleGAN2Discriminator(nn.Module): """StyleGAN2 Discriminator. This module comes from MMGeneration. In the future, this code will be removed and StyleGANv2 will be directly imported from mmgeneration. The architecture of this discriminator is proposed in StyleGAN2. More details can be found in: Analyzing and Improving the Image Quality of StyleGAN CVPR2020. You can load pretrained model through passing information into ``pretrained`` argument. We have already offered official weights as follows: - styelgan2-ffhq-config-f: http://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-ffhq-config-f-official_20210327_171224-bce9310c.pth # noqa - stylegan2-horse-config-f: http://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-horse-config-f-official_20210327_173203-ef3e69ca.pth # noqa - stylegan2-car-config-f: http://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-car-config-f-official_20210327_172340-8cfe053c.pth # noqa - styelgan2-cat-config-f: http://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-cat-config-f-official_20210327_172444-15bc485b.pth # noqa - stylegan2-church-config-f: http://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-church-config-f-official_20210327_172657-1d42b7d1.pth # noqa If you want to load the ema model, you can just use following codes: .. code-block:: python # ckpt_http is one of the valid path from http source discriminator = StyleGAN2Discriminator(1024, 512, pretrained=dict( ckpt_path=ckpt_http, prefix='discriminator')) Of course, you can also download the checkpoint in advance and set ``ckpt_path`` with local path. Note that our implementation adopts BGR image as input, while the original StyleGAN2 provides RGB images to the discriminator. Thus, we provide ``bgr2rgb`` argument to convert the image space. If your images follow the RGB order, please set it to ``True`` accordingly. Args: in_size (int): The input size of images. channel_multiplier (int, optional): The multiplier factor for the channel number. Defaults to 2. blur_kernel (list, optional): The blurry kernel. Defaults to [1, 3, 3, 1]. mbstd_cfg (dict, optional): Configs for minibatch-stddev layer. Defaults to dict(group_size=4, channel_groups=1). pretrained (dict | None, optional): Information for pretained models. The necessary key is 'ckpt_path'. Besides, you can also provide 'prefix' to load the generator part from the whole state dict. Defaults to None. bgr2rgb (bool, optional): Whether to flip the image channel dimension. Defaults to False. """ def __init__(self, in_size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1], mbstd_cfg=dict(group_size=4, channel_groups=1), pretrained=None, bgr2rgb=False): super(StyleGAN2Discriminator, self).__init__() self.bgr2rgb = bgr2rgb channels = { 4: 512, 8: 512, 16: 512, 32: 512, 64: 256 * channel_multiplier, 128: 128 * channel_multiplier, 256: 64 * channel_multiplier, 512: 32 * channel_multiplier, 1024: 16 * channel_multiplier, } log_size = int(np.log2(in_size)) in_channels = channels[in_size] convs = [ConvDownLayer(3, channels[in_size], 1)] for i in range(log_size, 2, -1): out_channel = channels[2**(i - 1)] convs.append(ResBlock(in_channels, out_channel, blur_kernel)) in_channels = out_channel self.convs = nn.Sequential(*convs) self.mbstd_layer = ModMBStddevLayer(**mbstd_cfg) self.final_conv = ConvDownLayer(in_channels + 1, channels[4], 3) self.final_linear = nn.Sequential( EqualLinearActModule( channels[4] * 4 * 4, channels[4], act_cfg=dict(type='fused_bias')), EqualLinearActModule(channels[4], 1), ) if pretrained is not None: self._load_pretrained_model(**pretrained) def _load_pretrained_model(self, ckpt_path, prefix='', map_location='cpu', strict=True): state_dict = _load_checkpoint_with_prefix(prefix, ckpt_path, map_location) self.load_state_dict(state_dict, strict=strict) mmcv.print_log(f'Load pretrained model from {ckpt_path}', 'mmedit')
[docs] def forward(self, x): """Forward function. Args: x (torch.Tensor): Input image tensor. Returns: torch.Tensor: Predict score for the input image. """ if self.bgr2rgb: x = torch.flip(x, dims=1) x = self.convs(x) x = self.mbstd_layer(x) x = self.final_conv(x) x = x.view(x.shape[0], -1) x = self.final_linear(x) return x
Read the Docs v: v0.13.0
Versions
latest
stable
v0.13.0
v0.12.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.