antoniomae1234's picture
changes in flenema
2493d72 verified
import numpy
import numpy as np
import queue
import torch
import random
from torch.utils.data import Dataset
from tqdm import tqdm
class MyDataset(Dataset):
def __init__(self, ap, meta_data, voice_len=1.6, num_speakers_in_batch=64,
storage_size=1, sample_from_storage_p=0.5, additive_noise=0,
num_utter_per_speaker=10, skip_speakers=False, verbose=False):
"""
Args:
ap (TTS.tts.utils.AudioProcessor): audio processor object.
meta_data (list): list of dataset instances.
seq_len (int): voice segment length in seconds.
verbose (bool): print diagnostic information.
"""
self.items = meta_data
self.sample_rate = ap.sample_rate
self.voice_len = voice_len
self.seq_len = int(voice_len * self.sample_rate)
self.num_speakers_in_batch = num_speakers_in_batch
self.num_utter_per_speaker = num_utter_per_speaker
self.skip_speakers = skip_speakers
self.ap = ap
self.verbose = verbose
self.__parse_items()
self.storage = queue.Queue(maxsize=storage_size*num_speakers_in_batch)
self.sample_from_storage_p = float(sample_from_storage_p)
self.additive_noise = float(additive_noise)
if self.verbose:
print("\n > DataLoader initialization")
print(f" | > Speakers per Batch: {num_speakers_in_batch}")
print(f" | > Storage Size: {self.storage.maxsize} speakers, each with {num_utter_per_speaker} utters")
print(f" | > Sample_from_storage_p : {self.sample_from_storage_p}")
print(f" | > Noise added : {self.additive_noise}")
print(f" | > Number of instances : {len(self.items)}")
print(f" | > Sequence length: {self.seq_len}")
print(f" | > Num speakers: {len(self.speakers)}")
def load_wav(self, filename):
audio = self.ap.load_wav(filename, sr=self.ap.sample_rate)
return audio
def load_data(self, idx):
text, wav_file, speaker_name = self.items[idx]
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
mel = self.ap.melspectrogram(wav).astype("float32")
# sample seq_len
assert text.size > 0, self.items[idx][1]
assert wav.size > 0, self.items[idx][1]
sample = {
"mel": mel,
"item_idx": self.items[idx][1],
"speaker_name": speaker_name,
}
return sample
def __parse_items(self):
self.speaker_to_utters = {}
for i in self.items:
path_ = i[1]
speaker_ = i[2]
if speaker_ in self.speaker_to_utters.keys():
self.speaker_to_utters[speaker_].append(path_)
else:
self.speaker_to_utters[speaker_] = [path_, ]
if self.skip_speakers:
self.speaker_to_utters = {k: v for (k, v) in self.speaker_to_utters.items() if
len(v) >= self.num_utter_per_speaker}
self.speakers = [k for (k, v) in self.speaker_to_utters.items()]
# def __parse_items(self):
# """
# Find unique speaker ids and create a dict mapping utterances from speaker id
# """
# speakers = list({item[-1] for item in self.items})
# self.speaker_to_utters = {}
# self.speakers = []
# for speaker in speakers:
# speaker_utters = [item[1] for item in self.items if item[2] == speaker]
# if len(speaker_utters) < self.num_utter_per_speaker and self.skip_speakers:
# print(
# f" [!] Skipped speaker {speaker}. Not enough utterances {self.num_utter_per_speaker} vs {len(speaker_utters)}."
# )
# else:
# self.speakers.append(speaker)
# self.speaker_to_utters[speaker] = speaker_utters
def __len__(self):
return int(1e10)
def __sample_speaker(self):
speaker = random.sample(self.speakers, 1)[0]
if self.num_utter_per_speaker > len(self.speaker_to_utters[speaker]):
utters = random.choices(
self.speaker_to_utters[speaker], k=self.num_utter_per_speaker
)
else:
utters = random.sample(
self.speaker_to_utters[speaker], self.num_utter_per_speaker
)
return speaker, utters
def __sample_speaker_utterances(self, speaker):
"""
Sample all M utterances for the given speaker.
"""
wavs = []
labels = []
for _ in range(self.num_utter_per_speaker):
# TODO:dummy but works
while True:
if len(self.speaker_to_utters[speaker]) > 0:
utter = random.sample(self.speaker_to_utters[speaker], 1)[0]
else:
self.speakers.remove(speaker)
speaker, _ = self.__sample_speaker()
continue
wav = self.load_wav(utter)
if wav.shape[0] - self.seq_len > 0:
break
self.speaker_to_utters[speaker].remove(utter)
wavs.append(wav)
labels.append(speaker)
return wavs, labels
def __getitem__(self, idx):
speaker, _ = self.__sample_speaker()
return speaker
def collate_fn(self, batch):
labels = []
feats = []
for speaker in batch:
if random.random() < self.sample_from_storage_p and self.storage.full():
# sample from storage (if full), ignoring the speaker
wavs_, labels_ = random.choice(self.storage.queue)
else:
# don't sample from storage, but from HDD
wavs_, labels_ = self.__sample_speaker_utterances(speaker)
# if storage is full, remove an item
if self.storage.full():
_ = self.storage.get_nowait()
# put the newly loaded item into storage
self.storage.put_nowait((wavs_, labels_))
# add random gaussian noise
if self.additive_noise > 0:
noises_ = [numpy.random.normal(0, self.additive_noise, size=len(w)) for w in wavs_]
wavs_ = [wavs_[i] + noises_[i] for i in range(len(wavs_))]
# get a random subset of each of the wavs and convert to MFCC.
offsets_ = [random.randint(0, wav.shape[0] - self.seq_len) for wav in wavs_]
mels_ = [self.ap.melspectrogram(wavs_[i][offsets_[i]: offsets_[i] + self.seq_len]) for i in range(len(wavs_))]
feats_ = [torch.FloatTensor(mel) for mel in mels_]
labels.append(labels_)
feats.extend(feats_)
feats = torch.stack(feats)
return feats.transpose(1, 2), labels