File size: 481 Bytes
530a7d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
import torch.nn.functional as F
import torch
import numpy as np
def mae(input, target):
return torch.mean(torch.abs(input - target))
def logmae_wav(model, output_dict, target):
loss = torch.log10(torch.clamp(mae(output_dict['wav'], target), 1e-8, np.inf))
return loss
def get_loss_func(loss_type):
if loss_type == 'logmae_wav':
return logmae_wav
elif loss_type == 'mae':
return mae
else:
raise Exception('Incorrect loss_type!')
|