EnglishToucan / Architectures /Vocoder /HiFiGAN_Dataset.py
Flux9665's picture
initial commit
6faeba1
raw
history blame
8.06 kB
import os
import random
from multiprocessing import Manager
from multiprocessing import Process
import librosa
import numpy
import soundfile as sf
import torch
import torchaudio
from torch.utils.data import Dataset
from tqdm import tqdm
from Preprocessing.AudioPreprocessor import AudioPreprocessor
def random_pitch_shifter(x):
n_steps = random.choice([-12, -9, -6, 3, 12]) # when using 12 steps per octave, these are the only ones that are pretty fast. I benchmarked it and the variance is many orders of magnitude.
return torchaudio.transforms.PitchShift(sample_rate=24000, n_steps=n_steps)(x)
def polarity_inverter(x):
return x * -1
class HiFiGANDataset(Dataset):
def __init__(self,
list_of_paths,
desired_samplingrate=24000,
samples_per_segment=12288, # = (8192 * 3) 2 , as I used 8192 for 16kHz previously
loading_processes=max(os.cpu_count() - 2, 1),
use_random_corruption=False):
self.use_random_corruption = use_random_corruption
self.samples_per_segment = samples_per_segment
self.desired_samplingrate = desired_samplingrate
self.melspec_ap = AudioPreprocessor(input_sr=self.desired_samplingrate,
output_sr=16000,
cut_silence=False)
# hop length of spec loss should be same as the product of the upscale factors
# samples per segment must be a multiple of hop length of spec loss
if loading_processes == 1:
self.waves = list()
self.cache_builder_process(list_of_paths)
else:
resource_manager = Manager()
self.waves = resource_manager.list()
# make processes
path_splits = list()
process_list = list()
for i in range(loading_processes):
path_splits.append(list_of_paths[i * len(list_of_paths) // loading_processes:(i + 1) * len(
list_of_paths) // loading_processes])
for path_split in path_splits:
process_list.append(Process(target=self.cache_builder_process, args=(path_split,), daemon=True))
process_list[-1].start()
for process in process_list:
process.join()
# self.masker = torchaudio.transforms.FrequencyMasking(freq_mask_param=16, iid_masks=True) # up to 16 consecutive bands can be masked, each element in the batch gets a different mask. Taken out because it seems too extreme.
self.wave_augs = [random_pitch_shifter, polarity_inverter, lambda x: x, lambda x: x, lambda x: x, lambda x: x] # just some data augmentation
self.wave_distortions = [CodecSimulator(), lambda x: x, lambda x: x, lambda x: x, lambda x: x] # simulating the fact, that we train the TTS on codec-compressed waves
print("{} eligible audios found".format(len(self.waves)))
def cache_builder_process(self, path_split):
for path in tqdm(path_split):
try:
wave, sr = sf.read(path)
if len(wave.shape) == 2:
wave = librosa.to_mono(numpy.transpose(wave))
if sr != self.desired_samplingrate:
wave = librosa.resample(y=wave, orig_sr=sr, target_sr=self.desired_samplingrate)
self.waves.append(wave)
except RuntimeError:
print(f"Problem with the following path: {path}")
def __getitem__(self, index):
"""
load the audio from the path and clean it.
All audio segments have to be cut to the same length,
according to the NeurIPS reference implementation.
return a pair of high-res audio and corresponding low-res spectrogram as if it was predicted by the TTS
"""
try:
wave = self.waves[index]
while len(wave) < self.samples_per_segment + 50: # + 50 is just to be extra sure
# catch files that are too short to apply meaningful signal processing and make them longer
wave = numpy.concatenate([wave, numpy.zeros(shape=1000), wave])
# add some true silence in the mix, so the vocoder is exposed to that as well during training
wave = torch.Tensor(wave)
if self.use_random_corruption:
# augmentations for the wave
wave = random.choice(self.wave_augs)(wave.unsqueeze(0)).squeeze(0) # it is intentional that this affects the target as well. This is not a distortion, but an augmentation.
max_audio_start = len(wave) - self.samples_per_segment
audio_start = random.randint(0, max_audio_start)
segment = wave[audio_start: audio_start + self.samples_per_segment]
resampled_segment = self.melspec_ap.resample(segment).float() # 16kHz spectrogram as input, 24kHz wave as output, see Blizzard 2021 DelightfulTTS
if self.use_random_corruption:
# augmentations for the wave
resampled_segment = random.choice(self.wave_distortions)(resampled_segment.unsqueeze(0)).squeeze(0)
melspec = self.melspec_ap.audio_to_mel_spec_tensor(resampled_segment,
explicit_sampling_rate=16000,
normalize=False).transpose(0, 1)[:-1].transpose(0, 1)
return segment.detach(), melspec.detach()
except RuntimeError:
print("encountered a runtime error, using fallback strategy")
if index == 0:
index = len(self.waves) - 1
return self.__getitem__(index - 1)
def __len__(self):
return len(self.waves)
class CodecSimulator(torch.nn.Module):
def __init__(self):
super().__init__()
self.encoder = torchaudio.transforms.MuLawEncoding(quantization_channels=64)
self.decoder = torchaudio.transforms.MuLawDecoding(quantization_channels=64)
def forward(self, x):
return self.decoder(self.encoder(x))
if __name__ == '__main__':
import matplotlib.pyplot as plt
wav, sr = sf.read("../../audios/speaker_references/female_high_voice.wav")
resampled_wave = torch.Tensor(librosa.resample(y=wav, orig_sr=sr, target_sr=24000))
audio = torch.tensor(resampled_wave)
melspec_ap = AudioPreprocessor(input_sr=24000,
output_sr=16000,
cut_silence=False)
spec = melspec_ap.audio_to_mel_spec_tensor(melspec_ap.resample(resampled_wave).float(),
explicit_sampling_rate=16000,
normalize=False).transpose(0, 1)[:-1].transpose(0, 1)
cs = CodecSimulator()
masker = torchaudio.transforms.FrequencyMasking(freq_mask_param=16, iid_masks=True) # up to 8 consecutive bands can be masked
# testing codec simulator
out = cs(resampled_wave.unsqueeze(0)).squeeze(0)
plt.plot(resampled_wave, alpha=0.5)
plt.plot(out, alpha=0.5)
plt.title("Codec Simulator")
plt.show()
# testing spectrogram masking
for _ in range(5):
masked_spec = masker(spec.unsqueeze(0)).squeeze(0)
print(masked_spec)
plt.imshow(masked_spec.cpu().numpy(), origin="lower", cmap='GnBu')
plt.title("Masked Spec")
plt.show()
# testing pitch shift
for _ in range(5):
shifted_wave = random_pitch_shifter(resampled_wave.unsqueeze(0)).squeeze(0)
shifted_spec = melspec_ap.audio_to_mel_spec_tensor(melspec_ap.resample(shifted_wave).float(),
explicit_sampling_rate=16000,
normalize=False).transpose(0, 1)[:-1].transpose(0, 1)
plt.imshow(shifted_spec.detach().cpu().numpy(), origin="lower", cmap='GnBu')
plt.title("Pitch Shifted Spec")
plt.show()