import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import functools
from . import base_function
from .stylegan_ops import style_function
from .transformer_ops import transformer_function
# Networks
def define_D(opt, img_size):
"""Create a discriminator"""
norm_value = base_function.get_norm_layer(opt.norm)
if 'patch' in opt.netD:
net = NLayerDiscriminator(opt.img_nc, opt.ndf, opt.n_layers_D, norm_value, use_attn=opt.attn_D)
elif 'style' in opt.netD:
net = StyleDiscriminator(img_size, ndf=opt.ndf, use_attn=opt.attn_D)
raise NotImplementedError('Discriminator model name [%s] is not recognized' % opt.netD)
return base_function.init_net(net, opt.init_type, opt.init_gain, initialize_weights=('style' not in opt.netD))
def define_G(opt):
"""Create a decoder"""
if 'diff' in opt.netG:
net = base_function.DiffDecoder(opt.img_nc, opt.ngf, opt.kernel_G, opt.embed_dim, opt.n_layers_G, opt.num_res_blocks,
word_size=opt.word_size, activation=opt.activation, norm=opt.norm,
add_noise=opt.add_noise, use_attn=opt.attn_G, use_pos=opt.use_pos_G)
elif 'linear' in opt.netG:
net = base_function.LinearDecoder(opt.img_nc, opt.ngf, opt.kernel_G, opt.embed_dim, opt.activation, opt.norm)
elif 'refine' in opt.netG:
net = RefinedGenerator(opt.img_nc, opt.ngf, opt.embed_dim, opt.down_layers, opt.mid_layers, opt.num_res_blocks,
activation=opt.activation, norm=opt.norm)
raise NotImplementedError('Decoder model name [%s] is not recognized' % opt.netG)
return base_function.init_net(net, opt.init_type, opt.init_gain, initialize_weights=('style' not in opt.netG))
def define_E(opt):
"""Create a encoder"""
if 'diff' in opt.netE:
net = base_function.DiffEncoder(opt.img_nc, opt.ngf, opt.kernel_E, opt.embed_dim, opt.n_layers_G, opt.num_res_blocks,
activation=opt.activation, norm=opt.norm, use_attn=opt.attn_E)
elif 'linear' in opt.netE:
net = base_function.LinearEncoder(opt.img_nc, opt.kernel_E, opt.embed_dim)
raise NotImplementedError('Encoder model name [%s] is not recognized' % opt.netE)
return base_function.init_net(net, opt.init_type, opt.init_gain, initialize_weights=('style' not in opt.netE))
def define_T(opt):
"""Create a transformer"""
if "original" in opt.netT:
e_d_f = int(opt.ngf * (2 ** opt.n_layers_G))
net = transformer_function.Transformer(e_d_f, opt.embed_dim, e_d_f, kernel=opt.kernel_T,
n_encoders=opt.n_encoders, n_decoders=opt.n_decoders, embed_type=opt.embed_type)
raise NotImplementedError('Transformer model name [%s] is not recognized' % opt.netT)
return net
# Discriminator
class NLayerDiscriminator(nn.Module):
"""Defines a PatchGAN discriminator"""
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_attn=False):
"""Construct a PatchGAN discriminator
input_nc (int) -- the number of channels in input examples
ndf (int) -- the number of filters in the last conv layer
n_layers (int) -- the number of conv layers in the discriminator
norm_layer -- normalization layer
super(NLayerDiscriminator, self).__init__()
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
use_bias = norm_layer.func == nn.InstanceNorm2d
use_bias = norm_layer == nn.InstanceNorm2d
kw = 4
padw = 1
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
nf_mult = 1
for n in range(1, n_layers): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)]
if n == 2 and use_attn:
sequence += [
nn.Conv2d(ndf * nf_mult, ndf * nf_mult, kernel_size=1, stride=1, bias=use_bias),
base_function.AttnAware(ndf * nf_mult)
nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
self.model = nn.Sequential(*sequence)
def forward(self, input):
"""Standard forward."""
return self.model(input)
class StyleDiscriminator(nn.Module):
def __init__(self, img_size, ndf=32, blur_kernel=[1, 3, 3, 1], use_attn=False):
super(StyleDiscriminator, self).__init__()
channel_multiplier = ndf / 64
channels = {
4: 512,
8: 512,
16: 512,
32: int(512 * channel_multiplier),
64: int(256 * channel_multiplier),
128: int(128 * channel_multiplier),
256: int(64 * channel_multiplier),
512: int(32 * channel_multiplier),
1024: int(16 * channel_multiplier),
convs = [style_function.ConvLayer(3, channels[img_size], 1)]
log_size = int(np.log2(img_size))
in_channel = channels[img_size]
for i in range(log_size, 2, -1):
out_channel = channels[2**(i-1)]
if i == log_size - 3 and use_attn:
convs.append(style_function.StyleBlock(in_channel, out_channel, blur_kernel))
in_channel = out_channel
self.convs = nn.Sequential(*convs)
self.stddev_group = 4
self.stddev_feat = 1
self.final_conv = style_function.ConvLayer(in_channel+1, channels[4], 3)
self.final_linear = nn.Sequential(
style_function.EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
style_function.EqualLinear(channels[4], 1),
def forward(self, x):
out = self.convs(x)
b, c, h, w = out.shape
group = min(b, self.stddev_group)
stddev = out.view(group, -1, self.stddev_feat, c // self.stddev_feat, h, w)
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
stddev = stddev.repeat(group, 1, h, w)
out =[out, stddev], 1)
out = self.final_conv(out)
out = out.view(b, -1)
out = self.final_linear(out)
return out
# Generator
class RefinedGenerator(nn.Module):
def __init__(self, input_nc, ngf=64, embed_dim=512, down_layers=3, mid_layers=6, num_res_blocks=1, dropout=0.0,
rample_with_conv=True, activation='gelu', norm='pixel'):
super(RefinedGenerator, self).__init__()
activation_layer = base_function.get_nonlinearity_layer(activation)
norm_layer = base_function.get_norm_layer(norm)
self.down_layers = down_layers
self.mid_layers = mid_layers
self.num_res_blocks = num_res_blocks
out_dims = []
# start
self.encode = base_function.PartialConv2d(input_nc, ngf, kernel_size=3, stride=1, padding=1)
# down
self.down = nn.ModuleList()
out_dim = ngf
for i in range(self.down_layers):
block = nn.ModuleList()
down = nn.Module()
in_dim = out_dim
out_dim = min(int(in_dim * 2), embed_dim)
down.downsample = base_function.DownSample(in_dim, rample_with_conv, kernel_size=3)
for i_block in range(self.num_res_blocks):
block.append(base_function.ResnetBlock(in_dim, out_dim, 3, dropout, activation, norm))
in_dim = out_dim
down.block = block
# middle
self.mid = nn.ModuleList()
for i in range(self.mid_layers):
self.mid.append(base_function.ResnetBlock(out_dim, out_dim, 3, dropout, activation, norm))
# up
self.up = nn.ModuleList()
for i in range(self.down_layers):
block = nn.ModuleList()
up = nn.Module()
in_dim = out_dim
out_dim = max(out_dims[-i-1], ngf)
for i_block in range(self.num_res_blocks):
block.append(base_function.ResnetBlock(in_dim, out_dim, 3, dropout, activation, norm))
in_dim = out_dim
if i == self.down_layers - 3:
up.attn = base_function.AttnAware(out_dim, activation, norm)
up.block = block
upsample = True if i != 0 else False
up.out = base_function.ToRGB(out_dim, input_nc, upsample, activation, norm)
up.upsample = base_function.UpSample(out_dim, rample_with_conv, kernel_size=3)
# end
self.decode = base_function.ToRGB(out_dim, input_nc, True, activation, norm)
def forward(self, x, mask=None):
# start
x = self.encode(x)
pre = None
# down
for i in range(self.down_layers):
x = self.down[i].downsample(x)
if i == 2:
pre = x
for i_block in range(self.num_res_blocks):
x = self.down[i].block[i_block](x)
# middle
for i in range(self.mid_layers):
x = self.mid[i](x)
# up
skip = None
for i in range(self.down_layers):
for i_block in range(self.num_res_blocks):
x = self.up[i].block[i_block](x)
if i == self.down_layers - 3:
mask = F.interpolate(mask, size=x.size()[2:], mode='bilinear', align_corners=True) if mask is not None else None
x = self.up[i].attn(x, pre=pre, mask=mask)
skip = self.up[i].out(x, skip)
x = self.up[i].upsample(x)
# end
x = self.decode(x, skip)
return x