|
import torch |
|
import torch.nn as nn |
|
from torch.nn import init |
|
import functools |
|
from torch.optim import lr_scheduler |
|
import numpy as np |
|
import torch.nn.functional as F |
|
from torch.nn.modules.normalization import LayerNorm |
|
import os |
|
from torch.nn.utils import spectral_norm |
|
from torchvision import models |
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_weights(net, init_type='normal', init_gain=0.02): |
|
"""Initialize network weights. |
|
Parameters: |
|
net (network) -- network to be initialized |
|
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal |
|
init_gain (float) -- scaling factor for normal, xavier and orthogonal. |
|
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might |
|
work better for some applications. Feel free to try yourself. |
|
""" |
|
def init_func(m): |
|
classname = m.__class__.__name__ |
|
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): |
|
if init_type == 'normal': |
|
init.normal_(m.weight.data, 0.0, init_gain) |
|
elif init_type == 'xavier': |
|
init.xavier_normal_(m.weight.data, gain=init_gain) |
|
elif init_type == 'kaiming': |
|
|
|
init.kaiming_normal_(m.weight.data, a=0.2, mode='fan_in', nonlinearity='leaky_relu') |
|
elif init_type == 'orthogonal': |
|
init.orthogonal_(m.weight.data, gain=init_gain) |
|
else: |
|
raise NotImplementedError('initialization method [%s] is not implemented' % init_type) |
|
if hasattr(m, 'bias') and m.bias is not None: |
|
init.constant_(m.bias.data, 0.0) |
|
elif classname.find('BatchNorm2d') != -1: |
|
init.normal_(m.weight.data, 1.0, init_gain) |
|
init.constant_(m.bias.data, 0.0) |
|
|
|
print('initialize network with %s' % init_type) |
|
net.apply(init_func) |
|
|
|
|
|
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], init=True): |
|
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights |
|
Parameters: |
|
net (network) -- the network to be initialized |
|
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal |
|
gain (float) -- scaling factor for normal, xavier and orthogonal. |
|
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 |
|
Return an initialized network. |
|
""" |
|
if len(gpu_ids) > 0: |
|
assert(torch.cuda.is_available()) |
|
net.to(gpu_ids[0]) |
|
if init: |
|
init_weights(net, init_type, init_gain=init_gain) |
|
return net |
|
|
|
|
|
def get_scheduler(optimizer, opt): |
|
"""Return a learning rate scheduler |
|
Parameters: |
|
optimizer -- the optimizer of the network |
|
opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptionsοΌγ |
|
opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine |
|
For 'linear', we keep the same learning rate for the first <opt.niter> epochs |
|
and linearly decay the rate to zero over the next <opt.niter_decay> epochs. |
|
For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. |
|
See https://pytorch.org/docs/stable/optim.html for more details. |
|
""" |
|
if opt.lr_policy == 'linear': |
|
def lambda_rule(epoch): |
|
lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) |
|
return lr_l |
|
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) |
|
elif opt.lr_policy == 'step': |
|
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) |
|
elif opt.lr_policy == 'plateau': |
|
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) |
|
elif opt.lr_policy == 'cosine': |
|
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0) |
|
else: |
|
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) |
|
return scheduler |
|
|
|
class LayerNormWarpper(nn.Module): |
|
def __init__(self, num_features): |
|
super(LayerNormWarpper, self).__init__() |
|
self.num_features = int(num_features) |
|
|
|
def forward(self, x): |
|
x = nn.LayerNorm([self.num_features, x.size()[2], x.size()[3]], elementwise_affine=False).cuda()(x) |
|
return x |
|
|
|
def get_norm_layer(norm_type='instance'): |
|
"""Return a normalization layer |
|
Parameters: |
|
norm_type (str) -- the name of the normalization layer: batch | instance | none |
|
For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). |
|
For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. |
|
""" |
|
if norm_type == 'batch': |
|
norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) |
|
elif norm_type == 'instance': |
|
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) |
|
elif norm_type == 'layer': |
|
norm_layer = functools.partial(LayerNormWarpper) |
|
elif norm_type == 'none': |
|
norm_layer = None |
|
else: |
|
raise NotImplementedError('normalization layer [%s] is not found' % norm_type) |
|
return norm_layer |
|
|
|
|
|
def get_non_linearity(layer_type='relu'): |
|
if layer_type == 'relu': |
|
nl_layer = functools.partial(nn.ReLU, inplace=True) |
|
elif layer_type == 'lrelu': |
|
nl_layer = functools.partial( |
|
nn.LeakyReLU, negative_slope=0.2, inplace=True) |
|
elif layer_type == 'elu': |
|
nl_layer = functools.partial(nn.ELU, inplace=True) |
|
elif layer_type == 'selu': |
|
nl_layer = functools.partial(nn.SELU, inplace=True) |
|
elif layer_type == 'prelu': |
|
nl_layer = functools.partial(nn.PReLU) |
|
else: |
|
raise NotImplementedError( |
|
'nonlinearity activitation [%s] is not found' % layer_type) |
|
return nl_layer |
|
|
|
|
|
def define_G(input_nc, output_nc, nz, ngf, netG='unet_128', norm='batch', nl='relu', use_noise=False, |
|
use_dropout=False, init_type='xavier', init_gain=0.02, gpu_ids=[], where_add='input', upsample='bilinear'): |
|
net = None |
|
norm_layer = get_norm_layer(norm_type=norm) |
|
nl_layer = get_non_linearity(layer_type=nl) |
|
|
|
|
|
if nz == 0: |
|
where_add = 'input' |
|
|
|
if netG == 'unet_128' and where_add == 'input': |
|
net = G_Unet_add_input(input_nc, output_nc, nz, 7, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise, |
|
use_dropout=use_dropout, upsample=upsample, device=gpu_ids) |
|
elif netG == 'unet_128_G' and where_add == 'input': |
|
net = G_Unet_add_input_G(input_nc, output_nc, nz, 7, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise, |
|
use_dropout=use_dropout, upsample=upsample, device=gpu_ids) |
|
elif netG == 'unet_256' and where_add == 'input': |
|
net = G_Unet_add_input(input_nc, output_nc, nz, 8, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise, |
|
use_dropout=use_dropout, upsample=upsample, device=gpu_ids) |
|
elif netG == 'unet_256_G' and where_add == 'input': |
|
net = G_Unet_add_input_G(input_nc, output_nc, nz, 8, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise, |
|
use_dropout=use_dropout, upsample=upsample, device=gpu_ids) |
|
elif netG == 'unet_128' and where_add == 'all': |
|
net = G_Unet_add_all(input_nc, output_nc, nz, 7, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise, |
|
use_dropout=use_dropout, upsample=upsample) |
|
elif netG == 'unet_256' and where_add == 'all': |
|
net = G_Unet_add_all(input_nc, output_nc, nz, 8, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise, |
|
use_dropout=use_dropout, upsample=upsample) |
|
else: |
|
raise NotImplementedError('Generator model name [%s] is not recognized' % net) |
|
|
|
return init_net(net, init_type, init_gain, gpu_ids) |
|
|
|
|
|
def define_C(input_nc, output_nc, nz, ngf, netC='unet_128', norm='instance', nl='relu', |
|
use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[], upsample='basic'): |
|
net = None |
|
norm_layer = get_norm_layer(norm_type=norm) |
|
nl_layer = get_non_linearity(layer_type=nl) |
|
|
|
if netC == 'resnet_9blocks': |
|
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9) |
|
elif netC == 'resnet_6blocks': |
|
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6) |
|
elif netC == 'unet_128': |
|
net = G_Unet_add_input_C(input_nc, output_nc, 0, 7, ngf, norm_layer=norm_layer, nl_layer=nl_layer, |
|
use_dropout=use_dropout, upsample=upsample) |
|
elif netC == 'unet_256': |
|
net = G_Unet_add_input(input_nc, output_nc, 0, 8, ngf, norm_layer=norm_layer, nl_layer=nl_layer, |
|
use_dropout=use_dropout, upsample=upsample) |
|
elif netC == 'unet_32': |
|
net = G_Unet_add_input(input_nc, output_nc, 0, 5, ngf, norm_layer=norm_layer, nl_layer=nl_layer, |
|
use_dropout=use_dropout, upsample=upsample) |
|
else: |
|
raise NotImplementedError('Generator model name [%s] is not recognized' % net) |
|
|
|
return init_net(net, init_type, init_gain, gpu_ids) |
|
|
|
|
|
def define_D(input_nc, ndf, netD, norm='batch', nl='lrelu', init_type='xavier', init_gain=0.02, num_Ds=1, gpu_ids=[]): |
|
net = None |
|
norm_layer = get_norm_layer(norm_type=norm) |
|
nl = 'lrelu' |
|
nl_layer = get_non_linearity(layer_type=nl) |
|
|
|
if netD == 'basic_128': |
|
net = D_NLayers(input_nc, ndf, n_layers=2, norm_layer=norm_layer, nl_layer=nl_layer) |
|
elif netD == 'basic_256': |
|
net = D_NLayers(input_nc, ndf, n_layers=3, norm_layer=norm_layer, nl_layer=nl_layer) |
|
elif netD == 'basic_128_multi': |
|
net = D_NLayersMulti(input_nc=input_nc, ndf=ndf, n_layers=2, norm_layer=norm_layer, num_D=num_Ds, nl_layer=nl_layer) |
|
elif netD == 'basic_256_multi': |
|
net = D_NLayersMulti(input_nc=input_nc, ndf=ndf, n_layers=3, norm_layer=norm_layer, num_D=num_Ds, nl_layer=nl_layer) |
|
else: |
|
raise NotImplementedError('Discriminator model name [%s] is not recognized' % net) |
|
return init_net(net, init_type, init_gain, gpu_ids) |
|
|
|
|
|
def define_E(input_nc, output_nc, ndf, netE, norm='batch', nl='lrelu', |
|
init_type='xavier', init_gain=0.02, gpu_ids=[], vaeLike=False): |
|
net = None |
|
norm_layer = get_norm_layer(norm_type=norm) |
|
nl = 'lrelu' |
|
nl_layer = get_non_linearity(layer_type=nl) |
|
if netE == 'resnet_128': |
|
net = E_ResNet(input_nc, output_nc, ndf, n_blocks=4, norm_layer=norm_layer, |
|
nl_layer=nl_layer, vaeLike=vaeLike) |
|
elif netE == 'resnet_256': |
|
net = E_ResNet(input_nc, output_nc, ndf, n_blocks=5, norm_layer=norm_layer, |
|
nl_layer=nl_layer, vaeLike=vaeLike) |
|
elif netE == 'conv_128': |
|
net = E_NLayers(input_nc, output_nc, ndf, n_layers=4, norm_layer=norm_layer, |
|
nl_layer=nl_layer, vaeLike=vaeLike) |
|
elif netE == 'conv_256': |
|
net = E_NLayers(input_nc, output_nc, ndf, n_layers=5, norm_layer=norm_layer, |
|
nl_layer=nl_layer, vaeLike=vaeLike) |
|
else: |
|
raise NotImplementedError('Encoder model name [%s] is not recognized' % net) |
|
|
|
return init_net(net, init_type, init_gain, gpu_ids, False) |
|
|
|
|
|
class ResnetGenerator(nn.Module): |
|
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, norm_layer=None, use_dropout=False, n_blocks=6, padding_type='replicate'): |
|
assert(n_blocks >= 0) |
|
super(ResnetGenerator, self).__init__() |
|
self.input_nc = input_nc |
|
self.output_nc = output_nc |
|
self.ngf = ngf |
|
if type(norm_layer) == functools.partial: |
|
use_bias = norm_layer.func != nn.BatchNorm2d |
|
else: |
|
use_bias = norm_layer != nn.BatchNorm2d |
|
|
|
model = [nn.ReplicationPad2d(3), |
|
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, |
|
bias=use_bias)] |
|
if norm_layer is not None: |
|
model += [norm_layer(ngf)] |
|
model += [nn.ReLU(True)] |
|
|
|
|
|
for i in range(n_downsampling): |
|
mult = 2**i |
|
model += [nn.ReplicationPad2d(1),nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, |
|
stride=2, padding=0, bias=use_bias)] |
|
|
|
|
|
if norm_layer is not None: |
|
model += [norm_layer(ngf * mult * 2)] |
|
model += [nn.ReLU(True)] |
|
|
|
mult = 2**n_downsampling |
|
for i in range(n_blocks): |
|
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] |
|
|
|
for i in range(n_downsampling): |
|
mult = 2**(n_downsampling - i) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model += upsampleLayer(ngf * mult, int(ngf * mult / 2), upsample='bilinear', padding_type=padding_type) |
|
if norm_layer is not None: |
|
model += [norm_layer(int(ngf * mult / 2))] |
|
model += [nn.ReLU(True)] |
|
model +=[nn.ReplicationPad2d(1), |
|
nn.Conv2d(int(ngf * mult / 2), int(ngf * mult / 2), kernel_size=3, padding=0)] |
|
if norm_layer is not None: |
|
model += [norm_layer(ngf * mult / 2)] |
|
model += [nn.ReLU(True)] |
|
model += [nn.ReplicationPad2d(3)] |
|
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] |
|
|
|
|
|
self.model = nn.Sequential(*model) |
|
|
|
def forward(self, input): |
|
return self.model(input) |
|
|
|
|
|
|
|
class ResnetBlock(nn.Module): |
|
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): |
|
super(ResnetBlock, self).__init__() |
|
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) |
|
|
|
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): |
|
conv_block = [] |
|
p = 0 |
|
if padding_type == 'reflect': |
|
conv_block += [nn.ReflectionPad2d(1)] |
|
elif padding_type == 'replicate': |
|
conv_block += [nn.ReplicationPad2d(1)] |
|
elif padding_type == 'zero': |
|
p = 1 |
|
else: |
|
raise NotImplementedError('padding [%s] is not implemented' % padding_type) |
|
|
|
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)] |
|
if norm_layer is not None: |
|
conv_block += [norm_layer(dim)] |
|
conv_block += [nn.ReLU(True)] |
|
|
|
|
|
|
|
p = 0 |
|
if padding_type == 'reflect': |
|
conv_block += [nn.ReflectionPad2d(1)] |
|
elif padding_type == 'replicate': |
|
conv_block += [nn.ReplicationPad2d(1)] |
|
elif padding_type == 'zero': |
|
p = 1 |
|
else: |
|
raise NotImplementedError('padding [%s] is not implemented' % padding_type) |
|
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)] |
|
if norm_layer is not None: |
|
conv_block += [norm_layer(dim)] |
|
|
|
return nn.Sequential(*conv_block) |
|
|
|
def forward(self, x): |
|
out = x + self.conv_block(x) |
|
return out |
|
|
|
|
|
class D_NLayersMulti(nn.Module): |
|
def __init__(self, input_nc, ndf=64, n_layers=3, |
|
norm_layer=nn.BatchNorm2d, num_D=1, nl_layer=None): |
|
super(D_NLayersMulti, self).__init__() |
|
|
|
self.num_D = num_D |
|
self.nl_layer=nl_layer |
|
if num_D == 1: |
|
layers = self.get_layers(input_nc, ndf, n_layers, norm_layer) |
|
self.model = nn.Sequential(*layers) |
|
else: |
|
layers = self.get_layers(input_nc, ndf, n_layers, norm_layer) |
|
self.add_module("model_0", nn.Sequential(*layers)) |
|
self.down = nn.functional.interpolate |
|
for i in range(1, num_D): |
|
ndf_i = int(round(ndf / (2**i))) |
|
layers = self.get_layers(input_nc, ndf_i, n_layers, norm_layer) |
|
self.add_module("model_%d" % i, nn.Sequential(*layers)) |
|
|
|
def get_layers(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): |
|
kw = 3 |
|
padw = 1 |
|
sequence = [spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kw, |
|
stride=2, padding=padw)), nn.LeakyReLU(0.2, True)] |
|
|
|
nf_mult = 1 |
|
nf_mult_prev = 1 |
|
for n in range(1, n_layers): |
|
nf_mult_prev = nf_mult |
|
nf_mult = min(2**n, 8) |
|
sequence += [spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, |
|
kernel_size=kw, stride=2, padding=padw))] |
|
if norm_layer: |
|
sequence += [norm_layer(ndf * nf_mult)] |
|
|
|
sequence += [self.nl_layer()] |
|
|
|
nf_mult_prev = nf_mult |
|
nf_mult = min(2**n_layers, 8) |
|
sequence += [spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, |
|
kernel_size=kw, stride=1, padding=padw))] |
|
if norm_layer: |
|
sequence += [norm_layer(ndf * nf_mult)] |
|
sequence += [self.nl_layer()] |
|
|
|
sequence += [spectral_norm(nn.Conv2d(ndf * nf_mult, 1, |
|
kernel_size=kw, stride=1, padding=padw))] |
|
|
|
return sequence |
|
|
|
def forward(self, input): |
|
if self.num_D == 1: |
|
return self.model(input) |
|
result = [] |
|
down = input |
|
for i in range(self.num_D): |
|
model = getattr(self, "model_%d" % i) |
|
result.append(model(down)) |
|
if i != self.num_D - 1: |
|
down = self.down(down, scale_factor=0.5, mode='bilinear') |
|
return result |
|
|
|
class D_NLayers(nn.Module): |
|
"""Defines a PatchGAN discriminator""" |
|
|
|
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): |
|
"""Construct a PatchGAN discriminator |
|
Parameters: |
|
input_nc (int) -- the number of channels in input images |
|
ndf (int) -- the number of filters in the last conv layer |
|
n_layers (int) -- the number of conv layers in the discriminator |
|
norm_layer -- normalization layer |
|
""" |
|
super(D_NLayers, self).__init__() |
|
if type(norm_layer) == functools.partial: |
|
use_bias = norm_layer.func != nn.BatchNorm2d |
|
else: |
|
use_bias = norm_layer != nn.BatchNorm2d |
|
|
|
kw = 3 |
|
padw = 1 |
|
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] |
|
nf_mult = 1 |
|
nf_mult_prev = 1 |
|
for n in range(1, n_layers): |
|
nf_mult_prev = nf_mult |
|
nf_mult = min(2 ** n, 8) |
|
sequence += [ |
|
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), |
|
norm_layer(ndf * nf_mult), |
|
nn.LeakyReLU(0.2, True) |
|
] |
|
|
|
nf_mult_prev = nf_mult |
|
nf_mult = min(2 ** n_layers, 8) |
|
sequence += [ |
|
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), |
|
norm_layer(ndf * nf_mult), |
|
nn.LeakyReLU(0.2, True) |
|
] |
|
|
|
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] |
|
self.model = nn.Sequential(*sequence) |
|
|
|
def forward(self, input): |
|
"""Standard forward.""" |
|
return self.model(input) |
|
|
|
|
|
class G_Unet_add_input(nn.Module): |
|
def __init__(self, input_nc, output_nc, nz, num_downs, ngf=64, |
|
norm_layer=None, nl_layer=None, use_dropout=False, use_noise=False, |
|
upsample='basic', device=0): |
|
super(G_Unet_add_input, self).__init__() |
|
self.nz = nz |
|
max_nchn = 8 |
|
noise = [] |
|
for i in range(num_downs+1): |
|
if use_noise: |
|
noise.append(True) |
|
else: |
|
noise.append(False) |
|
|
|
|
|
|
|
unet_block = UnetBlock_A(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, noise=noise[num_downs-1], |
|
innermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample) |
|
for i in range(num_downs - 5): |
|
unet_block = UnetBlock_A(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, unet_block, noise[num_downs-i-3], |
|
norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample) |
|
unet_block = UnetBlock_A(ngf * 4, ngf * 4, ngf * max_nchn, unet_block, noise[2], |
|
norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample) |
|
unet_block = UnetBlock_A(ngf * 2, ngf * 2, ngf * 4, unet_block, noise[1], |
|
norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample) |
|
unet_block = UnetBlock_A(ngf, ngf, ngf * 2, unet_block, noise[0], |
|
norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample) |
|
unet_block = UnetBlock_A(input_nc + nz, output_nc, ngf, unet_block, None, |
|
outermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample) |
|
|
|
self.model = unet_block |
|
|
|
def forward(self, x, z=None): |
|
if self.nz > 0: |
|
z_img = z.view(z.size(0), z.size(1), 1, 1).expand( |
|
z.size(0), z.size(1), x.size(2), x.size(3)) |
|
x_with_z = torch.cat([x, z_img], 1) |
|
else: |
|
x_with_z = x |
|
|
|
|
|
return torch.tanh(self.model(x_with_z)) |
|
|
|
|
|
class G_Unet_add_input_G(nn.Module): |
|
def __init__(self, input_nc, output_nc, nz, num_downs, ngf=64, |
|
norm_layer=None, nl_layer=None, use_dropout=False, use_noise=False, |
|
upsample='basic', device=0): |
|
super(G_Unet_add_input_G, self).__init__() |
|
self.nz = nz |
|
max_nchn = 8 |
|
noise = [] |
|
for i in range(num_downs+1): |
|
if use_noise: |
|
noise.append(True) |
|
else: |
|
noise.append(False) |
|
|
|
|
|
unet_block = UnetBlock_G(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, noise=False, |
|
innermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample) |
|
for i in range(num_downs - 5): |
|
unet_block = UnetBlock_G(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, unet_block, noise=False, |
|
norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample) |
|
unet_block = UnetBlock_G(ngf * 4, ngf * 4, ngf * max_nchn, unet_block, noise[2], |
|
norm_layer=norm_layer, nl_layer=nl_layer, upsample='basic') |
|
unet_block = UnetBlock_G(ngf * 2, ngf * 2, ngf * 4, unet_block, noise[1], |
|
norm_layer=norm_layer, nl_layer=nl_layer, upsample='basic') |
|
unet_block = UnetBlock_G(ngf, ngf, ngf * 2, unet_block, noise[0], |
|
norm_layer=norm_layer, nl_layer=nl_layer, upsample='basic') |
|
unet_block = UnetBlock_G(input_nc + nz, output_nc, ngf, unet_block, None, |
|
outermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample='basic') |
|
|
|
self.model = unet_block |
|
|
|
def forward(self, x, z=None): |
|
if self.nz > 0: |
|
z_img = z.view(z.size(0), z.size(1), 1, 1).expand( |
|
z.size(0), z.size(1), x.size(2), x.size(3)) |
|
x_with_z = torch.cat([x, z_img], 1) |
|
else: |
|
x_with_z = x |
|
|
|
|
|
return self.model(x_with_z) |
|
|
|
class G_Unet_add_input_C(nn.Module): |
|
def __init__(self, input_nc, output_nc, nz, num_downs, ngf=64, |
|
norm_layer=None, nl_layer=None, use_dropout=False, use_noise=False, |
|
upsample='basic', device=0): |
|
super(G_Unet_add_input_C, self).__init__() |
|
self.nz = nz |
|
max_nchn = 8 |
|
|
|
|
|
unet_block = UnetBlock(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, noise=False, |
|
innermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample) |
|
for i in range(num_downs - 5): |
|
unet_block = UnetBlock(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, unet_block, noise=False, |
|
norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample) |
|
unet_block = UnetBlock(ngf * 4, ngf * 4, ngf * max_nchn, unet_block, noise=False, |
|
norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample) |
|
unet_block = UnetBlock(ngf * 2, ngf * 2, ngf * 4, unet_block, noise=False, |
|
norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample) |
|
unet_block = UnetBlock(ngf, ngf, ngf * 2, unet_block, noise=False, |
|
norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample) |
|
unet_block = UnetBlock(input_nc + nz, output_nc, ngf, unet_block, noise=False, |
|
outermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample) |
|
|
|
self.model = unet_block |
|
|
|
def forward(self, x, z=None): |
|
if self.nz > 0: |
|
z_img = z.view(z.size(0), z.size(1), 1, 1).expand( |
|
z.size(0), z.size(1), x.size(2), x.size(3)) |
|
x_with_z = torch.cat([x, z_img], 1) |
|
else: |
|
x_with_z = x |
|
|
|
|
|
return self.model(x_with_z) |
|
|
|
def upsampleLayer(inplanes, outplanes, kw=1, upsample='basic', padding_type='replicate'): |
|
|
|
if upsample == 'basic': |
|
upconv = [nn.ConvTranspose2d(inplanes, outplanes, kernel_size=4, stride=2, padding=1)] |
|
elif upsample == 'bilinear' or upsample == 'nearest' or upsample == 'linear': |
|
upconv = [nn.Upsample(scale_factor=2, mode=upsample, align_corners=True), |
|
|
|
nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, padding=0)] |
|
|
|
|
|
|
|
else: |
|
raise NotImplementedError( |
|
'upsample layer [%s] not implemented' % upsample) |
|
return upconv |
|
|
|
class UnetBlock_G(nn.Module): |
|
def __init__(self, input_nc, outer_nc, inner_nc, |
|
submodule=None, noise=None, outermost=False, innermost=False, |
|
norm_layer=None, nl_layer=None, use_dropout=False, upsample='basic', padding_type='replicate'): |
|
super(UnetBlock_G, self).__init__() |
|
self.outermost = outermost |
|
p = 0 |
|
downconv = [] |
|
if padding_type == 'reflect': |
|
downconv += [nn.ReflectionPad2d(1)] |
|
elif padding_type == 'replicate': |
|
downconv += [nn.ReplicationPad2d(1)] |
|
elif padding_type == 'zero': |
|
p = 1 |
|
else: |
|
raise NotImplementedError( |
|
'padding [%s] is not implemented' % padding_type) |
|
|
|
downconv += [nn.Conv2d(input_nc, inner_nc, |
|
kernel_size=3, stride=2, padding=p)] |
|
|
|
downrelu = nn.LeakyReLU(0.2, True) |
|
downnorm = norm_layer(inner_nc) if norm_layer is not None else None |
|
uprelu = nl_layer() |
|
uprelu2 = nl_layer() |
|
uppad = nn.ReplicationPad2d(1) |
|
upnorm = norm_layer(outer_nc) if norm_layer is not None else None |
|
upnorm2 = norm_layer(outer_nc) if norm_layer is not None else None |
|
self.noiseblock = ApplyNoise(outer_nc) |
|
self.noise = noise |
|
|
|
if outermost: |
|
upconv = upsampleLayer(inner_nc * 2, inner_nc, upsample=upsample, padding_type=padding_type) |
|
uppad = nn.ReplicationPad2d(3) |
|
upconv2 = nn.Conv2d(inner_nc, outer_nc, kernel_size=7, padding=0) |
|
down = downconv |
|
up = [uprelu] + upconv |
|
if upnorm is not None: |
|
up += [norm_layer(inner_nc)] |
|
|
|
|
|
|
|
|
|
|
|
|
|
up +=[uprelu2, uppad, upconv2] |
|
model = down + [submodule] + up |
|
elif innermost: |
|
upconv = upsampleLayer(inner_nc, outer_nc, upsample=upsample, padding_type=padding_type) |
|
upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p) |
|
down = [downrelu] + downconv |
|
up = [uprelu] + upconv |
|
if upnorm is not None: |
|
up += [upnorm] |
|
up += [uprelu2, uppad, upconv2] |
|
if upnorm2 is not None: |
|
up += [upnorm2] |
|
model = down + up |
|
else: |
|
upconv = upsampleLayer(inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type) |
|
upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p) |
|
down = [downrelu] + downconv |
|
if downnorm is not None: |
|
down += [downnorm] |
|
up = [uprelu] + upconv |
|
if upnorm is not None: |
|
up += [upnorm] |
|
up += [uprelu2, uppad, upconv2] |
|
if upnorm2 is not None: |
|
up += [upnorm2] |
|
|
|
if use_dropout: |
|
model = down + [submodule] + up + [nn.Dropout(0.5)] |
|
else: |
|
model = down + [submodule] + up |
|
|
|
self.model = nn.Sequential(*model) |
|
|
|
def forward(self, x): |
|
if self.outermost: |
|
return self.model(x) |
|
else: |
|
x2 = self.model(x) |
|
if self.noise: |
|
x2 = self.noiseblock(x2, self.noise) |
|
return torch.cat([x2, x], 1) |
|
|
|
|
|
class UnetBlock(nn.Module): |
|
def __init__(self, input_nc, outer_nc, inner_nc, |
|
submodule=None, noise=None, outermost=False, innermost=False, |
|
norm_layer=None, nl_layer=None, use_dropout=False, upsample='basic', padding_type='replicate'): |
|
super(UnetBlock, self).__init__() |
|
self.outermost = outermost |
|
p = 0 |
|
downconv = [] |
|
if padding_type == 'reflect': |
|
downconv += [nn.ReflectionPad2d(1)] |
|
elif padding_type == 'replicate': |
|
downconv += [nn.ReplicationPad2d(1)] |
|
elif padding_type == 'zero': |
|
p = 1 |
|
else: |
|
raise NotImplementedError( |
|
'padding [%s] is not implemented' % padding_type) |
|
|
|
downconv += [nn.Conv2d(input_nc, inner_nc, |
|
kernel_size=3, stride=2, padding=p)] |
|
|
|
downrelu = nn.LeakyReLU(0.2, True) |
|
downnorm = norm_layer(inner_nc) if norm_layer is not None else None |
|
uprelu = nl_layer() |
|
uprelu2 = nl_layer() |
|
uppad = nn.ReplicationPad2d(1) |
|
upnorm = norm_layer(outer_nc) if norm_layer is not None else None |
|
upnorm2 = norm_layer(outer_nc) if norm_layer is not None else None |
|
self.noiseblock = ApplyNoise(outer_nc) |
|
self.noise = noise |
|
|
|
if outermost: |
|
upconv = upsampleLayer(inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type) |
|
upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p) |
|
down = downconv |
|
up = [uprelu] + upconv |
|
if upnorm is not None: |
|
up += [upnorm] |
|
up +=[uprelu2, uppad, upconv2] |
|
model = down + [submodule] + up |
|
elif innermost: |
|
upconv = upsampleLayer(inner_nc, outer_nc, upsample=upsample, padding_type=padding_type) |
|
upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p) |
|
down = [downrelu] + downconv |
|
up = [uprelu] + upconv |
|
if upnorm is not None: |
|
up += [upnorm] |
|
up += [uprelu2, uppad, upconv2] |
|
if upnorm2 is not None: |
|
up += [upnorm2] |
|
model = down + up |
|
else: |
|
upconv = upsampleLayer(inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type) |
|
upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p) |
|
down = [downrelu] + downconv |
|
if downnorm is not None: |
|
down += [downnorm] |
|
up = [uprelu] + upconv |
|
if upnorm is not None: |
|
up += [upnorm] |
|
up += [uprelu2, uppad, upconv2] |
|
if upnorm2 is not None: |
|
up += [upnorm2] |
|
|
|
if use_dropout: |
|
model = down + [submodule] + up + [nn.Dropout(0.5)] |
|
else: |
|
model = down + [submodule] + up |
|
|
|
self.model = nn.Sequential(*model) |
|
|
|
def forward(self, x): |
|
if self.outermost: |
|
return self.model(x) |
|
else: |
|
x2 = self.model(x) |
|
if self.noise: |
|
x2 = self.noiseblock(x2, self.noise) |
|
return torch.cat([x2, x], 1) |
|
|
|
|
|
|
|
|
|
class UnetBlock_A(nn.Module): |
|
def __init__(self, input_nc, outer_nc, inner_nc, |
|
submodule=None, noise=None, outermost=False, innermost=False, |
|
norm_layer=None, nl_layer=None, use_dropout=False, upsample='basic', padding_type='replicate'): |
|
super(UnetBlock_A, self).__init__() |
|
self.outermost = outermost |
|
p = 0 |
|
downconv = [] |
|
if padding_type == 'reflect': |
|
downconv += [nn.ReflectionPad2d(1)] |
|
elif padding_type == 'replicate': |
|
downconv += [nn.ReplicationPad2d(1)] |
|
elif padding_type == 'zero': |
|
p = 1 |
|
else: |
|
raise NotImplementedError( |
|
'padding [%s] is not implemented' % padding_type) |
|
|
|
downconv += [spectral_norm(nn.Conv2d(input_nc, inner_nc, |
|
kernel_size=3, stride=2, padding=p))] |
|
|
|
downrelu = nn.LeakyReLU(0.2, True) |
|
downnorm = norm_layer(inner_nc) if norm_layer is not None else None |
|
uprelu = nl_layer() |
|
uprelu2 = nl_layer() |
|
uppad = nn.ReplicationPad2d(1) |
|
upnorm = norm_layer(outer_nc) if norm_layer is not None else None |
|
upnorm2 = norm_layer(outer_nc) if norm_layer is not None else None |
|
self.noiseblock = ApplyNoise(outer_nc) |
|
self.noise = noise |
|
|
|
if outermost: |
|
upconv = upsampleLayer(inner_nc * 1, outer_nc, upsample=upsample, padding_type=padding_type) |
|
upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p)) |
|
down = downconv |
|
up = [uprelu] + upconv |
|
if upnorm is not None: |
|
up += [upnorm] |
|
up +=[uprelu2, uppad, upconv2] |
|
model = down + [submodule] + up |
|
elif innermost: |
|
upconv = upsampleLayer(inner_nc, outer_nc, upsample=upsample, padding_type=padding_type) |
|
upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p)) |
|
down = [downrelu] + downconv |
|
up = [uprelu] + upconv |
|
if upnorm is not None: |
|
up += [upnorm] |
|
up += [uprelu2, uppad, upconv2] |
|
if upnorm2 is not None: |
|
up += [upnorm2] |
|
model = down + up |
|
else: |
|
upconv = upsampleLayer(inner_nc * 1, outer_nc, upsample=upsample, padding_type=padding_type) |
|
upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p)) |
|
down = [downrelu] + downconv |
|
if downnorm is not None: |
|
down += [downnorm] |
|
up = [uprelu] + upconv |
|
if upnorm is not None: |
|
up += [upnorm] |
|
up += [uprelu2, uppad, upconv2] |
|
if upnorm2 is not None: |
|
up += [upnorm2] |
|
|
|
if use_dropout: |
|
model = down + [submodule] + up + [nn.Dropout(0.5)] |
|
else: |
|
model = down + [submodule] + up |
|
|
|
self.model = nn.Sequential(*model) |
|
|
|
def forward(self, x): |
|
if self.outermost: |
|
return self.model(x) |
|
else: |
|
x2 = self.model(x) |
|
if self.noise: |
|
x2 = self.noiseblock(x2, self.noise) |
|
if x2.shape[-1]==x.shape[-1]: |
|
return x2 + x |
|
else: |
|
x2 = F.interpolate(x2, x.shape[2:]) |
|
return x2 + x |
|
|
|
|
|
class E_ResNet(nn.Module): |
|
def __init__(self, input_nc=3, output_nc=1, ndf=64, n_blocks=4, |
|
norm_layer=None, nl_layer=None, vaeLike=False): |
|
super(E_ResNet, self).__init__() |
|
self.vaeLike = vaeLike |
|
max_ndf = 4 |
|
conv_layers = [ |
|
nn.Conv2d(input_nc, ndf, kernel_size=3, stride=2, padding=1, bias=True)] |
|
for n in range(1, n_blocks): |
|
input_ndf = ndf * min(max_ndf, n) |
|
output_ndf = ndf * min(max_ndf, n + 1) |
|
conv_layers += [BasicBlock(input_ndf, |
|
output_ndf, norm_layer, nl_layer)] |
|
conv_layers += [nl_layer(), nn.AdaptiveAvgPool2d(4)] |
|
if vaeLike: |
|
self.fc = nn.Sequential(*[nn.Linear(output_ndf * 16, output_nc)]) |
|
self.fcVar = nn.Sequential(*[nn.Linear(output_ndf * 16, output_nc)]) |
|
else: |
|
self.fc = nn.Sequential(*[nn.Linear(output_ndf * 16, output_nc)]) |
|
self.conv = nn.Sequential(*conv_layers) |
|
|
|
def forward(self, x): |
|
x_conv = self.conv(x) |
|
conv_flat = x_conv.view(x.size(0), -1) |
|
output = self.fc(conv_flat) |
|
if self.vaeLike: |
|
outputVar = self.fcVar(conv_flat) |
|
return output, outputVar |
|
else: |
|
return output |
|
return output |
|
|
|
|
|
|
|
|
|
|
|
|
|
class G_Unet_add_all(nn.Module): |
|
def __init__(self, input_nc, output_nc, nz, num_downs, ngf=64, |
|
norm_layer=None, nl_layer=None, use_dropout=False, use_noise=False, upsample='basic'): |
|
super(G_Unet_add_all, self).__init__() |
|
self.nz = nz |
|
self.mapping = G_mapping(self.nz, self.nz, 512, normalize_latents=False, lrmul=1) |
|
self.truncation_psi = 0 |
|
self.truncation_cutoff = 0 |
|
|
|
|
|
|
|
num_layers = int(np.log2(512)) * 2 - 2 |
|
|
|
self.noise_inputs = [] |
|
for layer_idx in range(num_layers): |
|
res = layer_idx // 2 + 2 |
|
shape = [1, 1, 2 ** res, 2 ** res] |
|
self.noise_inputs.append(torch.randn(*shape).to("cuda")) |
|
|
|
|
|
unet_block = UnetBlock_with_z(ngf * 8, ngf * 8, ngf * 8, nz, submodule=None, innermost=True, |
|
norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample) |
|
unet_block = UnetBlock_with_z(ngf * 8, ngf * 8, ngf * 8, nz, submodule=unet_block, |
|
norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample) |
|
for i in range(num_downs - 6): |
|
unet_block = UnetBlock_with_z(ngf * 8, ngf * 8, ngf * 8, nz, submodule=unet_block, |
|
norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample) |
|
unet_block = UnetBlock_with_z(ngf * 4, ngf * 4, ngf * 8, nz, submodule=unet_block, |
|
norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample) |
|
unet_block = UnetBlock_with_z(ngf * 2, ngf * 2, ngf * 4, nz, submodule=unet_block, |
|
norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample) |
|
unet_block = UnetBlock_with_z(ngf, ngf, ngf * 2, nz, submodule=unet_block, |
|
norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample) |
|
unet_block = UnetBlock_with_z(input_nc, output_nc, ngf, nz, submodule=unet_block, |
|
outermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample) |
|
self.model = unet_block |
|
|
|
def forward(self, x, z): |
|
|
|
dlatents1, num_layers = self.mapping(z) |
|
dlatents1 = dlatents1.unsqueeze(1) |
|
dlatents1 = dlatents1.expand(-1, int(num_layers), -1) |
|
|
|
|
|
if self.truncation_psi and self.truncation_cutoff: |
|
coefs = np.ones([1, num_layers, 1], dtype=np.float32) |
|
for i in range(num_layers): |
|
if i < self.truncation_cutoff: |
|
coefs[:, i, :] *= self.truncation_psi |
|
"""Linear interpolation. |
|
a + (b - a) * t (a = 0) |
|
reduce to |
|
b * t |
|
""" |
|
dlatents1 = dlatents1 * torch.Tensor(coefs).to(dlatents1.device) |
|
|
|
return torch.tanh(self.model(x, dlatents1, self.noise_inputs)) |
|
|
|
|
|
class ApplyNoise(nn.Module): |
|
def __init__(self, channels): |
|
super().__init__() |
|
self.channels = channels |
|
self.weight = nn.Parameter(torch.randn(channels), requires_grad=True) |
|
self.bias = nn.Parameter(torch.zeros(channels), requires_grad=True) |
|
|
|
def forward(self, x, noise): |
|
W,_ = torch.split(self.weight.view(1, -1, 1, 1), self.channels // 2, dim=1) |
|
B,_ = torch.split(self.bias.view(1, -1, 1, 1), self.channels // 2, dim=1) |
|
Z = torch.zeros_like(W) |
|
w = torch.cat([W,Z], dim=1).to(x.device) |
|
b = torch.cat([B,Z], dim=1).to(x.device) |
|
adds = w * torch.randn_like(x) + b |
|
return x + adds.type_as(x) |
|
|
|
|
|
class FC(nn.Module): |
|
def __init__(self, |
|
in_channels, |
|
out_channels, |
|
gain=2**(0.5), |
|
use_wscale=False, |
|
lrmul=1.0, |
|
bias=True): |
|
""" |
|
The complete conversion of Dense/FC/Linear Layer of original Tensorflow version. |
|
""" |
|
super(FC, self).__init__() |
|
he_std = gain * in_channels ** (-0.5) |
|
if use_wscale: |
|
init_std = 1.0 / lrmul |
|
self.w_lrmul = he_std * lrmul |
|
else: |
|
init_std = he_std / lrmul |
|
self.w_lrmul = lrmul |
|
|
|
self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels) * init_std) |
|
if bias: |
|
self.bias = torch.nn.Parameter(torch.zeros(out_channels)) |
|
self.b_lrmul = lrmul |
|
else: |
|
self.bias = None |
|
|
|
def forward(self, x): |
|
if self.bias is not None: |
|
out = F.linear(x, self.weight * self.w_lrmul, self.bias * self.b_lrmul) |
|
else: |
|
out = F.linear(x, self.weight * self.w_lrmul) |
|
out = F.leaky_relu(out, 0.2, inplace=True) |
|
return out |
|
|
|
|
|
class ApplyStyle(nn.Module): |
|
""" |
|
@ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb |
|
""" |
|
def __init__(self, latent_size, channels, use_wscale, nl_layer): |
|
super(ApplyStyle, self).__init__() |
|
modules = [nn.Linear(latent_size, channels*2)] |
|
if nl_layer: |
|
modules += [nl_layer()] |
|
self.linear = nn.Sequential(*modules) |
|
|
|
def forward(self, x, latent): |
|
style = self.linear(latent) |
|
shape = [-1, 2, x.size(1), 1, 1] |
|
style = style.view(shape) |
|
x = x * (style[:, 0] + 1.) + style[:, 1] |
|
return x |
|
|
|
class PixelNorm(nn.Module): |
|
def __init__(self, epsilon=1e-8): |
|
""" |
|
@notice: avoid in-place ops. |
|
https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3 |
|
""" |
|
super(PixelNorm, self).__init__() |
|
self.epsilon = epsilon |
|
|
|
def forward(self, x): |
|
tmp = torch.mul(x, x) |
|
tmp1 = torch.rsqrt(torch.mean(tmp, dim=1, keepdim=True) + self.epsilon) |
|
|
|
return x * tmp1 |
|
|
|
|
|
class InstanceNorm(nn.Module): |
|
def __init__(self, epsilon=1e-8): |
|
""" |
|
@notice: avoid in-place ops. |
|
https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3 |
|
""" |
|
super(InstanceNorm, self).__init__() |
|
self.epsilon = epsilon |
|
|
|
def forward(self, x): |
|
x = x - torch.mean(x, (2, 3), True) |
|
tmp = torch.mul(x, x) |
|
tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon) |
|
return x * tmp |
|
|
|
|
|
class LayerEpilogue(nn.Module): |
|
def __init__(self, channels, dlatent_size, use_wscale, use_noise, |
|
use_pixel_norm, use_instance_norm, use_styles, nl_layer=None): |
|
super(LayerEpilogue, self).__init__() |
|
self.use_noise = use_noise |
|
if use_noise: |
|
self.noise = ApplyNoise(channels) |
|
self.act = nn.LeakyReLU(negative_slope=0.2) |
|
|
|
if use_pixel_norm: |
|
self.pixel_norm = PixelNorm() |
|
else: |
|
self.pixel_norm = None |
|
|
|
if use_instance_norm: |
|
self.instance_norm = InstanceNorm() |
|
else: |
|
self.instance_norm = None |
|
|
|
if use_styles: |
|
self.style_mod = ApplyStyle(dlatent_size, channels, use_wscale=use_wscale, nl_layer=nl_layer) |
|
else: |
|
self.style_mod = None |
|
|
|
def forward(self, x, noise, dlatents_in_slice=None): |
|
|
|
if self.use_noise: |
|
x = self.noise(x, noise) |
|
x = self.act(x) |
|
if self.pixel_norm is not None: |
|
x = self.pixel_norm(x) |
|
if self.instance_norm is not None: |
|
x = self.instance_norm(x) |
|
if self.style_mod is not None: |
|
x = self.style_mod(x, dlatents_in_slice) |
|
|
|
return x |
|
|
|
class G_mapping(nn.Module): |
|
def __init__(self, |
|
mapping_fmaps=512, |
|
dlatent_size=512, |
|
resolution=512, |
|
normalize_latents=True, |
|
use_wscale=True, |
|
lrmul=0.01, |
|
gain=2**(0.5), |
|
nl_layer=None |
|
): |
|
super(G_mapping, self).__init__() |
|
self.mapping_fmaps = mapping_fmaps |
|
func = [ |
|
nn.Linear(self.mapping_fmaps, dlatent_size) |
|
] |
|
if nl_layer: |
|
func += [nl_layer()] |
|
|
|
for j in range(0,4): |
|
func += [ |
|
nn.Linear(dlatent_size, dlatent_size) |
|
] |
|
if nl_layer: |
|
func += [nl_layer()] |
|
|
|
self.func = nn.Sequential(*func) |
|
|
|
|
|
|
|
self.normalize_latents = normalize_latents |
|
self.resolution_log2 = int(np.log2(resolution)) |
|
self.num_layers = self.resolution_log2 * 2 - 2 |
|
self.pixel_norm = PixelNorm() |
|
|
|
|
|
|
|
def forward(self, x): |
|
if self.normalize_latents: |
|
x = self.pixel_norm(x) |
|
out = self.func(x) |
|
return out, self.num_layers |
|
|
|
class UnetBlock_with_z(nn.Module): |
|
def __init__(self, input_nc, outer_nc, inner_nc, nz=0, |
|
submodule=None, outermost=False, innermost=False, |
|
norm_layer=None, nl_layer=None, use_dropout=False, |
|
upsample='basic', padding_type='replicate'): |
|
super(UnetBlock_with_z, self).__init__() |
|
p = 0 |
|
downconv = [] |
|
if padding_type == 'reflect': |
|
downconv += [nn.ReflectionPad2d(1)] |
|
elif padding_type == 'replicate': |
|
downconv += [nn.ReplicationPad2d(1)] |
|
elif padding_type == 'zero': |
|
p = 1 |
|
else: |
|
raise NotImplementedError( |
|
'padding [%s] is not implemented' % padding_type) |
|
|
|
self.outermost = outermost |
|
self.innermost = innermost |
|
self.nz = nz |
|
|
|
|
|
downconv += [spectral_norm(nn.Conv2d(input_nc, inner_nc, |
|
kernel_size=3, stride=2, padding=p))] |
|
|
|
downrelu = nn.LeakyReLU(0.2, True) |
|
downnorm = norm_layer(inner_nc) if norm_layer is not None else None |
|
uprelu = nl_layer() |
|
uprelu2 = nl_layer() |
|
uppad = nn.ReplicationPad2d(1) |
|
upnorm = norm_layer(outer_nc) if norm_layer is not None else None |
|
upnorm2 = norm_layer(outer_nc) if norm_layer is not None else None |
|
|
|
use_styles=False |
|
uprelu = nl_layer() |
|
if self.nz >0: |
|
use_styles=True |
|
|
|
if outermost: |
|
self.adaIn = LayerEpilogue(inner_nc, self.nz, use_wscale=True, use_noise=False, |
|
use_pixel_norm=True, use_instance_norm=True, use_styles=use_styles, nl_layer=nl_layer) |
|
upconv = upsampleLayer( |
|
inner_nc , outer_nc, upsample=upsample, padding_type=padding_type) |
|
upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p)) |
|
down = downconv |
|
up = [uprelu] + upconv |
|
if upnorm is not None: |
|
up += [upnorm] |
|
up +=[uprelu2, uppad, upconv2] |
|
elif innermost: |
|
self.adaIn = LayerEpilogue(inner_nc, self.nz, use_wscale=True, use_noise=True, |
|
use_pixel_norm=True, use_instance_norm=True, use_styles=use_styles, nl_layer=nl_layer) |
|
upconv = upsampleLayer( |
|
inner_nc, outer_nc, upsample=upsample, padding_type=padding_type) |
|
upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p)) |
|
down = [downrelu] + downconv |
|
up = [uprelu] + upconv |
|
if norm_layer is not None: |
|
up += [norm_layer(outer_nc)] |
|
up += [uprelu2, uppad, upconv2] |
|
if upnorm2 is not None: |
|
up += [upnorm2] |
|
else: |
|
self.adaIn = LayerEpilogue(inner_nc, self.nz, use_wscale=True, use_noise=False, |
|
use_pixel_norm=True, use_instance_norm=True, use_styles=use_styles, nl_layer=nl_layer) |
|
upconv = upsampleLayer( |
|
inner_nc , outer_nc, upsample=upsample, padding_type=padding_type) |
|
upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p)) |
|
down = [downrelu] + downconv |
|
if norm_layer is not None: |
|
down += [norm_layer(inner_nc)] |
|
up = [uprelu] + upconv |
|
|
|
if norm_layer is not None: |
|
up += [norm_layer(outer_nc)] |
|
up += [uprelu2, uppad, upconv2] |
|
if upnorm2 is not None: |
|
up += [upnorm2] |
|
|
|
if use_dropout: |
|
up += [nn.Dropout(0.5)] |
|
self.down = nn.Sequential(*down) |
|
self.submodule = submodule |
|
self.up = nn.Sequential(*up) |
|
|
|
|
|
def forward(self, x, z, noise): |
|
if self.outermost: |
|
x1 = self.down(x) |
|
x2 = self.submodule(x1, z[:,2:], noise[2:]) |
|
return self.up(x2) |
|
|
|
elif self.innermost: |
|
x1 = self.down(x) |
|
x_and_z = self.adaIn(x1, noise[0], z[:,0]) |
|
x2 = self.up(x_and_z) |
|
x2 = F.interpolate(x2, x.shape[2:]) |
|
return x2 + x |
|
|
|
else: |
|
x1 = self.down(x) |
|
x2 = self.submodule(x1, z[:,2:], noise[2:]) |
|
x_and_z = self.adaIn(x2, noise[0], z[:,0]) |
|
return self.up(x_and_z) + x |
|
|
|
|
|
class E_NLayers(nn.Module): |
|
def __init__(self, input_nc, output_nc=1, ndf=64, n_layers=4, |
|
norm_layer=None, nl_layer=None, vaeLike=False): |
|
super(E_NLayers, self).__init__() |
|
self.vaeLike = vaeLike |
|
|
|
kw, padw = 3, 1 |
|
sequence = [spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kw, |
|
stride=2, padding=padw, padding_mode='replicate')), nl_layer()] |
|
|
|
nf_mult = 1 |
|
nf_mult_prev = 1 |
|
for n in range(1, n_layers): |
|
nf_mult_prev = nf_mult |
|
nf_mult = min(2**n, 8) |
|
sequence += [spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, |
|
kernel_size=kw, stride=2, padding=padw, padding_mode='replicate'))] |
|
if norm_layer is not None: |
|
sequence += [norm_layer(ndf * nf_mult)] |
|
sequence += [nl_layer()] |
|
sequence += [nn.AdaptiveAvgPool2d(4)] |
|
self.conv = nn.Sequential(*sequence) |
|
self.fc = nn.Sequential(*[spectral_norm(nn.Linear(ndf * nf_mult * 16, output_nc))]) |
|
if vaeLike: |
|
self.fcVar = nn.Sequential(*[spectral_norm(nn.Linear(ndf * nf_mult * 16, output_nc))]) |
|
|
|
def forward(self, x): |
|
x_conv = self.conv(x) |
|
conv_flat = x_conv.view(x.size(0), -1) |
|
output = self.fc(conv_flat) |
|
if self.vaeLike: |
|
outputVar = self.fcVar(conv_flat) |
|
return output, outputVar |
|
return output |
|
|
|
class BasicBlock(nn.Module): |
|
def __init__(self, inplanes, outplanes): |
|
super(BasicBlock, self).__init__() |
|
layers = [] |
|
norm_layer=get_norm_layer(norm_type='layer') |
|
|
|
nl_layer=nn.ReLU() |
|
if norm_layer is not None: |
|
layers += [norm_layer(inplanes)] |
|
layers += [nl_layer] |
|
layers += [nn.ReplicationPad2d(1), |
|
nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=1, |
|
padding=0, bias=True)] |
|
self.conv = nn.Sequential(*layers) |
|
|
|
def forward(self, x): |
|
return self.conv(x) |
|
|
|
|
|
def define_SVAE(inc=96, outc=3, outplanes=64, blocks=1, netVAE='SVAE', model_name='', load_ext=True, save_dir='', |
|
init_type="normal", init_gain=0.02, gpu_ids=[]): |
|
if netVAE == 'SVAE': |
|
net = ScreenVAE(inc=inc, outc=outc, outplanes=outplanes, blocks=blocks, save_dir=save_dir, |
|
init_type=init_type, init_gain=init_gain, gpu_ids=gpu_ids) |
|
else: |
|
raise NotImplementedError('Encoder model name [%s] is not recognized' % net) |
|
init_net(net, init_type=init_type, init_gain=init_gain, gpu_ids=gpu_ids) |
|
net.load_networks('latest') |
|
return net |
|
|
|
|
|
class ScreenVAE(nn.Module): |
|
def __init__(self,inc=1,outc=4, outplanes=64, downs=5, blocks=2,load_ext=True, save_dir='',init_type="normal", init_gain=0.02, gpu_ids=[]): |
|
super(ScreenVAE, self).__init__() |
|
self.inc = inc |
|
self.outc = outc |
|
self.save_dir = save_dir |
|
norm_layer=functools.partial(LayerNormWarpper) |
|
nl_layer=nn.LeakyReLU |
|
|
|
self.model_names=['enc','dec'] |
|
self.enc=define_C(inc+1, outc*2, 0, 24, netC='resnet_6blocks', |
|
norm='layer', nl='lrelu', use_dropout=True, init_type='kaiming', |
|
gpu_ids=gpu_ids, upsample='bilinear') |
|
self.dec=define_G(outc, inc, 0, 48, netG='unet_128_G', |
|
norm='layer', nl='lrelu', use_dropout=True, init_type='kaiming', |
|
gpu_ids=gpu_ids, where_add='input', upsample='bilinear', use_noise=True) |
|
|
|
for param in self.parameters(): |
|
param.requires_grad = False |
|
|
|
def load_networks(self, epoch): |
|
"""Load all the networks from the disk. |
|
|
|
Parameters: |
|
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) |
|
""" |
|
for name in self.model_names: |
|
if isinstance(name, str): |
|
load_filename = '%s_net_%s.pth' % (epoch, name) |
|
load_path = os.path.join(self.save_dir, load_filename) |
|
net = getattr(self, name) |
|
if isinstance(net, torch.nn.DataParallel): |
|
net = net.module |
|
print('loading the model from %s' % load_path) |
|
state_dict = torch.load( |
|
load_path, map_location=lambda storage, loc: storage.cuda()) |
|
if hasattr(state_dict, '_metadata'): |
|
del state_dict._metadata |
|
|
|
net.load_state_dict(state_dict) |
|
del state_dict |
|
|
|
def npad(self, im, pad=128): |
|
h,w = im.shape[-2:] |
|
hp = h //pad*pad+pad |
|
wp = w //pad*pad+pad |
|
return F.pad(im, (0, wp-w, 0, hp-h), mode='replicate') |
|
|
|
def forward(self, x, line=None, img_input=True, output_screen_only=True): |
|
if img_input: |
|
if line is None: |
|
line = torch.ones_like(x) |
|
else: |
|
line = torch.sign(line) |
|
x = torch.clamp(x + (1-line),-1,1) |
|
h,w = x.shape[-2:] |
|
input = torch.cat([x, line], 1) |
|
input = self.npad(input) |
|
inter = self.enc(input)[:,:,:h,:w] |
|
scr, logvar = torch.split(inter, (self.outc, self.outc), dim=1) |
|
if output_screen_only: |
|
return scr |
|
recons = self.dec(scr) |
|
return recons, scr, logvar |
|
else: |
|
h,w = x.shape[-2:] |
|
x = self.npad(x) |
|
recons = self.dec(x)[:,:,:h,:w] |
|
recons = (recons+1)*(line+1)/2-1 |
|
return torch.clamp(recons,-1,1) |
|
|