Anonymous-123's picture
Add application file
ec0fdfd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
####################################################################################################
# adversarial loss for different gan mode
####################################################################################################
class GANLoss(nn.Module):
"""Define different GAN objectives.
The GANLoss class abstracts away the need to create the target label tensor
that has the same size as the input.
"""
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
""" Initialize the GANLoss class.
Parameters:
gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
target_real_label (bool) - - label for a real image
target_fake_label (bool) - - label of a fake image
Note: Do not use sigmoid as the last layer of Discriminator.
LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
"""
super(GANLoss, self).__init__()
self.register_buffer('real_label', torch.tensor(target_real_label))
self.register_buffer('fake_label', torch.tensor(target_fake_label))
self.gan_mode = gan_mode
if gan_mode == 'lsgan':
self.loss = nn.MSELoss()
elif gan_mode == 'vanilla':
self.loss = nn.BCEWithLogitsLoss()
elif gan_mode == 'hinge':
self.loss = nn.ReLU()
elif gan_mode in ['wgangp', 'nonsaturating']:
self.loss = None
else:
raise NotImplementedError('gan mode %s not implemented' % gan_mode)
def get_target_tensor(self, prediction, target_is_real):
"""Create label tensors with the same size as the input.
Parameters:
prediction (tensor) - - tpyically the prediction from a discriminator
target_is_real (bool) - - if the ground truth label is for real examples or fake examples
Returns:
A label tensor filled with ground truth label, and with the size of the input
"""
if target_is_real:
target_tensor = self.real_label
else:
target_tensor = self.fake_label
return target_tensor.expand_as(prediction)
def calculate_loss(self, prediction, target_is_real, is_dis=False):
"""Calculate loss given Discriminator's output and grount truth labels.
Parameters:
prediction (tensor) - - tpyically the prediction output from a discriminator
target_is_real (bool) - - if the ground truth label is for real examples or fake examples
Returns:
the calculated loss.
"""
if self.gan_mode in ['lsgan', 'vanilla']:
target_tensor = self.get_target_tensor(prediction, target_is_real)
loss = self.loss(prediction, target_tensor)
if self.gan_mode == 'lsgan':
loss = loss * 0.5
else:
if is_dis:
if target_is_real:
prediction = -prediction
if self.gan_mode == 'wgangp':
loss = prediction.mean()
elif self.gan_mode == 'nonsaturating':
loss = F.softplus(prediction).mean()
elif self.gan_mode == 'hinge':
loss = self.loss(1+prediction).mean()
else:
if self.gan_mode == 'nonsaturating':
loss = F.softplus(-prediction).mean()
else:
loss = -prediction.mean()
return loss
def __call__(self, predictions, target_is_real, is_dis=False):
"""Calculate loss for multi-scales gan"""
if isinstance(predictions, list):
losses = []
for prediction in predictions:
losses.append(self.calculate_loss(prediction, target_is_real, is_dis))
loss = sum(losses)
else:
loss = self.calculate_loss(predictions, target_is_real, is_dis)
return loss
def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
"""Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
Arguments:
netD (network) -- discriminator network
real_data (tensor array) -- real examples
fake_data (tensor array) -- generated examples from the generator
device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
type (str) -- if we mix real and fake data or not [real | fake | mixed].
constant (float) -- the constant used in formula ( ||gradient||_2 - constant)^2
lambda_gp (float) -- weight for this loss
Returns the gradient penalty loss
"""
if lambda_gp > 0.0:
if type == 'real': # either use real examples, fake examples, or a linear interpolation of two.
interpolatesv = real_data
elif type == 'fake':
interpolatesv = fake_data
elif type == 'mixed':
alpha = torch.rand(real_data.shape[0], 1, device=device)
alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
else:
raise NotImplementedError('{} not implemented'.format(type))
interpolatesv.requires_grad_(True)
disc_interpolates = netD(interpolatesv)
if isinstance(disc_interpolates, list):
gradients = 0
for disc_interpolate in disc_interpolates:
gradients += torch.autograd.grad(outputs=disc_interpolate, inputs=interpolatesv,
grad_outputs=torch.ones(disc_interpolate.size()).to(device),
create_graph=True, retain_graph=True, only_inputs=True)[0]
else:
gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
grad_outputs=torch.ones(disc_interpolates.size()).to(device),
create_graph=True, retain_graph=True, only_inputs=True)[0]
gradients = gradients.view(real_data.size(0), -1) # flat the data
gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
return gradient_penalty, gradients
else:
return 0.0, None
####################################################################################################
# trained LPIPS loss
####################################################################################################
def normalize_tensor(x, eps=1e-10):
norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
return x/(norm_factor+eps)
def spatial_average(x, keepdim=True):
return x.mean([2, 3], keepdim=keepdim)
class NetLinLayer(nn.Module):
""" A single linear layer which does a 1x1 conv """
def __init__(self, chn_in, chn_out=1, use_dropout=False):
super(NetLinLayer, self).__init__()
layers = [nn.Dropout(), ] if (use_dropout) else []
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
self.model = nn.Sequential(*layers)
class LPIPSLoss(nn.Module):
"""
Learned perceptual metric
https://github.com/richzhang/PerceptualSimilarity
"""
def __init__(self, use_dropout=True, ckpt_path=None):
super(LPIPSLoss, self).__init__()
self.path = ckpt_path
self.net = VGG16()
self.chns = [64, 128, 256, 512, 512] # vg16 features
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
self.load_from_pretrained()
for param in self.parameters():
param.requires_grad = False
def load_from_pretrained(self):
self.load_state_dict(torch.load(self.path, map_location=torch.device("cpu")), strict=False)
print("loaded pretrained LPIPS loss from {}".format(self.path))
def _get_features(self, vgg_f):
names = ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']
feats = []
for i in range(len(names)):
name = names[i]
feat = vgg_f[name]
feats.append(feat)
return feats
def forward(self, x, y):
x_vgg, y_vgg = self._get_features(self.net(x)), self._get_features(self.net(y))
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
reses = []
loss = 0
for i in range(len(self.chns)):
x_feats, y_feats = normalize_tensor(x_vgg[i]), normalize_tensor(y_vgg[i])
diffs = (x_feats - y_feats) ** 2
res = spatial_average(lins[i].model(diffs))
loss += res
reses.append(res)
return loss
class PerceptualLoss(nn.Module):
r"""
Perceptual loss, VGG-based
https://arxiv.org/abs/1603.08155
https://github.com/dxyang/StyleTransfer/blob/master/utils.py
"""
def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 0.0]):
super(PerceptualLoss, self).__init__()
self.add_module('vgg', VGG16())
self.criterion = nn.L1Loss()
self.weights = weights
def __call__(self, x, y):
# Compute features
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
content_loss = 0.0
content_loss += self.weights[0] * self.criterion(x_vgg['relu1_2'], y_vgg['relu1_2']) if self.weights[0] > 0 else 0
content_loss += self.weights[1] * self.criterion(x_vgg['relu2_2'], y_vgg['relu2_2']) if self.weights[1] > 0 else 0
content_loss += self.weights[2] * self.criterion(x_vgg['relu3_3'], y_vgg['relu3_3']) if self.weights[2] > 0 else 0
content_loss += self.weights[3] * self.criterion(x_vgg['relu4_3'], y_vgg['relu4_3']) if self.weights[3] > 0 else 0
content_loss += self.weights[4] * self.criterion(x_vgg['relu5_3'], y_vgg['relu5_3']) if self.weights[4] > 0 else 0
return content_loss
class Normalization(nn.Module):
def __init__(self, device):
super(Normalization, self).__init__()
# .view the mean and std to make them [C x 1 x 1] so that they can
# directly work with image Tensor of shape [B x C x H x W].
# B is batch size. C is number of channels. H is height and W is width.
mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
std = torch.tensor([0.229, 0.224, 0.225]).to(device)
self.mean = mean.view(-1, 1, 1)
self.std = std.view(-1, 1, 1)
def forward(self, img):
# normalize img
return (img - self.mean) / self.std
class VGG16(nn.Module):
def __init__(self):
super(VGG16, self).__init__()
features = models.vgg16(pretrained=True).features
self.relu1_1 = torch.nn.Sequential()
self.relu1_2 = torch.nn.Sequential()
self.relu2_1 = torch.nn.Sequential()
self.relu2_2 = torch.nn.Sequential()
self.relu3_1 = torch.nn.Sequential()
self.relu3_2 = torch.nn.Sequential()
self.relu3_3 = torch.nn.Sequential()
self.relu4_1 = torch.nn.Sequential()
self.relu4_2 = torch.nn.Sequential()
self.relu4_3 = torch.nn.Sequential()
self.relu5_1 = torch.nn.Sequential()
self.relu5_2 = torch.nn.Sequential()
self.relu5_3 = torch.nn.Sequential()
for x in range(2):
self.relu1_1.add_module(str(x), features[x])
for x in range(2, 4):
self.relu1_2.add_module(str(x), features[x])
for x in range(4, 7):
self.relu2_1.add_module(str(x), features[x])
for x in range(7, 9):
self.relu2_2.add_module(str(x), features[x])
for x in range(9, 12):
self.relu3_1.add_module(str(x), features[x])
for x in range(12, 14):
self.relu3_2.add_module(str(x), features[x])
for x in range(14, 16):
self.relu3_3.add_module(str(x), features[x])
for x in range(16, 18):
self.relu4_1.add_module(str(x), features[x])
for x in range(18, 21):
self.relu4_2.add_module(str(x), features[x])
for x in range(21, 23):
self.relu4_3.add_module(str(x), features[x])
for x in range(23, 26):
self.relu5_1.add_module(str(x), features[x])
for x in range(26, 28):
self.relu5_2.add_module(str(x), features[x])
for x in range(28, 30):
self.relu5_3.add_module(str(x), features[x])
# don't need the gradients, just want the features
for param in self.parameters():
param.requires_grad = False
def forward(self, x,):
relu1_1 = self.relu1_1(x)
relu1_2 = self.relu1_2(relu1_1)
relu2_1 = self.relu2_1(relu1_2)
relu2_2 = self.relu2_2(relu2_1)
relu3_1 = self.relu3_1(relu2_2)
relu3_2 = self.relu3_2(relu3_1)
relu3_3 = self.relu3_3(relu3_2)
relu4_1 = self.relu4_1(relu3_3)
relu4_2 = self.relu4_2(relu4_1)
relu4_3 = self.relu4_3(relu4_2)
relu5_1 = self.relu5_1(relu4_3)
relu5_2 = self.relu5_2(relu5_1)
relu5_3 = self.relu5_3(relu5_2)
out = {
'relu1_1': relu1_1,
'relu1_2': relu1_2,
'relu2_1': relu2_1,
'relu2_2': relu2_2,
'relu3_1': relu3_1,
'relu3_2': relu3_2,
'relu3_3': relu3_3,
'relu4_1': relu4_1,
'relu4_2': relu4_2,
'relu4_3': relu4_3,
'relu5_1': relu5_1,
'relu5_2': relu5_2,
'relu5_3': relu5_3,
}
return out