File size: 6,882 Bytes
2493d72 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import numpy
import numpy as np
import queue
import torch
import random
from torch.utils.data import Dataset
from tqdm import tqdm
class MyDataset(Dataset):
def __init__(self, ap, meta_data, voice_len=1.6, num_speakers_in_batch=64,
storage_size=1, sample_from_storage_p=0.5, additive_noise=0,
num_utter_per_speaker=10, skip_speakers=False, verbose=False):
"""
Args:
ap (TTS.tts.utils.AudioProcessor): audio processor object.
meta_data (list): list of dataset instances.
seq_len (int): voice segment length in seconds.
verbose (bool): print diagnostic information.
"""
self.items = meta_data
self.sample_rate = ap.sample_rate
self.voice_len = voice_len
self.seq_len = int(voice_len * self.sample_rate)
self.num_speakers_in_batch = num_speakers_in_batch
self.num_utter_per_speaker = num_utter_per_speaker
self.skip_speakers = skip_speakers
self.ap = ap
self.verbose = verbose
self.__parse_items()
self.storage = queue.Queue(maxsize=storage_size*num_speakers_in_batch)
self.sample_from_storage_p = float(sample_from_storage_p)
self.additive_noise = float(additive_noise)
if self.verbose:
print("\n > DataLoader initialization")
print(f" | > Speakers per Batch: {num_speakers_in_batch}")
print(f" | > Storage Size: {self.storage.maxsize} speakers, each with {num_utter_per_speaker} utters")
print(f" | > Sample_from_storage_p : {self.sample_from_storage_p}")
print(f" | > Noise added : {self.additive_noise}")
print(f" | > Number of instances : {len(self.items)}")
print(f" | > Sequence length: {self.seq_len}")
print(f" | > Num speakers: {len(self.speakers)}")
def load_wav(self, filename):
audio = self.ap.load_wav(filename, sr=self.ap.sample_rate)
return audio
def load_data(self, idx):
text, wav_file, speaker_name = self.items[idx]
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
mel = self.ap.melspectrogram(wav).astype("float32")
# sample seq_len
assert text.size > 0, self.items[idx][1]
assert wav.size > 0, self.items[idx][1]
sample = {
"mel": mel,
"item_idx": self.items[idx][1],
"speaker_name": speaker_name,
}
return sample
def __parse_items(self):
self.speaker_to_utters = {}
for i in self.items:
path_ = i[1]
speaker_ = i[2]
if speaker_ in self.speaker_to_utters.keys():
self.speaker_to_utters[speaker_].append(path_)
else:
self.speaker_to_utters[speaker_] = [path_, ]
if self.skip_speakers:
self.speaker_to_utters = {k: v for (k, v) in self.speaker_to_utters.items() if
len(v) >= self.num_utter_per_speaker}
self.speakers = [k for (k, v) in self.speaker_to_utters.items()]
# def __parse_items(self):
# """
# Find unique speaker ids and create a dict mapping utterances from speaker id
# """
# speakers = list({item[-1] for item in self.items})
# self.speaker_to_utters = {}
# self.speakers = []
# for speaker in speakers:
# speaker_utters = [item[1] for item in self.items if item[2] == speaker]
# if len(speaker_utters) < self.num_utter_per_speaker and self.skip_speakers:
# print(
# f" [!] Skipped speaker {speaker}. Not enough utterances {self.num_utter_per_speaker} vs {len(speaker_utters)}."
# )
# else:
# self.speakers.append(speaker)
# self.speaker_to_utters[speaker] = speaker_utters
def __len__(self):
return int(1e10)
def __sample_speaker(self):
speaker = random.sample(self.speakers, 1)[0]
if self.num_utter_per_speaker > len(self.speaker_to_utters[speaker]):
utters = random.choices(
self.speaker_to_utters[speaker], k=self.num_utter_per_speaker
)
else:
utters = random.sample(
self.speaker_to_utters[speaker], self.num_utter_per_speaker
)
return speaker, utters
def __sample_speaker_utterances(self, speaker):
"""
Sample all M utterances for the given speaker.
"""
wavs = []
labels = []
for _ in range(self.num_utter_per_speaker):
# TODO:dummy but works
while True:
if len(self.speaker_to_utters[speaker]) > 0:
utter = random.sample(self.speaker_to_utters[speaker], 1)[0]
else:
self.speakers.remove(speaker)
speaker, _ = self.__sample_speaker()
continue
wav = self.load_wav(utter)
if wav.shape[0] - self.seq_len > 0:
break
self.speaker_to_utters[speaker].remove(utter)
wavs.append(wav)
labels.append(speaker)
return wavs, labels
def __getitem__(self, idx):
speaker, _ = self.__sample_speaker()
return speaker
def collate_fn(self, batch):
labels = []
feats = []
for speaker in batch:
if random.random() < self.sample_from_storage_p and self.storage.full():
# sample from storage (if full), ignoring the speaker
wavs_, labels_ = random.choice(self.storage.queue)
else:
# don't sample from storage, but from HDD
wavs_, labels_ = self.__sample_speaker_utterances(speaker)
# if storage is full, remove an item
if self.storage.full():
_ = self.storage.get_nowait()
# put the newly loaded item into storage
self.storage.put_nowait((wavs_, labels_))
# add random gaussian noise
if self.additive_noise > 0:
noises_ = [numpy.random.normal(0, self.additive_noise, size=len(w)) for w in wavs_]
wavs_ = [wavs_[i] + noises_[i] for i in range(len(wavs_))]
# get a random subset of each of the wavs and convert to MFCC.
offsets_ = [random.randint(0, wav.shape[0] - self.seq_len) for wav in wavs_]
mels_ = [self.ap.melspectrogram(wavs_[i][offsets_[i]: offsets_[i] + self.seq_len]) for i in range(len(wavs_))]
feats_ = [torch.FloatTensor(mel) for mel in mels_]
labels.append(labels_)
feats.extend(feats_)
feats = torch.stack(feats)
return feats.transpose(1, 2), labels
|