File size: 6,882 Bytes
2493d72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import numpy
import numpy as np
import queue
import torch
import random
from torch.utils.data import Dataset
from tqdm import tqdm


class MyDataset(Dataset):
    def __init__(self, ap, meta_data, voice_len=1.6, num_speakers_in_batch=64,
                 storage_size=1, sample_from_storage_p=0.5, additive_noise=0,
                 num_utter_per_speaker=10, skip_speakers=False, verbose=False):
        """
        Args:
            ap (TTS.tts.utils.AudioProcessor): audio processor object.
            meta_data (list): list of dataset instances.
            seq_len (int): voice segment length in seconds.
            verbose (bool): print diagnostic information.
        """
        self.items = meta_data
        self.sample_rate = ap.sample_rate
        self.voice_len = voice_len
        self.seq_len = int(voice_len * self.sample_rate)
        self.num_speakers_in_batch = num_speakers_in_batch
        self.num_utter_per_speaker = num_utter_per_speaker
        self.skip_speakers = skip_speakers
        self.ap = ap
        self.verbose = verbose
        self.__parse_items()
        self.storage = queue.Queue(maxsize=storage_size*num_speakers_in_batch)
        self.sample_from_storage_p = float(sample_from_storage_p)
        self.additive_noise = float(additive_noise)
        if self.verbose:
            print("\n > DataLoader initialization")
            print(f" | > Speakers per Batch: {num_speakers_in_batch}")
            print(f" | > Storage Size: {self.storage.maxsize} speakers, each with {num_utter_per_speaker} utters")
            print(f" | > Sample_from_storage_p : {self.sample_from_storage_p}")
            print(f" | > Noise added : {self.additive_noise}")
            print(f" | > Number of instances : {len(self.items)}")
            print(f" | > Sequence length: {self.seq_len}")
            print(f" | > Num speakers: {len(self.speakers)}")

    def load_wav(self, filename):
        audio = self.ap.load_wav(filename, sr=self.ap.sample_rate)
        return audio

    def load_data(self, idx):
        text, wav_file, speaker_name = self.items[idx]
        wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
        mel = self.ap.melspectrogram(wav).astype("float32")
        # sample seq_len

        assert text.size > 0, self.items[idx][1]
        assert wav.size > 0, self.items[idx][1]

        sample = {
            "mel": mel,
            "item_idx": self.items[idx][1],
            "speaker_name": speaker_name,
        }
        return sample

    def __parse_items(self):
        self.speaker_to_utters = {}
        for i in self.items:
            path_ = i[1]
            speaker_ = i[2]
            if speaker_ in self.speaker_to_utters.keys():
                self.speaker_to_utters[speaker_].append(path_)
            else:
                self.speaker_to_utters[speaker_] = [path_, ]

        if self.skip_speakers:
            self.speaker_to_utters = {k: v for (k, v) in self.speaker_to_utters.items() if
                                      len(v) >= self.num_utter_per_speaker}

        self.speakers = [k for (k, v) in self.speaker_to_utters.items()]

    # def __parse_items(self):
    #     """
    #     Find unique speaker ids and create a dict mapping utterances from speaker id
    #     """
    #     speakers = list({item[-1] for item in self.items})
    #     self.speaker_to_utters = {}
    #     self.speakers = []
    #     for speaker in speakers:
    #         speaker_utters = [item[1] for item in self.items if item[2] == speaker]
    #         if len(speaker_utters) < self.num_utter_per_speaker and self.skip_speakers:
    #             print(
    #                 f" [!] Skipped speaker {speaker}. Not enough utterances {self.num_utter_per_speaker} vs {len(speaker_utters)}."
    #             )
    #         else:
    #             self.speakers.append(speaker)
    #             self.speaker_to_utters[speaker] = speaker_utters

    def __len__(self):
        return int(1e10)

    def __sample_speaker(self):
        speaker = random.sample(self.speakers, 1)[0]
        if self.num_utter_per_speaker > len(self.speaker_to_utters[speaker]):
            utters = random.choices(
                self.speaker_to_utters[speaker], k=self.num_utter_per_speaker
            )
        else:
            utters = random.sample(
                self.speaker_to_utters[speaker], self.num_utter_per_speaker
            )
        return speaker, utters

    def __sample_speaker_utterances(self, speaker):
        """
        Sample all M utterances for the given speaker.
        """
        wavs = []
        labels = []
        for _ in range(self.num_utter_per_speaker):
            # TODO:dummy but works
            while True:
                if len(self.speaker_to_utters[speaker]) > 0:
                    utter = random.sample(self.speaker_to_utters[speaker], 1)[0]
                else:
                    self.speakers.remove(speaker)
                    speaker, _ = self.__sample_speaker()
                    continue
                wav = self.load_wav(utter)
                if wav.shape[0] - self.seq_len > 0:
                    break
                self.speaker_to_utters[speaker].remove(utter)

            wavs.append(wav)
            labels.append(speaker)
        return wavs, labels

    def __getitem__(self, idx):
        speaker, _ = self.__sample_speaker()
        return speaker

    def collate_fn(self, batch):
        labels = []
        feats = []
        for speaker in batch:
            if random.random() < self.sample_from_storage_p and self.storage.full():
                # sample from storage (if full), ignoring the speaker
                wavs_, labels_ = random.choice(self.storage.queue)
            else:
                # don't sample from storage, but from HDD
                wavs_, labels_ = self.__sample_speaker_utterances(speaker)
                # if storage is full, remove an item
                if self.storage.full():
                    _ = self.storage.get_nowait()
                # put the newly loaded item into storage
                self.storage.put_nowait((wavs_, labels_))

            # add random gaussian noise
            if self.additive_noise > 0:
                noises_ = [numpy.random.normal(0, self.additive_noise, size=len(w)) for w in wavs_]
                wavs_ = [wavs_[i] + noises_[i] for i in range(len(wavs_))]

            # get a random subset of each of the wavs and convert to MFCC.
            offsets_ = [random.randint(0, wav.shape[0] - self.seq_len) for wav in wavs_]
            mels_ = [self.ap.melspectrogram(wavs_[i][offsets_[i]: offsets_[i] + self.seq_len]) for i in range(len(wavs_))]
            feats_ = [torch.FloatTensor(mel) for mel in mels_]

            labels.append(labels_)
            feats.extend(feats_)
        feats = torch.stack(feats)
        return feats.transpose(1, 2), labels