Spaces:
Running
on
T4
Running
on
T4
File size: 8,775 Bytes
9e275b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
import os
import random
from multiprocessing import Manager
from multiprocessing import Process
import librosa
import numpy
import soundfile as sf
import torch
import torchaudio
from torch.utils.data import Dataset
from torchvision.transforms.v2 import GaussianBlur
from tqdm import tqdm
from Preprocessing.AudioPreprocessor import AudioPreprocessor
def random_pitch_shifter(x):
n_steps = random.choice([-12, -9, -6, 3, 12]) # when using 12 steps per octave, these are the only ones that are pretty fast. I benchmarked it and the variance is many orders of magnitude.
return torchaudio.transforms.PitchShift(sample_rate=24000, n_steps=n_steps)(x)
def polarity_inverter(x):
return x * -1
class HiFiGANDataset(Dataset):
def __init__(self,
list_of_paths,
desired_samplingrate=24000,
samples_per_segment=12288, # = (8192 * 3) 2 , as I used 8192 for 16kHz previously
loading_processes=max(os.cpu_count() - 2, 1),
use_random_corruption=False):
self.use_random_corruption = use_random_corruption
self.samples_per_segment = samples_per_segment
self.desired_samplingrate = desired_samplingrate
self.melspec_ap = AudioPreprocessor(input_sr=self.desired_samplingrate,
output_sr=16000,
cut_silence=False)
# hop length of spec loss should be same as the product of the upscale factors
# samples per segment must be a multiple of hop length of spec loss
if loading_processes == 1:
self.waves = list()
self.cache_builder_process(list_of_paths)
else:
resource_manager = Manager()
self.waves = resource_manager.list()
# make processes
path_splits = list()
process_list = list()
for i in range(loading_processes):
path_splits.append(list_of_paths[i * len(list_of_paths) // loading_processes:(i + 1) * len(
list_of_paths) // loading_processes])
for path_split in path_splits:
process_list.append(Process(target=self.cache_builder_process, args=(path_split,), daemon=True))
process_list[-1].start()
for process in process_list:
process.join()
self.blurrer = GaussianBlur(kernel_size=(5, 5), sigma=(0.5, 2.0)) # simulating the smoothness of a generated spectrogram
# self.masker = torchaudio.transforms.FrequencyMasking(freq_mask_param=16, iid_masks=True) # up to 16 consecutive bands can be masked, each element in the batch gets a different mask. Taken out because it seems too extreme.
self.spec_augs = [self.blurrer, lambda x: x, lambda x: x, lambda x: x, lambda x: x]
self.wave_augs = [random_pitch_shifter, polarity_inverter, lambda x: x, lambda x: x, lambda x: x, lambda x: x] # just some data augmentation
self.wave_distortions = [CodecSimulator(), lambda x: x, lambda x: x, lambda x: x, lambda x: x] # simulating the fact, that we train the TTS on codec-compressed waves
print("{} eligible audios found".format(len(self.waves)))
def cache_builder_process(self, path_split):
for path in tqdm(path_split):
try:
wave, sr = sf.read(path)
if len(wave.shape) == 2:
wave = librosa.to_mono(numpy.transpose(wave))
if sr != self.desired_samplingrate:
wave = librosa.resample(y=wave, orig_sr=sr, target_sr=self.desired_samplingrate)
self.waves.append(wave)
except RuntimeError:
print(f"Problem with the following path: {path}")
def __getitem__(self, index):
"""
load the audio from the path and clean it.
All audio segments have to be cut to the same length,
according to the NeurIPS reference implementation.
return a pair of high-res audio and corresponding low-res spectrogram as if it was predicted by the TTS
"""
try:
wave = self.waves[index]
while len(wave) < self.samples_per_segment + 50: # + 50 is just to be extra sure
# catch files that are too short to apply meaningful signal processing and make them longer
wave = numpy.concatenate([wave, numpy.zeros(shape=1000), wave])
# add some true silence in the mix, so the vocoder is exposed to that as well during training
wave = torch.Tensor(wave)
if self.use_random_corruption:
# augmentations for the wave
wave = random.choice(self.wave_augs)(wave.unsqueeze(0)).squeeze(0) # it is intentional that this affects the target as well. This is not a distortion, but an augmentation.
max_audio_start = len(wave) - self.samples_per_segment
audio_start = random.randint(0, max_audio_start)
segment = wave[audio_start: audio_start + self.samples_per_segment]
resampled_segment = self.melspec_ap.resample(segment).float() # 16kHz spectrogram as input, 24kHz wave as output, see Blizzard 2021 DelightfulTTS
if self.use_random_corruption:
# augmentations for the wave
resampled_segment = random.choice(self.wave_distortions)(resampled_segment.unsqueeze(0)).squeeze(0)
melspec = self.melspec_ap.audio_to_mel_spec_tensor(resampled_segment,
explicit_sampling_rate=16000,
normalize=False).transpose(0, 1)[:-1].transpose(0, 1)
if self.use_random_corruption:
# augmentations for the spec
melspec = random.choice(self.spec_augs)(melspec.unsqueeze(0)).squeeze(0)
return segment.detach(), melspec.detach()
except RuntimeError:
print("encountered a runtime error, using fallback strategy")
if index == 0:
index = len(self.waves) - 1
return self.__getitem__(index - 1)
def __len__(self):
return len(self.waves)
class CodecSimulator(torch.nn.Module):
def __init__(self):
super().__init__()
self.encoder = torchaudio.transforms.MuLawEncoding(quantization_channels=64)
self.decoder = torchaudio.transforms.MuLawDecoding(quantization_channels=64)
def forward(self, x):
return self.decoder(self.encoder(x))
if __name__ == '__main__':
import matplotlib.pyplot as plt
wav, sr = sf.read("../../audios/speaker_references/female_high_voice.wav")
resampled_wave = torch.Tensor(librosa.resample(y=wav, orig_sr=sr, target_sr=24000))
audio = torch.tensor(resampled_wave)
melspec_ap = AudioPreprocessor(input_sr=24000,
output_sr=16000,
cut_silence=False)
spec = melspec_ap.audio_to_mel_spec_tensor(melspec_ap.resample(resampled_wave).float(),
explicit_sampling_rate=16000,
normalize=False).transpose(0, 1)[:-1].transpose(0, 1)
cs = CodecSimulator()
blurrer = GaussianBlur(kernel_size=(5, 5), sigma=(0.5, 2.0))
masker = torchaudio.transforms.FrequencyMasking(freq_mask_param=16, iid_masks=True) # up to 8 consecutive bands can be masked
# testing codec simulator
out = cs(resampled_wave.unsqueeze(0)).squeeze(0)
plt.plot(resampled_wave, alpha=0.5)
plt.plot(out, alpha=0.5)
plt.title("Codec Simulator")
plt.show()
# testing Gaussian blur
blurred_spec = blurrer(spec.unsqueeze(0)).squeeze(0)
plt.imshow(blurred_spec.cpu().numpy(), origin="lower", cmap='GnBu')
plt.title("Blurred Spec")
plt.show()
# testing spectrogram masking
for _ in range(5):
masked_spec = masker(spec.unsqueeze(0)).squeeze(0)
print(masked_spec)
plt.imshow(masked_spec.cpu().numpy(), origin="lower", cmap='GnBu')
plt.title("Masked Spec")
plt.show()
# testing pitch shift
for _ in range(5):
shifted_wave = random_pitch_shifter(resampled_wave.unsqueeze(0)).squeeze(0)
shifted_spec = melspec_ap.audio_to_mel_spec_tensor(melspec_ap.resample(shifted_wave).float(),
explicit_sampling_rate=16000,
normalize=False).transpose(0, 1)[:-1].transpose(0, 1)
plt.imshow(shifted_spec.detach().cpu().numpy(), origin="lower", cmap='GnBu')
plt.title("Pitch Shifted Spec")
plt.show()
|