import numpy import numpy as np import queue import torch import random from torch.utils.data import Dataset from tqdm import tqdm class MyDataset(Dataset): def __init__(self, ap, meta_data, voice_len=1.6, num_speakers_in_batch=64, storage_size=1, sample_from_storage_p=0.5, additive_noise=0, num_utter_per_speaker=10, skip_speakers=False, verbose=False): """ Args: ap (TTS.tts.utils.AudioProcessor): audio processor object. meta_data (list): list of dataset instances. seq_len (int): voice segment length in seconds. verbose (bool): print diagnostic information. """ self.items = meta_data self.sample_rate = ap.sample_rate self.voice_len = voice_len self.seq_len = int(voice_len * self.sample_rate) self.num_speakers_in_batch = num_speakers_in_batch self.num_utter_per_speaker = num_utter_per_speaker self.skip_speakers = skip_speakers self.ap = ap self.verbose = verbose self.__parse_items() self.storage = queue.Queue(maxsize=storage_size*num_speakers_in_batch) self.sample_from_storage_p = float(sample_from_storage_p) self.additive_noise = float(additive_noise) if self.verbose: print("\n > DataLoader initialization") print(f" | > Speakers per Batch: {num_speakers_in_batch}") print(f" | > Storage Size: {self.storage.maxsize} speakers, each with {num_utter_per_speaker} utters") print(f" | > Sample_from_storage_p : {self.sample_from_storage_p}") print(f" | > Noise added : {self.additive_noise}") print(f" | > Number of instances : {len(self.items)}") print(f" | > Sequence length: {self.seq_len}") print(f" | > Num speakers: {len(self.speakers)}") def load_wav(self, filename): audio = self.ap.load_wav(filename, sr=self.ap.sample_rate) return audio def load_data(self, idx): text, wav_file, speaker_name = self.items[idx] wav = np.asarray(self.load_wav(wav_file), dtype=np.float32) mel = self.ap.melspectrogram(wav).astype("float32") # sample seq_len assert text.size > 0, self.items[idx][1] assert wav.size > 0, self.items[idx][1] sample = { "mel": mel, "item_idx": self.items[idx][1], "speaker_name": speaker_name, } return sample def __parse_items(self): self.speaker_to_utters = {} for i in self.items: path_ = i[1] speaker_ = i[2] if speaker_ in self.speaker_to_utters.keys(): self.speaker_to_utters[speaker_].append(path_) else: self.speaker_to_utters[speaker_] = [path_, ] if self.skip_speakers: self.speaker_to_utters = {k: v for (k, v) in self.speaker_to_utters.items() if len(v) >= self.num_utter_per_speaker} self.speakers = [k for (k, v) in self.speaker_to_utters.items()] # def __parse_items(self): # """ # Find unique speaker ids and create a dict mapping utterances from speaker id # """ # speakers = list({item[-1] for item in self.items}) # self.speaker_to_utters = {} # self.speakers = [] # for speaker in speakers: # speaker_utters = [item[1] for item in self.items if item[2] == speaker] # if len(speaker_utters) < self.num_utter_per_speaker and self.skip_speakers: # print( # f" [!] Skipped speaker {speaker}. Not enough utterances {self.num_utter_per_speaker} vs {len(speaker_utters)}." # ) # else: # self.speakers.append(speaker) # self.speaker_to_utters[speaker] = speaker_utters def __len__(self): return int(1e10) def __sample_speaker(self): speaker = random.sample(self.speakers, 1)[0] if self.num_utter_per_speaker > len(self.speaker_to_utters[speaker]): utters = random.choices( self.speaker_to_utters[speaker], k=self.num_utter_per_speaker ) else: utters = random.sample( self.speaker_to_utters[speaker], self.num_utter_per_speaker ) return speaker, utters def __sample_speaker_utterances(self, speaker): """ Sample all M utterances for the given speaker. """ wavs = [] labels = [] for _ in range(self.num_utter_per_speaker): # TODO:dummy but works while True: if len(self.speaker_to_utters[speaker]) > 0: utter = random.sample(self.speaker_to_utters[speaker], 1)[0] else: self.speakers.remove(speaker) speaker, _ = self.__sample_speaker() continue wav = self.load_wav(utter) if wav.shape[0] - self.seq_len > 0: break self.speaker_to_utters[speaker].remove(utter) wavs.append(wav) labels.append(speaker) return wavs, labels def __getitem__(self, idx): speaker, _ = self.__sample_speaker() return speaker def collate_fn(self, batch): labels = [] feats = [] for speaker in batch: if random.random() < self.sample_from_storage_p and self.storage.full(): # sample from storage (if full), ignoring the speaker wavs_, labels_ = random.choice(self.storage.queue) else: # don't sample from storage, but from HDD wavs_, labels_ = self.__sample_speaker_utterances(speaker) # if storage is full, remove an item if self.storage.full(): _ = self.storage.get_nowait() # put the newly loaded item into storage self.storage.put_nowait((wavs_, labels_)) # add random gaussian noise if self.additive_noise > 0: noises_ = [numpy.random.normal(0, self.additive_noise, size=len(w)) for w in wavs_] wavs_ = [wavs_[i] + noises_[i] for i in range(len(wavs_))] # get a random subset of each of the wavs and convert to MFCC. offsets_ = [random.randint(0, wav.shape[0] - self.seq_len) for wav in wavs_] mels_ = [self.ap.melspectrogram(wavs_[i][offsets_[i]: offsets_[i] + self.seq_len]) for i in range(len(wavs_))] feats_ = [torch.FloatTensor(mel) for mel in mels_] labels.append(labels_) feats.extend(feats_) feats = torch.stack(feats) return feats.transpose(1, 2), labels