import torch.nn.functional as F import torch import numpy as np def mae(input, target): return torch.mean(torch.abs(input - target)) def logmae_wav(model, output_dict, target): loss = torch.log10(torch.clamp(mae(output_dict['wav'], target), 1e-8, np.inf)) return loss def get_loss_func(loss_type): if loss_type == 'logmae_wav': return logmae_wav elif loss_type == 'mae': return mae else: raise Exception('Incorrect loss_type!')