EnglishToucan / Architectures /Vocoder /AdversarialLoss.py
Flux9665's picture
initial commit
6faeba1
raw
history blame
4.66 kB
# -*- coding: utf-8 -*-
# Copyright 2021 Tomoki Hayashi
# MIT License (https://opensource.org/licenses/MIT)
import torch
import torch.nn.functional as F
def discriminator_adv_loss(disc_real_outputs, disc_generated_outputs):
loss = 0
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
dr_fun, dr_dir = dr
dg_fun, dg_dir = dg
r_loss_fun = torch.mean(F.softplus(1 - dr_fun) ** 2)
g_loss_fun = torch.mean(F.softplus(dg_fun) ** 2)
r_loss_dir = torch.mean(F.softplus(1 - dr_dir) ** 2)
g_loss_dir = torch.mean(-F.softplus(1 - dg_dir) ** 2)
r_loss = r_loss_fun + r_loss_dir
g_loss = g_loss_fun + g_loss_dir
loss += (r_loss + g_loss)
return loss / len(disc_generated_outputs)
def generator_adv_loss(disc_outputs):
loss = 0
for dg in disc_outputs:
l = torch.mean(F.softplus(1 - dg) ** 2)
loss += l
return loss / len(disc_outputs)
class GeneratorAdversarialLoss(torch.nn.Module):
def __init__(self,
average_by_discriminators=True,
loss_type="mse", ):
"""Initialize GeneratorAversarialLoss module."""
super().__init__()
self.average_by_discriminators = average_by_discriminators
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
if loss_type == "mse":
self.criterion = self._mse_loss
else:
self.criterion = self._hinge_loss
def forward(self, outputs):
"""
Calcualate generator adversarial loss.
Args:
outputs (Tensor or list): Discriminator outputs or list of
discriminator outputs.
Returns:
Tensor: Generator adversarial loss value.
"""
if isinstance(outputs, (tuple, list)):
adv_loss = 0.0
for i, outputs_ in enumerate(outputs):
if isinstance(outputs_, (tuple, list)):
outputs_ = outputs_[-1]
adv_loss = adv_loss + self.criterion(outputs_)
if self.average_by_discriminators:
adv_loss /= i + 1
else:
adv_loss = self.criterion(outputs)
return adv_loss
def _mse_loss(self, x):
return F.mse_loss(x, x.new_ones(x.size()))
def _hinge_loss(self, x):
return -x.mean()
class DiscriminatorAdversarialLoss(torch.nn.Module):
def __init__(self,
average_by_discriminators=True,
loss_type="mse", ):
super().__init__()
self.average_by_discriminators = average_by_discriminators
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
if loss_type == "mse":
self.fake_criterion = self._mse_fake_loss
self.real_criterion = self._mse_real_loss
else:
self.fake_criterion = self._hinge_fake_loss
self.real_criterion = self._hinge_real_loss
def forward(self, outputs_hat, outputs):
"""
Calcualate discriminator adversarial loss.
Args:
outputs_hat (Tensor or list): Discriminator outputs or list of
discriminator outputs calculated from generator outputs.
outputs (Tensor or list): Discriminator outputs or list of
discriminator outputs calculated from groundtruth.
Returns:
Tensor: Discriminator real loss value.
Tensor: Discriminator fake loss value.
"""
if isinstance(outputs, (tuple, list)):
real_loss = 0.0
fake_loss = 0.0
for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)):
if isinstance(outputs_hat_, (tuple, list)):
outputs_hat_ = outputs_hat_[-1]
outputs_ = outputs_[-1]
real_loss = real_loss + self.real_criterion(outputs_)
fake_loss = fake_loss + self.fake_criterion(outputs_hat_)
if self.average_by_discriminators:
fake_loss /= i + 1
real_loss /= i + 1
else:
real_loss = self.real_criterion(outputs)
fake_loss = self.fake_criterion(outputs_hat)
return real_loss + fake_loss
def _mse_real_loss(self, x):
return F.mse_loss(x, x.new_ones(x.size()))
def _mse_fake_loss(self, x):
return F.mse_loss(x, x.new_zeros(x.size()))
def _hinge_real_loss(self, x):
return -torch.mean(torch.min(x - 1, x.new_zeros(x.size())))
def _hinge_fake_loss(self, x):
return -torch.mean(torch.min(-x - 1, x.new_zeros(x.size())))