File size: 3,871 Bytes
528df8b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import torch
import torch.utils.data
from librosa.filters import mel as librosa_mel_fn
import logging
logger = logging.getLogger(__name__)
MAX_WAV_VALUE = 32768.0
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
"""
PARAMS
------
C: compression factor
"""
return torch.log(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression_torch(x, C=1):
"""
PARAMS
------
C: compression factor used to compress
"""
return torch.exp(x) / C
def spectral_normalize_torch(magnitudes):
return dynamic_range_compression_torch(magnitudes)
def spectral_de_normalize_torch(magnitudes):
return dynamic_range_decompression_torch(magnitudes)
# Reusable banks
mel_basis = {}
hann_window = {}
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
"""Convert waveform into Linear-frequency Linear-amplitude spectrogram.
Args:
y :: (B, T) - Audio waveforms
n_fft
sampling_rate
hop_size
win_size
center
Returns:
:: (B, Freq, Frame) - Linear-frequency Linear-amplitude spectrogram
"""
# Validation
if torch.min(y) < -1.07:
logger.debug("min value is %s", str(torch.min(y)))
if torch.max(y) > 1.07:
logger.debug("max value is %s", str(torch.max(y)))
# Window - Cache if needed
global hann_window
dtype_device = str(y.dtype) + "_" + str(y.device)
wnsize_dtype_device = str(win_size) + "_" + dtype_device
if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
dtype=y.dtype, device=y.device
)
# Padding
y = torch.nn.functional.pad(
y.unsqueeze(1),
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
mode="reflect",
)
y = y.squeeze(1)
# Complex Spectrogram :: (B, T) -> (B, Freq, Frame, RealComplex=2)
spec = torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=False,
)
# Linear-frequency Linear-amplitude spectrogram :: (B, Freq, Frame, RealComplex=2) -> (B, Freq, Frame)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
return spec
def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
# MelBasis - Cache if needed
global mel_basis
dtype_device = str(spec.dtype) + "_" + str(spec.device)
fmax_dtype_device = str(fmax) + "_" + dtype_device
if fmax_dtype_device not in mel_basis:
mel = librosa_mel_fn(
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
dtype=spec.dtype, device=spec.device
)
# Mel-frequency Log-amplitude spectrogram :: (B, Freq=num_mels, Frame)
melspec = torch.matmul(mel_basis[fmax_dtype_device], spec)
melspec = spectral_normalize_torch(melspec)
return melspec
def mel_spectrogram_torch(
y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
):
"""Convert waveform into Mel-frequency Log-amplitude spectrogram.
Args:
y :: (B, T) - Waveforms
Returns:
melspec :: (B, Freq, Frame) - Mel-frequency Log-amplitude spectrogram
"""
# Linear-frequency Linear-amplitude spectrogram :: (B, T) -> (B, Freq, Frame)
spec = spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center)
# Mel-frequency Log-amplitude spectrogram :: (B, Freq, Frame) -> (B, Freq=num_mels, Frame)
melspec = spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax)
return melspec
|