|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torchvision.models as models |
|
|
|
|
|
|
|
|
|
|
|
class GANLoss(nn.Module): |
|
"""Define different GAN objectives. |
|
|
|
The GANLoss class abstracts away the need to create the target label tensor |
|
that has the same size as the input. |
|
""" |
|
|
|
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): |
|
""" Initialize the GANLoss class. |
|
|
|
Parameters: |
|
gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. |
|
target_real_label (bool) - - label for a real image |
|
target_fake_label (bool) - - label of a fake image |
|
|
|
Note: Do not use sigmoid as the last layer of Discriminator. |
|
LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. |
|
""" |
|
super(GANLoss, self).__init__() |
|
self.register_buffer('real_label', torch.tensor(target_real_label)) |
|
self.register_buffer('fake_label', torch.tensor(target_fake_label)) |
|
self.gan_mode = gan_mode |
|
if gan_mode == 'lsgan': |
|
self.loss = nn.MSELoss() |
|
elif gan_mode == 'vanilla': |
|
self.loss = nn.BCEWithLogitsLoss() |
|
elif gan_mode == 'hinge': |
|
self.loss = nn.ReLU() |
|
elif gan_mode in ['wgangp', 'nonsaturating']: |
|
self.loss = None |
|
else: |
|
raise NotImplementedError('gan mode %s not implemented' % gan_mode) |
|
|
|
def get_target_tensor(self, prediction, target_is_real): |
|
"""Create label tensors with the same size as the input. |
|
|
|
Parameters: |
|
prediction (tensor) - - tpyically the prediction from a discriminator |
|
target_is_real (bool) - - if the ground truth label is for real examples or fake examples |
|
|
|
Returns: |
|
A label tensor filled with ground truth label, and with the size of the input |
|
""" |
|
|
|
if target_is_real: |
|
target_tensor = self.real_label |
|
else: |
|
target_tensor = self.fake_label |
|
return target_tensor.expand_as(prediction) |
|
|
|
def calculate_loss(self, prediction, target_is_real, is_dis=False): |
|
"""Calculate loss given Discriminator's output and grount truth labels. |
|
|
|
Parameters: |
|
prediction (tensor) - - tpyically the prediction output from a discriminator |
|
target_is_real (bool) - - if the ground truth label is for real examples or fake examples |
|
|
|
Returns: |
|
the calculated loss. |
|
""" |
|
if self.gan_mode in ['lsgan', 'vanilla']: |
|
target_tensor = self.get_target_tensor(prediction, target_is_real) |
|
loss = self.loss(prediction, target_tensor) |
|
if self.gan_mode == 'lsgan': |
|
loss = loss * 0.5 |
|
else: |
|
if is_dis: |
|
if target_is_real: |
|
prediction = -prediction |
|
if self.gan_mode == 'wgangp': |
|
loss = prediction.mean() |
|
elif self.gan_mode == 'nonsaturating': |
|
loss = F.softplus(prediction).mean() |
|
elif self.gan_mode == 'hinge': |
|
loss = self.loss(1+prediction).mean() |
|
else: |
|
if self.gan_mode == 'nonsaturating': |
|
loss = F.softplus(-prediction).mean() |
|
else: |
|
loss = -prediction.mean() |
|
return loss |
|
|
|
def __call__(self, predictions, target_is_real, is_dis=False): |
|
"""Calculate loss for multi-scales gan""" |
|
if isinstance(predictions, list): |
|
losses = [] |
|
for prediction in predictions: |
|
losses.append(self.calculate_loss(prediction, target_is_real, is_dis)) |
|
loss = sum(losses) |
|
else: |
|
loss = self.calculate_loss(predictions, target_is_real, is_dis) |
|
|
|
return loss |
|
|
|
|
|
def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0): |
|
"""Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028 |
|
|
|
Arguments: |
|
netD (network) -- discriminator network |
|
real_data (tensor array) -- real examples |
|
fake_data (tensor array) -- generated examples from the generator |
|
device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') |
|
type (str) -- if we mix real and fake data or not [real | fake | mixed]. |
|
constant (float) -- the constant used in formula ( ||gradient||_2 - constant)^2 |
|
lambda_gp (float) -- weight for this loss |
|
|
|
Returns the gradient penalty loss |
|
""" |
|
if lambda_gp > 0.0: |
|
if type == 'real': |
|
interpolatesv = real_data |
|
elif type == 'fake': |
|
interpolatesv = fake_data |
|
elif type == 'mixed': |
|
alpha = torch.rand(real_data.shape[0], 1, device=device) |
|
alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape) |
|
interpolatesv = alpha * real_data + ((1 - alpha) * fake_data) |
|
else: |
|
raise NotImplementedError('{} not implemented'.format(type)) |
|
interpolatesv.requires_grad_(True) |
|
disc_interpolates = netD(interpolatesv) |
|
if isinstance(disc_interpolates, list): |
|
gradients = 0 |
|
for disc_interpolate in disc_interpolates: |
|
gradients += torch.autograd.grad(outputs=disc_interpolate, inputs=interpolatesv, |
|
grad_outputs=torch.ones(disc_interpolate.size()).to(device), |
|
create_graph=True, retain_graph=True, only_inputs=True)[0] |
|
else: |
|
gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv, |
|
grad_outputs=torch.ones(disc_interpolates.size()).to(device), |
|
create_graph=True, retain_graph=True, only_inputs=True)[0] |
|
gradients = gradients.view(real_data.size(0), -1) |
|
gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp |
|
return gradient_penalty, gradients |
|
else: |
|
return 0.0, None |
|
|
|
|
|
|
|
|
|
|
|
def normalize_tensor(x, eps=1e-10): |
|
norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) |
|
return x/(norm_factor+eps) |
|
|
|
|
|
def spatial_average(x, keepdim=True): |
|
return x.mean([2, 3], keepdim=keepdim) |
|
|
|
|
|
class NetLinLayer(nn.Module): |
|
""" A single linear layer which does a 1x1 conv """ |
|
def __init__(self, chn_in, chn_out=1, use_dropout=False): |
|
super(NetLinLayer, self).__init__() |
|
layers = [nn.Dropout(), ] if (use_dropout) else [] |
|
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] |
|
self.model = nn.Sequential(*layers) |
|
|
|
|
|
class LPIPSLoss(nn.Module): |
|
""" |
|
Learned perceptual metric |
|
https://github.com/richzhang/PerceptualSimilarity |
|
""" |
|
def __init__(self, use_dropout=True, ckpt_path=None): |
|
super(LPIPSLoss, self).__init__() |
|
self.path = ckpt_path |
|
self.net = VGG16() |
|
self.chns = [64, 128, 256, 512, 512] |
|
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) |
|
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) |
|
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) |
|
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) |
|
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) |
|
self.load_from_pretrained() |
|
for param in self.parameters(): |
|
param.requires_grad = False |
|
|
|
def load_from_pretrained(self): |
|
self.load_state_dict(torch.load(self.path, map_location=torch.device("cpu")), strict=False) |
|
print("loaded pretrained LPIPS loss from {}".format(self.path)) |
|
|
|
def _get_features(self, vgg_f): |
|
names = ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'] |
|
feats = [] |
|
for i in range(len(names)): |
|
name = names[i] |
|
feat = vgg_f[name] |
|
feats.append(feat) |
|
return feats |
|
|
|
def forward(self, x, y): |
|
x_vgg, y_vgg = self._get_features(self.net(x)), self._get_features(self.net(y)) |
|
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] |
|
reses = [] |
|
loss = 0 |
|
|
|
for i in range(len(self.chns)): |
|
x_feats, y_feats = normalize_tensor(x_vgg[i]), normalize_tensor(y_vgg[i]) |
|
diffs = (x_feats - y_feats) ** 2 |
|
res = spatial_average(lins[i].model(diffs)) |
|
loss += res |
|
reses.append(res) |
|
|
|
return loss |
|
|
|
|
|
class PerceptualLoss(nn.Module): |
|
r""" |
|
Perceptual loss, VGG-based |
|
https://arxiv.org/abs/1603.08155 |
|
https://github.com/dxyang/StyleTransfer/blob/master/utils.py |
|
""" |
|
|
|
def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 0.0]): |
|
super(PerceptualLoss, self).__init__() |
|
self.add_module('vgg', VGG16()) |
|
self.criterion = nn.L1Loss() |
|
self.weights = weights |
|
|
|
def __call__(self, x, y): |
|
|
|
x_vgg, y_vgg = self.vgg(x), self.vgg(y) |
|
|
|
content_loss = 0.0 |
|
content_loss += self.weights[0] * self.criterion(x_vgg['relu1_2'], y_vgg['relu1_2']) if self.weights[0] > 0 else 0 |
|
content_loss += self.weights[1] * self.criterion(x_vgg['relu2_2'], y_vgg['relu2_2']) if self.weights[1] > 0 else 0 |
|
content_loss += self.weights[2] * self.criterion(x_vgg['relu3_3'], y_vgg['relu3_3']) if self.weights[2] > 0 else 0 |
|
content_loss += self.weights[3] * self.criterion(x_vgg['relu4_3'], y_vgg['relu4_3']) if self.weights[3] > 0 else 0 |
|
content_loss += self.weights[4] * self.criterion(x_vgg['relu5_3'], y_vgg['relu5_3']) if self.weights[4] > 0 else 0 |
|
|
|
return content_loss |
|
|
|
|
|
class Normalization(nn.Module): |
|
def __init__(self, device): |
|
super(Normalization, self).__init__() |
|
|
|
|
|
|
|
mean = torch.tensor([0.485, 0.456, 0.406]).to(device) |
|
std = torch.tensor([0.229, 0.224, 0.225]).to(device) |
|
self.mean = mean.view(-1, 1, 1) |
|
self.std = std.view(-1, 1, 1) |
|
|
|
def forward(self, img): |
|
|
|
return (img - self.mean) / self.std |
|
|
|
|
|
class VGG16(nn.Module): |
|
def __init__(self): |
|
super(VGG16, self).__init__() |
|
features = models.vgg16(pretrained=True).features |
|
self.relu1_1 = torch.nn.Sequential() |
|
self.relu1_2 = torch.nn.Sequential() |
|
|
|
self.relu2_1 = torch.nn.Sequential() |
|
self.relu2_2 = torch.nn.Sequential() |
|
|
|
self.relu3_1 = torch.nn.Sequential() |
|
self.relu3_2 = torch.nn.Sequential() |
|
self.relu3_3 = torch.nn.Sequential() |
|
|
|
self.relu4_1 = torch.nn.Sequential() |
|
self.relu4_2 = torch.nn.Sequential() |
|
self.relu4_3 = torch.nn.Sequential() |
|
|
|
self.relu5_1 = torch.nn.Sequential() |
|
self.relu5_2 = torch.nn.Sequential() |
|
self.relu5_3 = torch.nn.Sequential() |
|
|
|
for x in range(2): |
|
self.relu1_1.add_module(str(x), features[x]) |
|
|
|
for x in range(2, 4): |
|
self.relu1_2.add_module(str(x), features[x]) |
|
|
|
for x in range(4, 7): |
|
self.relu2_1.add_module(str(x), features[x]) |
|
|
|
for x in range(7, 9): |
|
self.relu2_2.add_module(str(x), features[x]) |
|
|
|
for x in range(9, 12): |
|
self.relu3_1.add_module(str(x), features[x]) |
|
|
|
for x in range(12, 14): |
|
self.relu3_2.add_module(str(x), features[x]) |
|
|
|
for x in range(14, 16): |
|
self.relu3_3.add_module(str(x), features[x]) |
|
|
|
for x in range(16, 18): |
|
self.relu4_1.add_module(str(x), features[x]) |
|
|
|
for x in range(18, 21): |
|
self.relu4_2.add_module(str(x), features[x]) |
|
|
|
for x in range(21, 23): |
|
self.relu4_3.add_module(str(x), features[x]) |
|
|
|
for x in range(23, 26): |
|
self.relu5_1.add_module(str(x), features[x]) |
|
|
|
for x in range(26, 28): |
|
self.relu5_2.add_module(str(x), features[x]) |
|
|
|
for x in range(28, 30): |
|
self.relu5_3.add_module(str(x), features[x]) |
|
|
|
|
|
for param in self.parameters(): |
|
param.requires_grad = False |
|
|
|
def forward(self, x,): |
|
relu1_1 = self.relu1_1(x) |
|
relu1_2 = self.relu1_2(relu1_1) |
|
|
|
relu2_1 = self.relu2_1(relu1_2) |
|
relu2_2 = self.relu2_2(relu2_1) |
|
|
|
relu3_1 = self.relu3_1(relu2_2) |
|
relu3_2 = self.relu3_2(relu3_1) |
|
relu3_3 = self.relu3_3(relu3_2) |
|
|
|
relu4_1 = self.relu4_1(relu3_3) |
|
relu4_2 = self.relu4_2(relu4_1) |
|
relu4_3 = self.relu4_3(relu4_2) |
|
|
|
relu5_1 = self.relu5_1(relu4_3) |
|
relu5_2 = self.relu5_2(relu5_1) |
|
relu5_3 = self.relu5_3(relu5_2) |
|
|
|
out = { |
|
'relu1_1': relu1_1, |
|
'relu1_2': relu1_2, |
|
|
|
'relu2_1': relu2_1, |
|
'relu2_2': relu2_2, |
|
|
|
'relu3_1': relu3_1, |
|
'relu3_2': relu3_2, |
|
'relu3_3': relu3_3, |
|
|
|
'relu4_1': relu4_1, |
|
'relu4_2': relu4_2, |
|
'relu4_3': relu4_3, |
|
|
|
'relu5_1': relu5_1, |
|
'relu5_2': relu5_2, |
|
'relu5_3': relu5_3, |
|
} |
|
return out |