Spaces:
Running
on
T4
Running
on
T4
import os | |
import random | |
from multiprocessing import Manager | |
from multiprocessing import Process | |
import librosa | |
import numpy | |
import soundfile as sf | |
import torch | |
import torchaudio | |
from torch.utils.data import Dataset | |
from torchvision.transforms.v2 import GaussianBlur | |
from tqdm import tqdm | |
from Preprocessing.AudioPreprocessor import AudioPreprocessor | |
def random_pitch_shifter(x): | |
n_steps = random.choice([-12, -9, -6, 3, 12]) # when using 12 steps per octave, these are the only ones that are pretty fast. I benchmarked it and the variance is many orders of magnitude. | |
return torchaudio.transforms.PitchShift(sample_rate=24000, n_steps=n_steps)(x) | |
def polarity_inverter(x): | |
return x * -1 | |
class HiFiGANDataset(Dataset): | |
def __init__(self, | |
list_of_paths, | |
desired_samplingrate=24000, | |
samples_per_segment=12288, # = (8192 * 3) 2 , as I used 8192 for 16kHz previously | |
loading_processes=max(os.cpu_count() - 2, 1), | |
use_random_corruption=False): | |
self.use_random_corruption = use_random_corruption | |
self.samples_per_segment = samples_per_segment | |
self.desired_samplingrate = desired_samplingrate | |
self.melspec_ap = AudioPreprocessor(input_sr=self.desired_samplingrate, | |
output_sr=16000, | |
cut_silence=False) | |
# hop length of spec loss should be same as the product of the upscale factors | |
# samples per segment must be a multiple of hop length of spec loss | |
if loading_processes == 1: | |
self.waves = list() | |
self.cache_builder_process(list_of_paths) | |
else: | |
resource_manager = Manager() | |
self.waves = resource_manager.list() | |
# make processes | |
path_splits = list() | |
process_list = list() | |
for i in range(loading_processes): | |
path_splits.append(list_of_paths[i * len(list_of_paths) // loading_processes:(i + 1) * len( | |
list_of_paths) // loading_processes]) | |
for path_split in path_splits: | |
process_list.append(Process(target=self.cache_builder_process, args=(path_split,), daemon=True)) | |
process_list[-1].start() | |
for process in process_list: | |
process.join() | |
self.blurrer = GaussianBlur(kernel_size=(5, 5), sigma=(0.5, 2.0)) # simulating the smoothness of a generated spectrogram | |
# self.masker = torchaudio.transforms.FrequencyMasking(freq_mask_param=16, iid_masks=True) # up to 16 consecutive bands can be masked, each element in the batch gets a different mask. Taken out because it seems too extreme. | |
self.spec_augs = [self.blurrer, lambda x: x, lambda x: x, lambda x: x, lambda x: x] | |
self.wave_augs = [random_pitch_shifter, polarity_inverter, lambda x: x, lambda x: x, lambda x: x, lambda x: x] # just some data augmentation | |
self.wave_distortions = [CodecSimulator(), lambda x: x, lambda x: x, lambda x: x, lambda x: x] # simulating the fact, that we train the TTS on codec-compressed waves | |
print("{} eligible audios found".format(len(self.waves))) | |
def cache_builder_process(self, path_split): | |
for path in tqdm(path_split): | |
try: | |
wave, sr = sf.read(path) | |
if len(wave.shape) == 2: | |
wave = librosa.to_mono(numpy.transpose(wave)) | |
if sr != self.desired_samplingrate: | |
wave = librosa.resample(y=wave, orig_sr=sr, target_sr=self.desired_samplingrate) | |
self.waves.append(wave) | |
except RuntimeError: | |
print(f"Problem with the following path: {path}") | |
def __getitem__(self, index): | |
""" | |
load the audio from the path and clean it. | |
All audio segments have to be cut to the same length, | |
according to the NeurIPS reference implementation. | |
return a pair of high-res audio and corresponding low-res spectrogram as if it was predicted by the TTS | |
""" | |
try: | |
wave = self.waves[index] | |
while len(wave) < self.samples_per_segment + 50: # + 50 is just to be extra sure | |
# catch files that are too short to apply meaningful signal processing and make them longer | |
wave = numpy.concatenate([wave, numpy.zeros(shape=1000), wave]) | |
# add some true silence in the mix, so the vocoder is exposed to that as well during training | |
wave = torch.Tensor(wave) | |
if self.use_random_corruption: | |
# augmentations for the wave | |
wave = random.choice(self.wave_augs)(wave.unsqueeze(0)).squeeze(0) # it is intentional that this affects the target as well. This is not a distortion, but an augmentation. | |
max_audio_start = len(wave) - self.samples_per_segment | |
audio_start = random.randint(0, max_audio_start) | |
segment = wave[audio_start: audio_start + self.samples_per_segment] | |
resampled_segment = self.melspec_ap.resample(segment).float() # 16kHz spectrogram as input, 24kHz wave as output, see Blizzard 2021 DelightfulTTS | |
if self.use_random_corruption: | |
# augmentations for the wave | |
resampled_segment = random.choice(self.wave_distortions)(resampled_segment.unsqueeze(0)).squeeze(0) | |
melspec = self.melspec_ap.audio_to_mel_spec_tensor(resampled_segment, | |
explicit_sampling_rate=16000, | |
normalize=False).transpose(0, 1)[:-1].transpose(0, 1) | |
if self.use_random_corruption: | |
# augmentations for the spec | |
melspec = random.choice(self.spec_augs)(melspec.unsqueeze(0)).squeeze(0) | |
return segment.detach(), melspec.detach() | |
except RuntimeError: | |
print("encountered a runtime error, using fallback strategy") | |
if index == 0: | |
index = len(self.waves) - 1 | |
return self.__getitem__(index - 1) | |
def __len__(self): | |
return len(self.waves) | |
class CodecSimulator(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.encoder = torchaudio.transforms.MuLawEncoding(quantization_channels=64) | |
self.decoder = torchaudio.transforms.MuLawDecoding(quantization_channels=64) | |
def forward(self, x): | |
return self.decoder(self.encoder(x)) | |
if __name__ == '__main__': | |
import matplotlib.pyplot as plt | |
wav, sr = sf.read("../../audios/speaker_references/female_high_voice.wav") | |
resampled_wave = torch.Tensor(librosa.resample(y=wav, orig_sr=sr, target_sr=24000)) | |
audio = torch.tensor(resampled_wave) | |
melspec_ap = AudioPreprocessor(input_sr=24000, | |
output_sr=16000, | |
cut_silence=False) | |
spec = melspec_ap.audio_to_mel_spec_tensor(melspec_ap.resample(resampled_wave).float(), | |
explicit_sampling_rate=16000, | |
normalize=False).transpose(0, 1)[:-1].transpose(0, 1) | |
cs = CodecSimulator() | |
blurrer = GaussianBlur(kernel_size=(5, 5), sigma=(0.5, 2.0)) | |
masker = torchaudio.transforms.FrequencyMasking(freq_mask_param=16, iid_masks=True) # up to 8 consecutive bands can be masked | |
# testing codec simulator | |
out = cs(resampled_wave.unsqueeze(0)).squeeze(0) | |
plt.plot(resampled_wave, alpha=0.5) | |
plt.plot(out, alpha=0.5) | |
plt.title("Codec Simulator") | |
plt.show() | |
# testing Gaussian blur | |
blurred_spec = blurrer(spec.unsqueeze(0)).squeeze(0) | |
plt.imshow(blurred_spec.cpu().numpy(), origin="lower", cmap='GnBu') | |
plt.title("Blurred Spec") | |
plt.show() | |
# testing spectrogram masking | |
for _ in range(5): | |
masked_spec = masker(spec.unsqueeze(0)).squeeze(0) | |
print(masked_spec) | |
plt.imshow(masked_spec.cpu().numpy(), origin="lower", cmap='GnBu') | |
plt.title("Masked Spec") | |
plt.show() | |
# testing pitch shift | |
for _ in range(5): | |
shifted_wave = random_pitch_shifter(resampled_wave.unsqueeze(0)).squeeze(0) | |
shifted_spec = melspec_ap.audio_to_mel_spec_tensor(melspec_ap.resample(shifted_wave).float(), | |
explicit_sampling_rate=16000, | |
normalize=False).transpose(0, 1)[:-1].transpose(0, 1) | |
plt.imshow(shifted_spec.detach().cpu().numpy(), origin="lower", cmap='GnBu') | |
plt.title("Pitch Shifted Spec") | |
plt.show() | |