antoniomae1234's picture
changes in flenema
2493d72 verified
import torch
from torch import nn
from torch.nn import functional as F
class TorchSTFT(nn.Module):
def __init__(self, n_fft, hop_length, win_length, window='hann_window'):
""" Torch based STFT operation """
super(TorchSTFT, self).__init__()
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.window = nn.Parameter(getattr(torch, window)(win_length),
requires_grad=False)
def __call__(self, x):
# B x D x T x 2
o = torch.stft(x,
self.n_fft,
self.hop_length,
self.win_length,
self.window,
center=True,
pad_mode="reflect", # compatible with audio.py
normalized=False,
onesided=True,
return_complex=False)
M = o[:, :, :, 0]
P = o[:, :, :, 1]
return torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8))
#################################
# GENERATOR LOSSES
#################################
class STFTLoss(nn.Module):
""" Single scale STFT Loss """
def __init__(self, n_fft, hop_length, win_length):
super(STFTLoss, self).__init__()
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.stft = TorchSTFT(n_fft, hop_length, win_length)
def forward(self, y_hat, y):
y_hat_M = self.stft(y_hat)
y_M = self.stft(y)
# magnitude loss
loss_mag = F.l1_loss(torch.log(y_M), torch.log(y_hat_M))
# spectral convergence loss
loss_sc = torch.norm(y_M - y_hat_M, p="fro") / torch.norm(y_M, p="fro")
return loss_mag, loss_sc
class MultiScaleSTFTLoss(torch.nn.Module):
""" Multi scale STFT loss """
def __init__(self,
n_ffts=(1024, 2048, 512),
hop_lengths=(120, 240, 50),
win_lengths=(600, 1200, 240)):
super(MultiScaleSTFTLoss, self).__init__()
self.loss_funcs = torch.nn.ModuleList()
for n_fft, hop_length, win_length in zip(n_ffts, hop_lengths, win_lengths):
self.loss_funcs.append(STFTLoss(n_fft, hop_length, win_length))
def forward(self, y_hat, y):
N = len(self.loss_funcs)
loss_sc = 0
loss_mag = 0
for f in self.loss_funcs:
lm, lsc = f(y_hat, y)
loss_mag += lm
loss_sc += lsc
loss_sc /= N
loss_mag /= N
return loss_mag, loss_sc
class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss):
""" Multiscale STFT loss for multi band model outputs """
# pylint: disable=no-self-use
def forward(self, y_hat, y):
y_hat = y_hat.view(-1, 1, y_hat.shape[2])
y = y.view(-1, 1, y.shape[2])
return super().forward(y_hat.squeeze(1), y.squeeze(1))
class MSEGLoss(nn.Module):
""" Mean Squared Generator Loss """
# pylint: disable=no-self-use
def forward(self, score_real):
loss_fake = F.mse_loss(score_real, score_real.new_ones(score_real.shape))
return loss_fake
class HingeGLoss(nn.Module):
""" Hinge Discriminator Loss """
# pylint: disable=no-self-use
def forward(self, score_real):
# TODO: this might be wrong
loss_fake = torch.mean(F.relu(1. - score_real))
return loss_fake
##################################
# DISCRIMINATOR LOSSES
##################################
class MSEDLoss(nn.Module):
""" Mean Squared Discriminator Loss """
def __init__(self,):
super(MSEDLoss, self).__init__()
self.loss_func = nn.MSELoss()
# pylint: disable=no-self-use
def forward(self, score_fake, score_real):
loss_real = self.loss_func(score_real, score_real.new_ones(score_real.shape))
loss_fake = self.loss_func(score_fake, score_fake.new_zeros(score_fake.shape))
loss_d = loss_real + loss_fake
return loss_d, loss_real, loss_fake
class HingeDLoss(nn.Module):
""" Hinge Discriminator Loss """
# pylint: disable=no-self-use
def forward(self, score_fake, score_real):
loss_real = torch.mean(F.relu(1. - score_real))
loss_fake = torch.mean(F.relu(1. + score_fake))
loss_d = loss_real + loss_fake
return loss_d, loss_real, loss_fake
class MelganFeatureLoss(nn.Module):
def __init__(self,):
super(MelganFeatureLoss, self).__init__()
self.loss_func = nn.L1Loss()
# pylint: disable=no-self-use
def forward(self, fake_feats, real_feats):
loss_feats = 0
for fake_feat, real_feat in zip(fake_feats, real_feats):
loss_feats += self.loss_func(fake_feat, real_feat)
loss_feats /= len(fake_feats) + len(real_feats)
return loss_feats
#####################################
# LOSS WRAPPERS
#####################################
def _apply_G_adv_loss(scores_fake, loss_func):
""" Compute G adversarial loss function
and normalize values """
adv_loss = 0
if isinstance(scores_fake, list):
for score_fake in scores_fake:
fake_loss = loss_func(score_fake)
adv_loss += fake_loss
adv_loss /= len(scores_fake)
else:
fake_loss = loss_func(scores_fake)
adv_loss = fake_loss
return adv_loss
def _apply_D_loss(scores_fake, scores_real, loss_func):
""" Compute D loss func and normalize loss values """
loss = 0
real_loss = 0
fake_loss = 0
if isinstance(scores_fake, list):
# multi-scale loss
for score_fake, score_real in zip(scores_fake, scores_real):
total_loss, real_loss, fake_loss = loss_func(score_fake=score_fake, score_real=score_real)
loss += total_loss
real_loss += real_loss
fake_loss += fake_loss
# normalize loss values with number of scales
loss /= len(scores_fake)
real_loss /= len(scores_real)
fake_loss /= len(scores_fake)
else:
# single scale loss
total_loss, real_loss, fake_loss = loss_func(scores_fake, scores_real)
loss = total_loss
return loss, real_loss, fake_loss
##################################
# MODEL LOSSES
##################################
class GeneratorLoss(nn.Module):
def __init__(self, C):
""" Compute Generator Loss values depending on training
configuration """
super(GeneratorLoss, self).__init__()
assert not(C.use_mse_gan_loss and C.use_hinge_gan_loss),\
" [!] Cannot use HingeGANLoss and MSEGANLoss together."
self.use_stft_loss = C.use_stft_loss
self.use_subband_stft_loss = C.use_subband_stft_loss
self.use_mse_gan_loss = C.use_mse_gan_loss
self.use_hinge_gan_loss = C.use_hinge_gan_loss
self.use_feat_match_loss = C.use_feat_match_loss
self.stft_loss_weight = C.stft_loss_weight
self.subband_stft_loss_weight = C.subband_stft_loss_weight
self.mse_gan_loss_weight = C.mse_G_loss_weight
self.hinge_gan_loss_weight = C.hinge_G_loss_weight
self.feat_match_loss_weight = C.feat_match_loss_weight
if C.use_stft_loss:
self.stft_loss = MultiScaleSTFTLoss(**C.stft_loss_params)
if C.use_subband_stft_loss:
self.subband_stft_loss = MultiScaleSubbandSTFTLoss(**C.subband_stft_loss_params)
if C.use_mse_gan_loss:
self.mse_loss = MSEGLoss()
if C.use_hinge_gan_loss:
self.hinge_loss = HingeGLoss()
if C.use_feat_match_loss:
self.feat_match_loss = MelganFeatureLoss()
def forward(self, y_hat=None, y=None, scores_fake=None, feats_fake=None, feats_real=None, y_hat_sub=None, y_sub=None):
gen_loss = 0
adv_loss = 0
return_dict = {}
# STFT Loss
if self.use_stft_loss:
stft_loss_mg, stft_loss_sc = self.stft_loss(y_hat.squeeze(1), y.squeeze(1))
return_dict['G_stft_loss_mg'] = stft_loss_mg
return_dict['G_stft_loss_sc'] = stft_loss_sc
gen_loss += self.stft_loss_weight * (stft_loss_mg + stft_loss_sc)
# subband STFT Loss
if self.use_subband_stft_loss:
subband_stft_loss_mg, subband_stft_loss_sc = self.subband_stft_loss(y_hat_sub, y_sub)
return_dict['G_subband_stft_loss_mg'] = subband_stft_loss_mg
return_dict['G_subband_stft_loss_sc'] = subband_stft_loss_sc
gen_loss += self.subband_stft_loss_weight * (subband_stft_loss_mg + subband_stft_loss_sc)
# multiscale MSE adversarial loss
if self.use_mse_gan_loss and scores_fake is not None:
mse_fake_loss = _apply_G_adv_loss(scores_fake, self.mse_loss)
return_dict['G_mse_fake_loss'] = mse_fake_loss
adv_loss += self.mse_gan_loss_weight * mse_fake_loss
# multiscale Hinge adversarial loss
if self.use_hinge_gan_loss and not scores_fake is not None:
hinge_fake_loss = _apply_G_adv_loss(scores_fake, self.hinge_loss)
return_dict['G_hinge_fake_loss'] = hinge_fake_loss
adv_loss += self.hinge_gan_loss_weight * hinge_fake_loss
# Feature Matching Loss
if self.use_feat_match_loss and not feats_fake:
feat_match_loss = self.feat_match_loss(feats_fake, feats_real)
return_dict['G_feat_match_loss'] = feat_match_loss
adv_loss += self.feat_match_loss_weight * feat_match_loss
return_dict['G_loss'] = gen_loss + adv_loss
return_dict['G_gen_loss'] = gen_loss
return_dict['G_adv_loss'] = adv_loss
return return_dict
class DiscriminatorLoss(nn.Module):
""" Compute Discriminator Loss values depending on training
configuration """
def __init__(self, C):
super(DiscriminatorLoss, self).__init__()
assert not(C.use_mse_gan_loss and C.use_hinge_gan_loss),\
" [!] Cannot use HingeGANLoss and MSEGANLoss together."
self.use_mse_gan_loss = C.use_mse_gan_loss
self.use_hinge_gan_loss = C.use_hinge_gan_loss
if C.use_mse_gan_loss:
self.mse_loss = MSEDLoss()
if C.use_hinge_gan_loss:
self.hinge_loss = HingeDLoss()
def forward(self, scores_fake, scores_real):
loss = 0
return_dict = {}
if self.use_mse_gan_loss:
mse_D_loss, mse_D_real_loss, mse_D_fake_loss = _apply_D_loss(
scores_fake=scores_fake,
scores_real=scores_real,
loss_func=self.mse_loss)
return_dict['D_mse_gan_loss'] = mse_D_loss
return_dict['D_mse_gan_real_loss'] = mse_D_real_loss
return_dict['D_mse_gan_fake_loss'] = mse_D_fake_loss
loss += mse_D_loss
if self.use_hinge_gan_loss:
hinge_D_loss, hinge_D_real_loss, hinge_D_fake_loss = _apply_D_loss(
scores_fake=scores_fake,
scores_real=scores_real,
loss_func=self.hinge_loss)
return_dict['D_hinge_gan_loss'] = hinge_D_loss
return_dict['D_hinge_gan_real_loss'] = hinge_D_real_loss
return_dict['D_hinge_gan_fake_loss'] = hinge_D_fake_loss
loss += hinge_D_loss
return_dict['D_loss'] = loss
return return_dict