Sebas / losses.py
Mudrock's picture
Upload 18 files
530a7d1
raw
history blame contribute delete
481 Bytes
import torch.nn.functional as F
import torch
import numpy as np
def mae(input, target):
return torch.mean(torch.abs(input - target))
def logmae_wav(model, output_dict, target):
loss = torch.log10(torch.clamp(mae(output_dict['wav'], target), 1e-8, np.inf))
return loss
def get_loss_func(loss_type):
if loss_type == 'logmae_wav':
return logmae_wav
elif loss_type == 'mae':
return mae
else:
raise Exception('Incorrect loss_type!')