Nick088's picture
added audio sr files, adapted them to zerogpu and optimization for memory
fa90792
raw
history blame
2.53 kB
import torch
import numpy as np
from scipy.io.wavfile import write
import torchaudio
from audiosr.utilities.audio.audio_processing import griffin_lim
def pad_wav(waveform, segment_length):
waveform_length = waveform.shape[-1]
assert waveform_length > 100, "Waveform is too short, %s" % waveform_length
if segment_length is None or waveform_length == segment_length:
return waveform
elif waveform_length > segment_length:
return waveform[:segment_length]
elif waveform_length < segment_length:
temp_wav = np.zeros((1, segment_length))
temp_wav[:, :waveform_length] = waveform
return temp_wav
def normalize_wav(waveform):
waveform = waveform - np.mean(waveform)
waveform = waveform / (np.max(np.abs(waveform)) + 1e-8)
return waveform * 0.5
def read_wav_file(filename, segment_length):
# waveform, sr = librosa.load(filename, sr=None, mono=True) # 4 times slower
waveform, sr = torchaudio.load(filename) # Faster!!!
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)
waveform = waveform.numpy()[0, ...]
waveform = normalize_wav(waveform)
waveform = waveform[None, ...]
waveform = pad_wav(waveform, segment_length)
waveform = waveform / np.max(np.abs(waveform))
waveform = 0.5 * waveform
return waveform
def get_mel_from_wav(audio, _stft):
audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1)
audio = torch.autograd.Variable(audio, requires_grad=False)
melspec, magnitudes, phases, energy = _stft.mel_spectrogram(audio)
melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32)
magnitudes = torch.squeeze(magnitudes, 0).numpy().astype(np.float32)
energy = torch.squeeze(energy, 0).numpy().astype(np.float32)
return melspec, magnitudes, energy
def inv_mel_spec(mel, out_filename, _stft, griffin_iters=60):
mel = torch.stack([mel])
mel_decompress = _stft.spectral_de_normalize(mel)
mel_decompress = mel_decompress.transpose(1, 2).data.cpu()
spec_from_mel_scaling = 1000
spec_from_mel = torch.mm(mel_decompress[0], _stft.mel_basis)
spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0)
spec_from_mel = spec_from_mel * spec_from_mel_scaling
audio = griffin_lim(
torch.autograd.Variable(spec_from_mel[:, :, :-1]), _stft._stft_fn, griffin_iters
)
audio = audio.squeeze()
audio = audio.cpu().numpy()
audio_path = out_filename
write(audio_path, _stft.sampling_rate, audio)