|
import utility
|
|
from types import SimpleNamespace
|
|
|
|
from model import common
|
|
from loss import discriminator
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.optim as optim
|
|
|
|
class Adversarial(nn.Module):
|
|
def __init__(self, args, gan_type):
|
|
super(Adversarial, self).__init__()
|
|
self.gan_type = gan_type
|
|
self.gan_k = args.gan_k
|
|
self.dis = discriminator.Discriminator(args)
|
|
if gan_type == 'WGAN_GP':
|
|
|
|
optim_dict = {
|
|
'optimizer': 'ADAM',
|
|
'betas': (0, 0.9),
|
|
'epsilon': 1e-8,
|
|
'lr': 1e-5,
|
|
'weight_decay': args.weight_decay,
|
|
'decay': args.decay,
|
|
'gamma': args.gamma
|
|
}
|
|
optim_args = SimpleNamespace(**optim_dict)
|
|
else:
|
|
optim_args = args
|
|
|
|
self.optimizer = utility.make_optimizer(optim_args, self.dis)
|
|
|
|
def forward(self, fake, real):
|
|
|
|
self.loss = 0
|
|
fake_detach = fake.detach()
|
|
for _ in range(self.gan_k):
|
|
self.optimizer.zero_grad()
|
|
|
|
d_fake = self.dis(fake_detach)
|
|
d_real = self.dis(real)
|
|
retain_graph = False
|
|
if self.gan_type == 'GAN':
|
|
loss_d = self.bce(d_real, d_fake)
|
|
elif self.gan_type.find('WGAN') >= 0:
|
|
loss_d = (d_fake - d_real).mean()
|
|
if self.gan_type.find('GP') >= 0:
|
|
epsilon = torch.rand_like(fake).view(-1, 1, 1, 1)
|
|
hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon)
|
|
hat.requires_grad = True
|
|
d_hat = self.dis(hat)
|
|
gradients = torch.autograd.grad(
|
|
outputs=d_hat.sum(), inputs=hat,
|
|
retain_graph=True, create_graph=True, only_inputs=True
|
|
)[0]
|
|
gradients = gradients.view(gradients.size(0), -1)
|
|
gradient_norm = gradients.norm(2, dim=1)
|
|
gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean()
|
|
loss_d += gradient_penalty
|
|
|
|
elif self.gan_type == 'RGAN':
|
|
better_real = d_real - d_fake.mean(dim=0, keepdim=True)
|
|
better_fake = d_fake - d_real.mean(dim=0, keepdim=True)
|
|
loss_d = self.bce(better_real, better_fake)
|
|
retain_graph = True
|
|
|
|
|
|
self.loss += loss_d.item()
|
|
loss_d.backward(retain_graph=retain_graph)
|
|
self.optimizer.step()
|
|
|
|
if self.gan_type == 'WGAN':
|
|
for p in self.dis.parameters():
|
|
p.data.clamp_(-1, 1)
|
|
|
|
self.loss /= self.gan_k
|
|
|
|
|
|
d_fake_bp = self.dis(fake)
|
|
if self.gan_type == 'GAN':
|
|
label_real = torch.ones_like(d_fake_bp)
|
|
loss_g = F.binary_cross_entropy_with_logits(d_fake_bp, label_real)
|
|
elif self.gan_type.find('WGAN') >= 0:
|
|
loss_g = -d_fake_bp.mean()
|
|
elif self.gan_type == 'RGAN':
|
|
better_real = d_real - d_fake_bp.mean(dim=0, keepdim=True)
|
|
better_fake = d_fake_bp - d_real.mean(dim=0, keepdim=True)
|
|
loss_g = self.bce(better_fake, better_real)
|
|
|
|
|
|
return loss_g
|
|
|
|
def state_dict(self, *args, **kwargs):
|
|
state_discriminator = self.dis.state_dict(*args, **kwargs)
|
|
state_optimizer = self.optimizer.state_dict()
|
|
|
|
return dict(**state_discriminator, **state_optimizer)
|
|
|
|
def bce(self, real, fake):
|
|
label_real = torch.ones_like(real)
|
|
label_fake = torch.zeros_like(fake)
|
|
bce_real = F.binary_cross_entropy_with_logits(real, label_real)
|
|
bce_fake = F.binary_cross_entropy_with_logits(fake, label_fake)
|
|
bce_loss = bce_real + bce_fake
|
|
return bce_loss
|
|
|
|
|
|
|
|
|
|
|
|
|