File size: 8,775 Bytes
9e275b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import os
import random
from multiprocessing import Manager
from multiprocessing import Process

import librosa
import numpy
import soundfile as sf
import torch
import torchaudio
from torch.utils.data import Dataset
from torchvision.transforms.v2 import GaussianBlur
from tqdm import tqdm

from Preprocessing.AudioPreprocessor import AudioPreprocessor


def random_pitch_shifter(x):
    n_steps = random.choice([-12, -9, -6, 3, 12])  # when using 12 steps per octave, these are the only ones that are pretty fast. I benchmarked it and the variance is many orders of magnitude.
    return torchaudio.transforms.PitchShift(sample_rate=24000, n_steps=n_steps)(x)


def polarity_inverter(x):
    return x * -1


class HiFiGANDataset(Dataset):

    def __init__(self,
                 list_of_paths,
                 desired_samplingrate=24000,
                 samples_per_segment=12288,  # = (8192 * 3) 2 , as I used 8192 for 16kHz previously
                 loading_processes=max(os.cpu_count() - 2, 1),
                 use_random_corruption=False):
        self.use_random_corruption = use_random_corruption
        self.samples_per_segment = samples_per_segment
        self.desired_samplingrate = desired_samplingrate
        self.melspec_ap = AudioPreprocessor(input_sr=self.desired_samplingrate,
                                            output_sr=16000,
                                            cut_silence=False)
        # hop length of spec loss should be same as the product of the upscale factors
        # samples per segment must be a multiple of hop length of spec loss
        if loading_processes == 1:
            self.waves = list()
            self.cache_builder_process(list_of_paths)
        else:
            resource_manager = Manager()
            self.waves = resource_manager.list()
            # make processes
            path_splits = list()
            process_list = list()
            for i in range(loading_processes):
                path_splits.append(list_of_paths[i * len(list_of_paths) // loading_processes:(i + 1) * len(
                    list_of_paths) // loading_processes])
            for path_split in path_splits:
                process_list.append(Process(target=self.cache_builder_process, args=(path_split,), daemon=True))
                process_list[-1].start()
            for process in process_list:
                process.join()
        self.blurrer = GaussianBlur(kernel_size=(5, 5), sigma=(0.5, 2.0))  # simulating the smoothness of a generated spectrogram
        # self.masker = torchaudio.transforms.FrequencyMasking(freq_mask_param=16, iid_masks=True)  # up to 16 consecutive bands can be masked, each element in the batch gets a different mask. Taken out because it seems too extreme.
        self.spec_augs = [self.blurrer, lambda x: x, lambda x: x, lambda x: x, lambda x: x]
        self.wave_augs = [random_pitch_shifter, polarity_inverter, lambda x: x, lambda x: x, lambda x: x, lambda x: x]  # just some data augmentation
        self.wave_distortions = [CodecSimulator(), lambda x: x, lambda x: x, lambda x: x, lambda x: x]  # simulating the fact, that we train the TTS on codec-compressed waves
        print("{} eligible audios found".format(len(self.waves)))

    def cache_builder_process(self, path_split):
        for path in tqdm(path_split):
            try:
                wave, sr = sf.read(path)
                if len(wave.shape) == 2:
                    wave = librosa.to_mono(numpy.transpose(wave))
                if sr != self.desired_samplingrate:
                    wave = librosa.resample(y=wave, orig_sr=sr, target_sr=self.desired_samplingrate)

                self.waves.append(wave)
            except RuntimeError:
                print(f"Problem with the following path: {path}")

    def __getitem__(self, index):
        """
        load the audio from the path and clean it.
        All audio segments have to be cut to the same length,
        according to the NeurIPS reference implementation.

        return a pair of high-res audio and corresponding low-res spectrogram as if it was predicted by the TTS
        """
        try:
            wave = self.waves[index]
            while len(wave) < self.samples_per_segment + 50:  # + 50 is just to be extra sure
                # catch files that are too short to apply meaningful signal processing and make them longer
                wave = numpy.concatenate([wave, numpy.zeros(shape=1000), wave])
                # add some true silence in the mix, so the vocoder is exposed to that as well during training
            wave = torch.Tensor(wave)

            if self.use_random_corruption:
                # augmentations for the wave
                wave = random.choice(self.wave_augs)(wave.unsqueeze(0)).squeeze(0)  # it is intentional that this affects the target as well. This is not a distortion, but an augmentation.

            max_audio_start = len(wave) - self.samples_per_segment
            audio_start = random.randint(0, max_audio_start)
            segment = wave[audio_start: audio_start + self.samples_per_segment]

            resampled_segment = self.melspec_ap.resample(segment).float()  # 16kHz spectrogram as input, 24kHz wave as output, see Blizzard 2021 DelightfulTTS
            if self.use_random_corruption:
                # augmentations for the wave
                resampled_segment = random.choice(self.wave_distortions)(resampled_segment.unsqueeze(0)).squeeze(0)
            melspec = self.melspec_ap.audio_to_mel_spec_tensor(resampled_segment,
                                                               explicit_sampling_rate=16000,
                                                               normalize=False).transpose(0, 1)[:-1].transpose(0, 1)
            if self.use_random_corruption:
                # augmentations for the spec
                melspec = random.choice(self.spec_augs)(melspec.unsqueeze(0)).squeeze(0)
            return segment.detach(), melspec.detach()
        except RuntimeError:
            print("encountered a runtime error, using fallback strategy")
            if index == 0:
                index = len(self.waves) - 1
            return self.__getitem__(index - 1)

    def __len__(self):
        return len(self.waves)


class CodecSimulator(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.encoder = torchaudio.transforms.MuLawEncoding(quantization_channels=64)
        self.decoder = torchaudio.transforms.MuLawDecoding(quantization_channels=64)

    def forward(self, x):
        return self.decoder(self.encoder(x))


if __name__ == '__main__':
    import matplotlib.pyplot as plt

    wav, sr = sf.read("../../audios/speaker_references/female_high_voice.wav")
    resampled_wave = torch.Tensor(librosa.resample(y=wav, orig_sr=sr, target_sr=24000))
    audio = torch.tensor(resampled_wave)
    melspec_ap = AudioPreprocessor(input_sr=24000,
                                   output_sr=16000,
                                   cut_silence=False)

    spec = melspec_ap.audio_to_mel_spec_tensor(melspec_ap.resample(resampled_wave).float(),
                                               explicit_sampling_rate=16000,
                                               normalize=False).transpose(0, 1)[:-1].transpose(0, 1)

    cs = CodecSimulator()
    blurrer = GaussianBlur(kernel_size=(5, 5), sigma=(0.5, 2.0))
    masker = torchaudio.transforms.FrequencyMasking(freq_mask_param=16, iid_masks=True)  # up to 8 consecutive bands can be masked

    # testing codec simulator
    out = cs(resampled_wave.unsqueeze(0)).squeeze(0)

    plt.plot(resampled_wave, alpha=0.5)
    plt.plot(out, alpha=0.5)
    plt.title("Codec Simulator")
    plt.show()

    # testing Gaussian blur
    blurred_spec = blurrer(spec.unsqueeze(0)).squeeze(0)
    plt.imshow(blurred_spec.cpu().numpy(), origin="lower", cmap='GnBu')
    plt.title("Blurred Spec")
    plt.show()

    # testing spectrogram masking
    for _ in range(5):
        masked_spec = masker(spec.unsqueeze(0)).squeeze(0)
        print(masked_spec)
        plt.imshow(masked_spec.cpu().numpy(), origin="lower", cmap='GnBu')
        plt.title("Masked Spec")
        plt.show()

    # testing pitch shift
    for _ in range(5):
        shifted_wave = random_pitch_shifter(resampled_wave.unsqueeze(0)).squeeze(0)
        shifted_spec = melspec_ap.audio_to_mel_spec_tensor(melspec_ap.resample(shifted_wave).float(),
                                                           explicit_sampling_rate=16000,
                                                           normalize=False).transpose(0, 1)[:-1].transpose(0, 1)
        plt.imshow(shifted_spec.detach().cpu().numpy(), origin="lower", cmap='GnBu')
        plt.title("Pitch Shifted Spec")
        plt.show()