|
import numpy |
|
import numpy as np |
|
import queue |
|
import torch |
|
import random |
|
from torch.utils.data import Dataset |
|
from tqdm import tqdm |
|
|
|
|
|
class MyDataset(Dataset): |
|
def __init__(self, ap, meta_data, voice_len=1.6, num_speakers_in_batch=64, |
|
storage_size=1, sample_from_storage_p=0.5, additive_noise=0, |
|
num_utter_per_speaker=10, skip_speakers=False, verbose=False): |
|
""" |
|
Args: |
|
ap (TTS.tts.utils.AudioProcessor): audio processor object. |
|
meta_data (list): list of dataset instances. |
|
seq_len (int): voice segment length in seconds. |
|
verbose (bool): print diagnostic information. |
|
""" |
|
self.items = meta_data |
|
self.sample_rate = ap.sample_rate |
|
self.voice_len = voice_len |
|
self.seq_len = int(voice_len * self.sample_rate) |
|
self.num_speakers_in_batch = num_speakers_in_batch |
|
self.num_utter_per_speaker = num_utter_per_speaker |
|
self.skip_speakers = skip_speakers |
|
self.ap = ap |
|
self.verbose = verbose |
|
self.__parse_items() |
|
self.storage = queue.Queue(maxsize=storage_size*num_speakers_in_batch) |
|
self.sample_from_storage_p = float(sample_from_storage_p) |
|
self.additive_noise = float(additive_noise) |
|
if self.verbose: |
|
print("\n > DataLoader initialization") |
|
print(f" | > Speakers per Batch: {num_speakers_in_batch}") |
|
print(f" | > Storage Size: {self.storage.maxsize} speakers, each with {num_utter_per_speaker} utters") |
|
print(f" | > Sample_from_storage_p : {self.sample_from_storage_p}") |
|
print(f" | > Noise added : {self.additive_noise}") |
|
print(f" | > Number of instances : {len(self.items)}") |
|
print(f" | > Sequence length: {self.seq_len}") |
|
print(f" | > Num speakers: {len(self.speakers)}") |
|
|
|
def load_wav(self, filename): |
|
audio = self.ap.load_wav(filename, sr=self.ap.sample_rate) |
|
return audio |
|
|
|
def load_data(self, idx): |
|
text, wav_file, speaker_name = self.items[idx] |
|
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32) |
|
mel = self.ap.melspectrogram(wav).astype("float32") |
|
|
|
|
|
assert text.size > 0, self.items[idx][1] |
|
assert wav.size > 0, self.items[idx][1] |
|
|
|
sample = { |
|
"mel": mel, |
|
"item_idx": self.items[idx][1], |
|
"speaker_name": speaker_name, |
|
} |
|
return sample |
|
|
|
def __parse_items(self): |
|
self.speaker_to_utters = {} |
|
for i in self.items: |
|
path_ = i[1] |
|
speaker_ = i[2] |
|
if speaker_ in self.speaker_to_utters.keys(): |
|
self.speaker_to_utters[speaker_].append(path_) |
|
else: |
|
self.speaker_to_utters[speaker_] = [path_, ] |
|
|
|
if self.skip_speakers: |
|
self.speaker_to_utters = {k: v for (k, v) in self.speaker_to_utters.items() if |
|
len(v) >= self.num_utter_per_speaker} |
|
|
|
self.speakers = [k for (k, v) in self.speaker_to_utters.items()] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __len__(self): |
|
return int(1e10) |
|
|
|
def __sample_speaker(self): |
|
speaker = random.sample(self.speakers, 1)[0] |
|
if self.num_utter_per_speaker > len(self.speaker_to_utters[speaker]): |
|
utters = random.choices( |
|
self.speaker_to_utters[speaker], k=self.num_utter_per_speaker |
|
) |
|
else: |
|
utters = random.sample( |
|
self.speaker_to_utters[speaker], self.num_utter_per_speaker |
|
) |
|
return speaker, utters |
|
|
|
def __sample_speaker_utterances(self, speaker): |
|
""" |
|
Sample all M utterances for the given speaker. |
|
""" |
|
wavs = [] |
|
labels = [] |
|
for _ in range(self.num_utter_per_speaker): |
|
|
|
while True: |
|
if len(self.speaker_to_utters[speaker]) > 0: |
|
utter = random.sample(self.speaker_to_utters[speaker], 1)[0] |
|
else: |
|
self.speakers.remove(speaker) |
|
speaker, _ = self.__sample_speaker() |
|
continue |
|
wav = self.load_wav(utter) |
|
if wav.shape[0] - self.seq_len > 0: |
|
break |
|
self.speaker_to_utters[speaker].remove(utter) |
|
|
|
wavs.append(wav) |
|
labels.append(speaker) |
|
return wavs, labels |
|
|
|
def __getitem__(self, idx): |
|
speaker, _ = self.__sample_speaker() |
|
return speaker |
|
|
|
def collate_fn(self, batch): |
|
labels = [] |
|
feats = [] |
|
for speaker in batch: |
|
if random.random() < self.sample_from_storage_p and self.storage.full(): |
|
|
|
wavs_, labels_ = random.choice(self.storage.queue) |
|
else: |
|
|
|
wavs_, labels_ = self.__sample_speaker_utterances(speaker) |
|
|
|
if self.storage.full(): |
|
_ = self.storage.get_nowait() |
|
|
|
self.storage.put_nowait((wavs_, labels_)) |
|
|
|
|
|
if self.additive_noise > 0: |
|
noises_ = [numpy.random.normal(0, self.additive_noise, size=len(w)) for w in wavs_] |
|
wavs_ = [wavs_[i] + noises_[i] for i in range(len(wavs_))] |
|
|
|
|
|
offsets_ = [random.randint(0, wav.shape[0] - self.seq_len) for wav in wavs_] |
|
mels_ = [self.ap.melspectrogram(wavs_[i][offsets_[i]: offsets_[i] + self.seq_len]) for i in range(len(wavs_))] |
|
feats_ = [torch.FloatTensor(mel) for mel in mels_] |
|
|
|
labels.append(labels_) |
|
feats.extend(feats_) |
|
feats = torch.stack(feats) |
|
return feats.transpose(1, 2), labels |
|
|