Spaces:
Running
on
T4
Running
on
T4
File size: 5,003 Bytes
9e275b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
# Copyright 2021 Tomoki Hayashi
# MIT License (https://opensource.org/licenses/MIT)
# Adapted by Florian Lux 2021
import librosa
import torch
import torch.nn.functional as F
class MelSpectrogram(torch.nn.Module):
def __init__(self,
fs=24000,
fft_size=1536,
hop_size=384,
win_length=None,
window="hann",
num_mels=100,
fmin=60,
fmax=None,
center=True,
normalized=False,
onesided=True,
eps=1e-10,
log_base=10.0, ):
super().__init__()
self.fft_size = fft_size
if win_length is None:
self.win_length = fft_size
else:
self.win_length = win_length
self.hop_size = hop_size
self.center = center
self.normalized = normalized
self.onesided = onesided
if window is not None and not hasattr(torch, f"{window}_window"):
raise ValueError(f"{window} window is not implemented")
self.window = window
self.eps = eps
fmin = 0 if fmin is None else fmin
fmax = fs / 2 if fmax is None else fmax
melmat = librosa.filters.mel(sr=fs,
n_fft=fft_size,
n_mels=num_mels,
fmin=fmin,
fmax=fmax, )
self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
self.stft_params = {
"n_fft" : self.fft_size,
"win_length": self.win_length,
"hop_length": self.hop_size,
"center" : self.center,
"normalized": self.normalized,
"onesided" : self.onesided,
}
self.stft_params["return_complex"] = False
self.log_base = log_base
if self.log_base is None:
self.log = torch.log
elif self.log_base == 2.0:
self.log = torch.log2
elif self.log_base == 10.0:
self.log = torch.log10
else:
raise ValueError(f"log_base: {log_base} is not supported.")
def forward(self, x):
"""
Calculate Mel-spectrogram.
Args:
x (Tensor): Input waveform tensor (B, T) or (B, 1, T).
Returns:
Tensor: Mel-spectrogram (B, #mels, #frames).
"""
if x.dim() == 3:
# (B, C, T) -> (B*C, T)
x = x.reshape(-1, x.size(2))
if self.window is not None:
window_func = getattr(torch, f"{self.window}_window")
window = window_func(self.win_length, dtype=x.dtype, device=x.device)
else:
window = None
x_stft = torch.stft(x, window=window, **self.stft_params)
# (B, #freqs, #frames, 2) -> (B, $frames, #freqs, 2)
x_stft = x_stft.transpose(1, 2)
x_power = x_stft[..., 0] ** 2 + x_stft[..., 1] ** 2
x_amp = torch.sqrt(torch.clamp(x_power, min=self.eps))
x_mel = torch.matmul(x_amp, self.melmat)
x_mel = torch.clamp(x_mel, min=self.eps)
return self.log(x_mel).transpose(1, 2)
class MelSpectrogramLoss(torch.nn.Module):
def __init__(self,
fs=24000,
fft_size=1024,
hop_size=256,
win_length=None,
window="hann",
num_mels=128,
fmin=20,
fmax=None,
center=True,
normalized=False,
onesided=True,
eps=1e-10,
log_base=10.0, ):
super().__init__()
self.mel_spectrogram = MelSpectrogram(fs=fs,
fft_size=fft_size,
hop_size=hop_size,
win_length=win_length,
window=window,
num_mels=num_mels,
fmin=fmin,
fmax=fmax,
center=center,
normalized=normalized,
onesided=onesided,
eps=eps,
log_base=log_base, )
def forward(self, y_hat, y):
"""
Calculate Mel-spectrogram loss.
Args:
y_hat (Tensor): Generated single tensor (B, 1, T).
y (Tensor): Groundtruth single tensor (B, 1, T).
Returns:
Tensor: Mel-spectrogram loss value.
"""
mel_hat = self.mel_spectrogram(y_hat)
mel = self.mel_spectrogram(y)
mel_loss = F.l1_loss(mel_hat, mel)
return mel_loss
|