Spaces:
Running
on
Zero
Running
on
Zero
# -*- coding: utf-8 -*- | |
# Copyright 2021 Tomoki Hayashi | |
# MIT License (https://opensource.org/licenses/MIT) | |
import torch | |
import torch.nn.functional as F | |
def discriminator_adv_loss(disc_real_outputs, disc_generated_outputs): | |
loss = 0 | |
for dr, dg in zip(disc_real_outputs, disc_generated_outputs): | |
dr_fun, dr_dir = dr | |
dg_fun, dg_dir = dg | |
r_loss_fun = torch.mean(F.softplus(1 - dr_fun) ** 2) | |
g_loss_fun = torch.mean(F.softplus(dg_fun) ** 2) | |
r_loss_dir = torch.mean(F.softplus(1 - dr_dir) ** 2) | |
g_loss_dir = torch.mean(-F.softplus(1 - dg_dir) ** 2) | |
r_loss = r_loss_fun + r_loss_dir | |
g_loss = g_loss_fun + g_loss_dir | |
loss += (r_loss + g_loss) | |
return loss / len(disc_generated_outputs) | |
def generator_adv_loss(disc_outputs): | |
loss = 0 | |
for dg in disc_outputs: | |
l = torch.mean(F.softplus(1 - dg) ** 2) | |
loss += l | |
return loss / len(disc_outputs) | |
class GeneratorAdversarialLoss(torch.nn.Module): | |
def __init__(self, | |
average_by_discriminators=True, | |
loss_type="mse", ): | |
"""Initialize GeneratorAversarialLoss module.""" | |
super().__init__() | |
self.average_by_discriminators = average_by_discriminators | |
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." | |
if loss_type == "mse": | |
self.criterion = self._mse_loss | |
else: | |
self.criterion = self._hinge_loss | |
def forward(self, outputs): | |
""" | |
Calcualate generator adversarial loss. | |
Args: | |
outputs (Tensor or list): Discriminator outputs or list of | |
discriminator outputs. | |
Returns: | |
Tensor: Generator adversarial loss value. | |
""" | |
if isinstance(outputs, (tuple, list)): | |
adv_loss = 0.0 | |
for i, outputs_ in enumerate(outputs): | |
if isinstance(outputs_, (tuple, list)): | |
outputs_ = outputs_[-1] | |
adv_loss = adv_loss + self.criterion(outputs_) | |
if self.average_by_discriminators: | |
adv_loss /= i + 1 | |
else: | |
adv_loss = self.criterion(outputs) | |
return adv_loss | |
def _mse_loss(self, x): | |
return F.mse_loss(x, x.new_ones(x.size())) | |
def _hinge_loss(self, x): | |
return -x.mean() | |
class DiscriminatorAdversarialLoss(torch.nn.Module): | |
def __init__(self, | |
average_by_discriminators=True, | |
loss_type="mse", ): | |
super().__init__() | |
self.average_by_discriminators = average_by_discriminators | |
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." | |
if loss_type == "mse": | |
self.fake_criterion = self._mse_fake_loss | |
self.real_criterion = self._mse_real_loss | |
else: | |
self.fake_criterion = self._hinge_fake_loss | |
self.real_criterion = self._hinge_real_loss | |
def forward(self, outputs_hat, outputs): | |
""" | |
Calcualate discriminator adversarial loss. | |
Args: | |
outputs_hat (Tensor or list): Discriminator outputs or list of | |
discriminator outputs calculated from generator outputs. | |
outputs (Tensor or list): Discriminator outputs or list of | |
discriminator outputs calculated from groundtruth. | |
Returns: | |
Tensor: Discriminator real loss value. | |
Tensor: Discriminator fake loss value. | |
""" | |
if isinstance(outputs, (tuple, list)): | |
real_loss = 0.0 | |
fake_loss = 0.0 | |
for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)): | |
if isinstance(outputs_hat_, (tuple, list)): | |
outputs_hat_ = outputs_hat_[-1] | |
outputs_ = outputs_[-1] | |
real_loss = real_loss + self.real_criterion(outputs_) | |
fake_loss = fake_loss + self.fake_criterion(outputs_hat_) | |
if self.average_by_discriminators: | |
fake_loss /= i + 1 | |
real_loss /= i + 1 | |
else: | |
real_loss = self.real_criterion(outputs) | |
fake_loss = self.fake_criterion(outputs_hat) | |
return real_loss + fake_loss | |
def _mse_real_loss(self, x): | |
return F.mse_loss(x, x.new_ones(x.size())) | |
def _mse_fake_loss(self, x): | |
return F.mse_loss(x, x.new_zeros(x.size())) | |
def _hinge_real_loss(self, x): | |
return -torch.mean(torch.min(x - 1, x.new_zeros(x.size()))) | |
def _hinge_fake_loss(self, x): | |
return -torch.mean(torch.min(-x - 1, x.new_zeros(x.size()))) | |