TextureScraping / libs /losses.py
sunshineatnoon
Add application file
1b2a9b1
from libs.blocks import encoder5
import torch
import torchvision
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
from .normalization import get_nonspade_norm_layer
from .blocks import encoder5
import os
import numpy as np
class BaseNetwork(nn.Module):
def __init__(self):
super(BaseNetwork, self).__init__()
def print_network(self):
if isinstance(self, list):
self = self[0]
num_params = 0
for param in self.parameters():
num_params += param.numel()
print('Network [%s] was created. Total number of parameters: %.1f million. '
'To see the architecture, do print(network).'
% (type(self).__name__, num_params / 1000000))
def init_weights(self, init_type='normal', gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if classname.find('BatchNorm2d') != -1:
if hasattr(m, 'weight') and m.weight is not None:
init.normal_(m.weight.data, 1.0, gain)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'xavier_uniform':
init.xavier_uniform_(m.weight.data, gain=1.0)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=gain)
elif init_type == 'none': # uses pytorch's default init method
m.reset_parameters()
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
self.apply(init_func)
# propagate to children
for m in self.children():
if hasattr(m, 'init_weights'):
m.init_weights(init_type, gain)
class NLayerDiscriminator(BaseNetwork):
def __init__(self):
super().__init__()
kw = 4
padw = int(np.ceil((kw - 1.0) / 2))
nf = 64
n_layers_D = 4
input_nc = 3
norm_layer = get_nonspade_norm_layer('spectralinstance')
sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, False)]]
for n in range(1, n_layers_D):
nf_prev = nf
nf = min(nf * 2, 512)
stride = 1 if n == n_layers_D - 1 else 2
sequence += [[norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw,
stride=stride, padding=padw)),
nn.LeakyReLU(0.2, False)
]]
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
# We divide the layers into groups to extract intermediate layer outputs
for n in range(len(sequence)):
self.add_module('model' + str(n), nn.Sequential(*sequence[n]))
def forward(self, input, get_intermediate_features = True):
results = [input]
for submodel in self.children():
intermediate_output = submodel(results[-1])
results.append(intermediate_output)
if get_intermediate_features:
return results[1:]
else:
return results[-1]
class VGG19(torch.nn.Module):
def __init__(self, requires_grad=False):
super().__init__()
vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
for x in range(2):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(2, 7):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(7, 12):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(12, 21):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(21, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
import pdb; pdb.set_trace()
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h_relu1 = self.slice1(X)
h_relu2 = self.slice2(h_relu1)
h_relu3 = self.slice3(h_relu2)
h_relu4 = self.slice4(h_relu3)
h_relu5 = self.slice5(h_relu4)
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
return out
class encoder5(nn.Module):
def __init__(self):
super(encoder5,self).__init__()
# vgg
# 224 x 224
self.conv1 = nn.Conv2d(3,3,1,1,0)
self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1))
# 226 x 226
self.conv2 = nn.Conv2d(3,64,3,1,0)
self.relu2 = nn.ReLU(inplace=True)
# 224 x 224
self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1))
self.conv3 = nn.Conv2d(64,64,3,1,0)
self.relu3 = nn.ReLU(inplace=True)
# 224 x 224
self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2)
# 112 x 112
self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1))
self.conv4 = nn.Conv2d(64,128,3,1,0)
self.relu4 = nn.ReLU(inplace=True)
# 112 x 112
self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1))
self.conv5 = nn.Conv2d(128,128,3,1,0)
self.relu5 = nn.ReLU(inplace=True)
# 112 x 112
self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2)
# 56 x 56
self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1))
self.conv6 = nn.Conv2d(128,256,3,1,0)
self.relu6 = nn.ReLU(inplace=True)
# 56 x 56
self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1))
self.conv7 = nn.Conv2d(256,256,3,1,0)
self.relu7 = nn.ReLU(inplace=True)
# 56 x 56
self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1))
self.conv8 = nn.Conv2d(256,256,3,1,0)
self.relu8 = nn.ReLU(inplace=True)
# 56 x 56
self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1))
self.conv9 = nn.Conv2d(256,256,3,1,0)
self.relu9 = nn.ReLU(inplace=True)
# 56 x 56
self.maxPool3 = nn.MaxPool2d(kernel_size=2,stride=2)
# 28 x 28
self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1))
self.conv10 = nn.Conv2d(256,512,3,1,0)
self.relu10 = nn.ReLU(inplace=True)
self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1))
self.conv11 = nn.Conv2d(512,512,3,1,0)
self.relu11 = nn.ReLU(inplace=True)
self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1))
self.conv12 = nn.Conv2d(512,512,3,1,0)
self.relu12 = nn.ReLU(inplace=True)
self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1))
self.conv13 = nn.Conv2d(512,512,3,1,0)
self.relu13 = nn.ReLU(inplace=True)
self.maxPool4 = nn.MaxPool2d(kernel_size=2,stride=2)
self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1))
self.conv14 = nn.Conv2d(512,512,3,1,0)
self.relu14 = nn.ReLU(inplace=True)
def forward(self,x):
output = []
out = self.conv1(x)
out = self.reflecPad1(out)
out = self.conv2(out)
out = self.relu2(out)
output.append(out)
out = self.reflecPad3(out)
out = self.conv3(out)
out = self.relu3(out)
out = self.maxPool(out)
out = self.reflecPad4(out)
out = self.conv4(out)
out = self.relu4(out)
output.append(out)
out = self.reflecPad5(out)
out = self.conv5(out)
out = self.relu5(out)
out = self.maxPool2(out)
out = self.reflecPad6(out)
out = self.conv6(out)
out = self.relu6(out)
output.append(out)
out = self.reflecPad7(out)
out = self.conv7(out)
out = self.relu7(out)
out = self.reflecPad8(out)
out = self.conv8(out)
out = self.relu8(out)
out = self.reflecPad9(out)
out = self.conv9(out)
out = self.relu9(out)
out = self.maxPool3(out)
out = self.reflecPad10(out)
out = self.conv10(out)
out = self.relu10(out)
output.append(out)
out = self.reflecPad11(out)
out = self.conv11(out)
out = self.relu11(out)
out = self.reflecPad12(out)
out = self.conv12(out)
out = self.relu12(out)
out = self.reflecPad13(out)
out = self.conv13(out)
out = self.relu13(out)
out = self.maxPool4(out)
out = self.reflecPad14(out)
out = self.conv14(out)
out = self.relu14(out)
output.append(out)
return output
class VGGLoss(nn.Module):
def __init__(self, model_path):
super(VGGLoss, self).__init__()
self.vgg = encoder5().cuda()
self.vgg.load_state_dict(torch.load(os.path.join(model_path, 'vgg_r51.pth')))
self.criterion = nn.MSELoss()
self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
def forward(self, x, y):
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
loss = 0
for i in range(4):
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
return loss
class GANLoss(nn.Module):
def __init__(self, gan_mode = 'hinge', target_real_label=1.0, target_fake_label=0.0,
tensor=torch.cuda.FloatTensor):
super(GANLoss, self).__init__()
self.real_label = target_real_label
self.fake_label = target_fake_label
self.real_label_tensor = None
self.fake_label_tensor = None
self.zero_tensor = None
self.Tensor = tensor
self.gan_mode = gan_mode
if gan_mode == 'ls':
pass
elif gan_mode == 'original':
pass
elif gan_mode == 'w':
pass
elif gan_mode == 'hinge':
pass
else:
raise ValueError('Unexpected gan_mode {}'.format(gan_mode))
def get_target_tensor(self, input, target_is_real):
if target_is_real:
if self.real_label_tensor is None:
self.real_label_tensor = self.Tensor(1).fill_(self.real_label)
self.real_label_tensor.requires_grad_(False)
return self.real_label_tensor.expand_as(input)
else:
if self.fake_label_tensor is None:
self.fake_label_tensor = self.Tensor(1).fill_(self.fake_label)
self.fake_label_tensor.requires_grad_(False)
return self.fake_label_tensor.expand_as(input)
def get_zero_tensor(self, input):
if self.zero_tensor is None:
self.zero_tensor = self.Tensor(1).fill_(0)
self.zero_tensor.requires_grad_(False)
return self.zero_tensor.expand_as(input)
def loss(self, input, target_is_real, for_discriminator=True):
if self.gan_mode == 'original': # cross entropy loss
target_tensor = self.get_target_tensor(input, target_is_real)
loss = F.binary_cross_entropy_with_logits(input, target_tensor)
return loss
elif self.gan_mode == 'ls':
target_tensor = self.get_target_tensor(input, target_is_real)
return F.mse_loss(input, target_tensor)
elif self.gan_mode == 'hinge':
if for_discriminator:
if target_is_real:
minval = torch.min(input - 1, self.get_zero_tensor(input))
loss = -torch.mean(minval)
else:
minval = torch.min(-input - 1, self.get_zero_tensor(input))
loss = -torch.mean(minval)
else:
assert target_is_real, "The generator's hinge loss must be aiming for real"
loss = -torch.mean(input)
return loss
else:
# wgan
if target_is_real:
return -input.mean()
else:
return input.mean()
def __call__(self, input, target_is_real, for_discriminator=True):
# computing loss is a bit complicated because |input| may not be
# a tensor, but list of tensors in case of multiscale discriminator
if isinstance(input, list):
loss = 0
for pred_i in input:
if isinstance(pred_i, list):
pred_i = pred_i[-1]
loss_tensor = self.loss(pred_i, target_is_real, for_discriminator)
bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0)
new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1)
loss += new_loss
return loss / len(input)
else:
return self.loss(input, target_is_real, for_discriminator)
class SPADE_LOSS(nn.Module):
def __init__(self, model_path, lambda_feat = 1):
super(SPADE_LOSS, self).__init__()
self.criterionVGG = VGGLoss(model_path)
self.criterionGAN = GANLoss('hinge')
self.criterionL1 = nn.L1Loss()
self.discriminator = NLayerDiscriminator()
self.lambda_feat = lambda_feat
def forward(self, x, y, for_discriminator = False):
pred_real = self.discriminator(y)
if not for_discriminator:
pred_fake = self.discriminator(x)
VGGLoss = self.criterionVGG(x, y)
GANLoss = self.criterionGAN(pred_fake, True, for_discriminator = False)
# feature matching loss
# last output is the final prediction, so we exclude it
num_intermediate_outputs = len(pred_fake) - 1
GAN_Feat_loss = 0
for j in range(num_intermediate_outputs): # for each layer output
unweighted_loss = self.criterionL1(pred_fake[j], pred_real[j].detach())
GAN_Feat_loss += unweighted_loss * self.lambda_feat
L1Loss = self.criterionL1(x, y)
return VGGLoss, GANLoss, GAN_Feat_loss, L1Loss
else:
pred_fake = self.discriminator(x.detach())
GANLoss = self.criterionGAN(pred_fake, False, for_discriminator = True)
GANLoss += self.criterionGAN(pred_real, True, for_discriminator = True)
return GANLoss
class ContrastiveLoss(nn.Module):
"""
Contrastive loss
Takes embeddings of two samples and a target label == 1 if samples are from the same class and label == 0 otherwise
"""
def __init__(self, margin):
super(ContrastiveLoss, self).__init__()
self.margin = margin
self.eps = 1e-9
def forward(self, out1, out2, target, size_average=True, norm = True):
if norm:
output1 = out1 / out1.pow(2).sum(1, keepdim=True).sqrt()
output2 = out1 / out2.pow(2).sum(1, keepdim=True).sqrt()
distances = (output2 - output1).pow(2).sum(1) # squared distances
losses = 0.5 * (target.float() * distances +
(1 + -1 * target).float() * F.relu(self.margin - (distances + self.eps).sqrt()).pow(2))
return losses.mean() if size_average else losses.sum()