Flux9665's picture
initial commit
6faeba1
raw
history blame
5 kB
# Copyright 2021 Tomoki Hayashi
# MIT License (https://opensource.org/licenses/MIT)
# Adapted by Florian Lux 2021
import librosa
import torch
import torch.nn.functional as F
class MelSpectrogram(torch.nn.Module):
def __init__(self,
fs=24000,
fft_size=1536,
hop_size=384,
win_length=None,
window="hann",
num_mels=100,
fmin=60,
fmax=None,
center=True,
normalized=False,
onesided=True,
eps=1e-10,
log_base=10.0, ):
super().__init__()
self.fft_size = fft_size
if win_length is None:
self.win_length = fft_size
else:
self.win_length = win_length
self.hop_size = hop_size
self.center = center
self.normalized = normalized
self.onesided = onesided
if window is not None and not hasattr(torch, f"{window}_window"):
raise ValueError(f"{window} window is not implemented")
self.window = window
self.eps = eps
fmin = 0 if fmin is None else fmin
fmax = fs / 2 if fmax is None else fmax
melmat = librosa.filters.mel(sr=fs,
n_fft=fft_size,
n_mels=num_mels,
fmin=fmin,
fmax=fmax, )
self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
self.stft_params = {
"n_fft" : self.fft_size,
"win_length": self.win_length,
"hop_length": self.hop_size,
"center" : self.center,
"normalized": self.normalized,
"onesided" : self.onesided,
}
self.stft_params["return_complex"] = False
self.log_base = log_base
if self.log_base is None:
self.log = torch.log
elif self.log_base == 2.0:
self.log = torch.log2
elif self.log_base == 10.0:
self.log = torch.log10
else:
raise ValueError(f"log_base: {log_base} is not supported.")
def forward(self, x):
"""
Calculate Mel-spectrogram.
Args:
x (Tensor): Input waveform tensor (B, T) or (B, 1, T).
Returns:
Tensor: Mel-spectrogram (B, #mels, #frames).
"""
if x.dim() == 3:
# (B, C, T) -> (B*C, T)
x = x.reshape(-1, x.size(2))
if self.window is not None:
window_func = getattr(torch, f"{self.window}_window")
window = window_func(self.win_length, dtype=x.dtype, device=x.device)
else:
window = None
x_stft = torch.stft(x, window=window, **self.stft_params)
# (B, #freqs, #frames, 2) -> (B, $frames, #freqs, 2)
x_stft = x_stft.transpose(1, 2)
x_power = x_stft[..., 0] ** 2 + x_stft[..., 1] ** 2
x_amp = torch.sqrt(torch.clamp(x_power, min=self.eps))
x_mel = torch.matmul(x_amp, self.melmat)
x_mel = torch.clamp(x_mel, min=self.eps)
return self.log(x_mel).transpose(1, 2)
class MelSpectrogramLoss(torch.nn.Module):
def __init__(self,
fs=24000,
fft_size=1024,
hop_size=256,
win_length=None,
window="hann",
num_mels=128,
fmin=20,
fmax=None,
center=True,
normalized=False,
onesided=True,
eps=1e-10,
log_base=10.0, ):
super().__init__()
self.mel_spectrogram = MelSpectrogram(fs=fs,
fft_size=fft_size,
hop_size=hop_size,
win_length=win_length,
window=window,
num_mels=num_mels,
fmin=fmin,
fmax=fmax,
center=center,
normalized=normalized,
onesided=onesided,
eps=eps,
log_base=log_base, )
def forward(self, y_hat, y):
"""
Calculate Mel-spectrogram loss.
Args:
y_hat (Tensor): Generated single tensor (B, 1, T).
y (Tensor): Groundtruth single tensor (B, 1, T).
Returns:
Tensor: Mel-spectrogram loss value.
"""
mel_hat = self.mel_spectrogram(y_hat)
mel = self.mel_spectrogram(y)
mel_loss = F.l1_loss(mel_hat, mel)
return mel_loss