import torch from torch import nn import torch.nn.functional as F import numpy as np import functools from . import base_function from .stylegan_ops import style_function from .transformer_ops import transformer_function ################################################################################## # Networks ################################################################################## def define_D(opt, img_size): """Create a discriminator""" norm_value = base_function.get_norm_layer(opt.norm) if 'patch' in opt.netD: net = NLayerDiscriminator(opt.img_nc, opt.ndf, opt.n_layers_D, norm_value, use_attn=opt.attn_D) elif 'style' in opt.netD: net = StyleDiscriminator(img_size, ndf=opt.ndf, use_attn=opt.attn_D) else: raise NotImplementedError('Discriminator model name [%s] is not recognized' % opt.netD) return base_function.init_net(net, opt.init_type, opt.init_gain, initialize_weights=('style' not in opt.netD)) def define_G(opt): """Create a decoder""" if 'diff' in opt.netG: net = base_function.DiffDecoder(opt.img_nc, opt.ngf, opt.kernel_G, opt.embed_dim, opt.n_layers_G, opt.num_res_blocks, word_size=opt.word_size, activation=opt.activation, norm=opt.norm, add_noise=opt.add_noise, use_attn=opt.attn_G, use_pos=opt.use_pos_G) elif 'linear' in opt.netG: net = base_function.LinearDecoder(opt.img_nc, opt.ngf, opt.kernel_G, opt.embed_dim, opt.activation, opt.norm) elif 'refine' in opt.netG: net = RefinedGenerator(opt.img_nc, opt.ngf, opt.embed_dim, opt.down_layers, opt.mid_layers, opt.num_res_blocks, activation=opt.activation, norm=opt.norm) else: raise NotImplementedError('Decoder model name [%s] is not recognized' % opt.netG) return base_function.init_net(net, opt.init_type, opt.init_gain, initialize_weights=('style' not in opt.netG)) def define_E(opt): """Create a encoder""" if 'diff' in opt.netE: net = base_function.DiffEncoder(opt.img_nc, opt.ngf, opt.kernel_E, opt.embed_dim, opt.n_layers_G, opt.num_res_blocks, activation=opt.activation, norm=opt.norm, use_attn=opt.attn_E) elif 'linear' in opt.netE: net = base_function.LinearEncoder(opt.img_nc, opt.kernel_E, opt.embed_dim) else: raise NotImplementedError('Encoder model name [%s] is not recognized' % opt.netE) return base_function.init_net(net, opt.init_type, opt.init_gain, initialize_weights=('style' not in opt.netE)) def define_T(opt): """Create a transformer""" if "original" in opt.netT: e_d_f = int(opt.ngf * (2 ** opt.n_layers_G)) net = transformer_function.Transformer(e_d_f, opt.embed_dim, e_d_f, kernel=opt.kernel_T, n_encoders=opt.n_encoders, n_decoders=opt.n_decoders, embed_type=opt.embed_type) else: raise NotImplementedError('Transformer model name [%s] is not recognized' % opt.netT) return net ################################################################################## # Discriminator ################################################################################## class NLayerDiscriminator(nn.Module): """Defines a PatchGAN discriminator""" def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_attn=False): """Construct a PatchGAN discriminator Parameters: input_nc (int) -- the number of channels in input examples ndf (int) -- the number of filters in the last conv layer n_layers (int) -- the number of conv layers in the discriminator norm_layer -- normalization layer """ super(NLayerDiscriminator, self).__init__() if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d kw = 4 padw = 1 sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] nf_mult = 1 for n in range(1, n_layers): # gradually increase the number of filters nf_mult_prev = nf_mult nf_mult = min(2 ** n, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True)] if n == 2 and use_attn: sequence += [ nn.Conv2d(ndf * nf_mult, ndf * nf_mult, kernel_size=1, stride=1, bias=use_bias), base_function.AttnAware(ndf * nf_mult) ] nf_mult_prev = nf_mult nf_mult = min(2 ** n_layers, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map self.model = nn.Sequential(*sequence) def forward(self, input): """Standard forward.""" return self.model(input) class StyleDiscriminator(nn.Module): def __init__(self, img_size, ndf=32, blur_kernel=[1, 3, 3, 1], use_attn=False): super(StyleDiscriminator, self).__init__() channel_multiplier = ndf / 64 channels = { 4: 512, 8: 512, 16: 512, 32: int(512 * channel_multiplier), 64: int(256 * channel_multiplier), 128: int(128 * channel_multiplier), 256: int(64 * channel_multiplier), 512: int(32 * channel_multiplier), 1024: int(16 * channel_multiplier), } convs = [style_function.ConvLayer(3, channels[img_size], 1)] log_size = int(np.log2(img_size)) in_channel = channels[img_size] for i in range(log_size, 2, -1): out_channel = channels[2**(i-1)] if i == log_size - 3 and use_attn: convs.append(base_function.AttnAware(in_channel)) convs.append(style_function.StyleBlock(in_channel, out_channel, blur_kernel)) in_channel = out_channel self.convs = nn.Sequential(*convs) self.stddev_group = 4 self.stddev_feat = 1 self.final_conv = style_function.ConvLayer(in_channel+1, channels[4], 3) self.final_linear = nn.Sequential( style_function.EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'), style_function.EqualLinear(channels[4], 1), ) def forward(self, x): out = self.convs(x) b, c, h, w = out.shape group = min(b, self.stddev_group) stddev = out.view(group, -1, self.stddev_feat, c // self.stddev_feat, h, w) stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) stddev = stddev.repeat(group, 1, h, w) out = torch.cat([out, stddev], 1) out = self.final_conv(out) out = out.view(b, -1) out = self.final_linear(out) return out ################################################################################## # Generator ################################################################################## class RefinedGenerator(nn.Module): def __init__(self, input_nc, ngf=64, embed_dim=512, down_layers=3, mid_layers=6, num_res_blocks=1, dropout=0.0, rample_with_conv=True, activation='gelu', norm='pixel'): super(RefinedGenerator, self).__init__() activation_layer = base_function.get_nonlinearity_layer(activation) norm_layer = base_function.get_norm_layer(norm) self.down_layers = down_layers self.mid_layers = mid_layers self.num_res_blocks = num_res_blocks out_dims = [] # start self.encode = base_function.PartialConv2d(input_nc, ngf, kernel_size=3, stride=1, padding=1) # down self.down = nn.ModuleList() out_dim = ngf for i in range(self.down_layers): block = nn.ModuleList() down = nn.Module() in_dim = out_dim out_dims.append(out_dim) out_dim = min(int(in_dim * 2), embed_dim) down.downsample = base_function.DownSample(in_dim, rample_with_conv, kernel_size=3) for i_block in range(self.num_res_blocks): block.append(base_function.ResnetBlock(in_dim, out_dim, 3, dropout, activation, norm)) in_dim = out_dim down.block = block self.down.append(down) # middle self.mid = nn.ModuleList() for i in range(self.mid_layers): self.mid.append(base_function.ResnetBlock(out_dim, out_dim, 3, dropout, activation, norm)) # up self.up = nn.ModuleList() for i in range(self.down_layers): block = nn.ModuleList() up = nn.Module() in_dim = out_dim out_dim = max(out_dims[-i-1], ngf) for i_block in range(self.num_res_blocks): block.append(base_function.ResnetBlock(in_dim, out_dim, 3, dropout, activation, norm)) in_dim = out_dim if i == self.down_layers - 3: up.attn = base_function.AttnAware(out_dim, activation, norm) up.block = block upsample = True if i != 0 else False up.out = base_function.ToRGB(out_dim, input_nc, upsample, activation, norm) up.upsample = base_function.UpSample(out_dim, rample_with_conv, kernel_size=3) self.up.append(up) # end self.decode = base_function.ToRGB(out_dim, input_nc, True, activation, norm) def forward(self, x, mask=None): # start x = self.encode(x) pre = None # down for i in range(self.down_layers): x = self.down[i].downsample(x) if i == 2: pre = x for i_block in range(self.num_res_blocks): x = self.down[i].block[i_block](x) # middle for i in range(self.mid_layers): x = self.mid[i](x) # up skip = None for i in range(self.down_layers): for i_block in range(self.num_res_blocks): x = self.up[i].block[i_block](x) if i == self.down_layers - 3: mask = F.interpolate(mask, size=x.size()[2:], mode='bilinear', align_corners=True) if mask is not None else None x = self.up[i].attn(x, pre=pre, mask=mask) skip = self.up[i].out(x, skip) x = self.up[i].upsample(x) # end x = self.decode(x, skip) return x