namelessai
commited on
Upload 2 files
Browse files- utilities/data/add_on.py +508 -0
- utilities/data/dataset.py +518 -0
utilities/data/add_on.py
ADDED
@@ -0,0 +1,508 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import torchaudio
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
|
7 |
+
CACHE = {
|
8 |
+
"get_vits_phoneme_ids": {
|
9 |
+
"PAD_LENGTH": 310,
|
10 |
+
"_pad": "_",
|
11 |
+
"_punctuation": ';:,.!?¡¿—…"«»“” ',
|
12 |
+
"_letters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz",
|
13 |
+
"_letters_ipa": "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ",
|
14 |
+
"_special": "♪☎☒☝⚠",
|
15 |
+
}
|
16 |
+
}
|
17 |
+
|
18 |
+
CACHE["get_vits_phoneme_ids"]["symbols"] = (
|
19 |
+
[CACHE["get_vits_phoneme_ids"]["_pad"]]
|
20 |
+
+ list(CACHE["get_vits_phoneme_ids"]["_punctuation"])
|
21 |
+
+ list(CACHE["get_vits_phoneme_ids"]["_letters"])
|
22 |
+
+ list(CACHE["get_vits_phoneme_ids"]["_letters_ipa"])
|
23 |
+
+ list(CACHE["get_vits_phoneme_ids"]["_special"])
|
24 |
+
)
|
25 |
+
CACHE["get_vits_phoneme_ids"]["_symbol_to_id"] = {
|
26 |
+
s: i for i, s in enumerate(CACHE["get_vits_phoneme_ids"]["symbols"])
|
27 |
+
}
|
28 |
+
|
29 |
+
|
30 |
+
def get_vits_phoneme_ids(config, dl_output, metadata):
|
31 |
+
pad_token_id = 0
|
32 |
+
pad_length = CACHE["get_vits_phoneme_ids"]["PAD_LENGTH"]
|
33 |
+
_symbol_to_id = CACHE["get_vits_phoneme_ids"]["_symbol_to_id"]
|
34 |
+
|
35 |
+
assert (
|
36 |
+
"phonemes" in metadata.keys()
|
37 |
+
), "You must provide vits phonemes on using addon get_vits_phoneme_ids"
|
38 |
+
clean_text = metadata["phonemes"]
|
39 |
+
sequence = []
|
40 |
+
|
41 |
+
for symbol in clean_text:
|
42 |
+
symbol_id = _symbol_to_id[symbol]
|
43 |
+
sequence += [symbol_id]
|
44 |
+
|
45 |
+
inserted_zero_sequence = [0] * (len(sequence) * 2)
|
46 |
+
inserted_zero_sequence[1::2] = sequence
|
47 |
+
inserted_zero_sequence = inserted_zero_sequence + [0]
|
48 |
+
|
49 |
+
def _pad_phonemes(phonemes_list):
|
50 |
+
return phonemes_list + [pad_token_id] * (pad_length - len(phonemes_list))
|
51 |
+
|
52 |
+
return {"phoneme_idx": torch.LongTensor(_pad_phonemes(inserted_zero_sequence))}
|
53 |
+
|
54 |
+
|
55 |
+
def get_vits_phoneme_ids_no_padding(config, dl_output, metadata):
|
56 |
+
pad_token_id = 0
|
57 |
+
pad_length = CACHE["get_vits_phoneme_ids"]["PAD_LENGTH"]
|
58 |
+
_symbol_to_id = CACHE["get_vits_phoneme_ids"]["_symbol_to_id"]
|
59 |
+
|
60 |
+
assert (
|
61 |
+
"phonemes" in metadata.keys()
|
62 |
+
), "You must provide vits phonemes on using addon get_vits_phoneme_ids"
|
63 |
+
clean_text = metadata["phonemes"] + "⚠"
|
64 |
+
sequence = []
|
65 |
+
|
66 |
+
for symbol in clean_text:
|
67 |
+
if symbol not in _symbol_to_id.keys():
|
68 |
+
print("%s is not in the vocabulary. %s" % (symbol, clean_text))
|
69 |
+
symbol = "_"
|
70 |
+
symbol_id = _symbol_to_id[symbol]
|
71 |
+
sequence += [symbol_id]
|
72 |
+
|
73 |
+
def _pad_phonemes(phonemes_list):
|
74 |
+
return phonemes_list + [pad_token_id] * (pad_length - len(phonemes_list))
|
75 |
+
|
76 |
+
sequence = sequence[:pad_length]
|
77 |
+
|
78 |
+
return {"phoneme_idx": torch.LongTensor(_pad_phonemes(sequence))}
|
79 |
+
|
80 |
+
|
81 |
+
def calculate_relative_bandwidth(config, dl_output, metadata):
|
82 |
+
assert "stft" in dl_output.keys()
|
83 |
+
|
84 |
+
# The last dimension of the stft feature is the frequency dimension
|
85 |
+
freq_dimensions = dl_output["stft"].size(-1)
|
86 |
+
|
87 |
+
freq_energy_dist = torch.sum(dl_output["stft"], dim=0)
|
88 |
+
freq_energy_dist = torch.cumsum(freq_energy_dist, dim=0)
|
89 |
+
total_energy = freq_energy_dist[-1]
|
90 |
+
|
91 |
+
percentile_5th = total_energy * 0.05
|
92 |
+
percentile_95th = total_energy * 0.95
|
93 |
+
|
94 |
+
lower_idx = torch.argmin(torch.abs(percentile_5th - freq_energy_dist))
|
95 |
+
higher_idx = torch.argmin(torch.abs(percentile_95th - freq_energy_dist))
|
96 |
+
|
97 |
+
lower_idx = int((lower_idx / freq_dimensions) * 1000)
|
98 |
+
higher_idx = int((higher_idx / freq_dimensions) * 1000)
|
99 |
+
|
100 |
+
return {"freq_energy_percentile": torch.LongTensor([lower_idx, higher_idx])}
|
101 |
+
|
102 |
+
|
103 |
+
def calculate_mel_spec_relative_bandwidth_as_extra_channel(config, dl_output, metadata):
|
104 |
+
assert "stft" in dl_output.keys()
|
105 |
+
linear_mel_spec = torch.exp(torch.clip(dl_output["log_mel_spec"], max=10))
|
106 |
+
|
107 |
+
# The last dimension of the stft feature is the frequency dimension
|
108 |
+
freq_dimensions = linear_mel_spec.size(-1)
|
109 |
+
freq_energy_dist = torch.sum(linear_mel_spec, dim=0)
|
110 |
+
freq_energy_dist = torch.cumsum(freq_energy_dist, dim=0)
|
111 |
+
total_energy = freq_energy_dist[-1]
|
112 |
+
|
113 |
+
percentile_5th = total_energy * 0.05
|
114 |
+
percentile_95th = total_energy * 0.95
|
115 |
+
|
116 |
+
lower_idx = torch.argmin(torch.abs(percentile_5th - freq_energy_dist))
|
117 |
+
higher_idx = torch.argmin(torch.abs(percentile_95th - freq_energy_dist))
|
118 |
+
|
119 |
+
latent_t_size = config["model"]["params"]["latent_t_size"]
|
120 |
+
latent_f_size = config["model"]["params"]["latent_f_size"]
|
121 |
+
|
122 |
+
lower_idx = int(latent_f_size * float((lower_idx / freq_dimensions)))
|
123 |
+
higher_idx = int(latent_f_size * float((higher_idx / freq_dimensions)))
|
124 |
+
|
125 |
+
bandwidth_condition = torch.zeros((latent_t_size, latent_f_size))
|
126 |
+
bandwidth_condition[:, lower_idx:higher_idx] += 1.0
|
127 |
+
|
128 |
+
return {
|
129 |
+
"mel_spec_bandwidth_cond_extra_channel": bandwidth_condition,
|
130 |
+
"freq_energy_percentile": torch.LongTensor([lower_idx, higher_idx]),
|
131 |
+
}
|
132 |
+
|
133 |
+
|
134 |
+
def waveform_rs_48k(config, dl_output, metadata):
|
135 |
+
waveform = dl_output["waveform"] # [1, samples]
|
136 |
+
sampling_rate = dl_output["sampling_rate"]
|
137 |
+
|
138 |
+
if sampling_rate != 48000:
|
139 |
+
waveform_48k = torchaudio.functional.resample(
|
140 |
+
waveform, orig_freq=sampling_rate, new_freq=48000
|
141 |
+
)
|
142 |
+
else:
|
143 |
+
waveform_48k = waveform
|
144 |
+
|
145 |
+
return {"waveform_48k": waveform_48k}
|
146 |
+
|
147 |
+
|
148 |
+
def extract_vits_phoneme_and_flant5_text(config, dl_output, metadata):
|
149 |
+
assert (
|
150 |
+
"phoneme" not in metadata.keys()
|
151 |
+
), "The metadata of speech you use seems belong to fastspeech. Please check dataset_root.json"
|
152 |
+
|
153 |
+
if "phonemes" in metadata.keys():
|
154 |
+
new_item = get_vits_phoneme_ids_no_padding(config, dl_output, metadata)
|
155 |
+
new_item["text"] = "" # We assume TTS data does not have text description
|
156 |
+
else:
|
157 |
+
fake_metadata = {"phonemes": ""} # Add empty phoneme sequence
|
158 |
+
new_item = get_vits_phoneme_ids_no_padding(config, dl_output, fake_metadata)
|
159 |
+
|
160 |
+
return new_item
|
161 |
+
|
162 |
+
|
163 |
+
def extract_fs2_phoneme_and_flant5_text(config, dl_output, metadata):
|
164 |
+
if "phoneme" in metadata.keys():
|
165 |
+
new_item = extract_fs2_phoneme_g2p_en_feature(config, dl_output, metadata)
|
166 |
+
new_item["text"] = ""
|
167 |
+
else:
|
168 |
+
fake_metadata = {"phoneme": []}
|
169 |
+
new_item = extract_fs2_phoneme_g2p_en_feature(config, dl_output, fake_metadata)
|
170 |
+
return new_item
|
171 |
+
|
172 |
+
|
173 |
+
def extract_fs2_phoneme_g2p_en_feature(config, dl_output, metadata):
|
174 |
+
PAD_LENGTH = 135
|
175 |
+
|
176 |
+
phonemes_lookup_dict = {
|
177 |
+
"K": 0,
|
178 |
+
"IH2": 1,
|
179 |
+
"NG": 2,
|
180 |
+
"OW2": 3,
|
181 |
+
"AH2": 4,
|
182 |
+
"F": 5,
|
183 |
+
"AE0": 6,
|
184 |
+
"IY0": 7,
|
185 |
+
"SH": 8,
|
186 |
+
"G": 9,
|
187 |
+
"W": 10,
|
188 |
+
"UW1": 11,
|
189 |
+
"AO2": 12,
|
190 |
+
"AW2": 13,
|
191 |
+
"UW0": 14,
|
192 |
+
"EY2": 15,
|
193 |
+
"UW2": 16,
|
194 |
+
"AE2": 17,
|
195 |
+
"IH0": 18,
|
196 |
+
"P": 19,
|
197 |
+
"D": 20,
|
198 |
+
"ER1": 21,
|
199 |
+
"AA1": 22,
|
200 |
+
"EH0": 23,
|
201 |
+
"UH1": 24,
|
202 |
+
"N": 25,
|
203 |
+
"V": 26,
|
204 |
+
"AY1": 27,
|
205 |
+
"EY1": 28,
|
206 |
+
"UH2": 29,
|
207 |
+
"EH1": 30,
|
208 |
+
"L": 31,
|
209 |
+
"AA2": 32,
|
210 |
+
"R": 33,
|
211 |
+
"OY1": 34,
|
212 |
+
"Y": 35,
|
213 |
+
"ER2": 36,
|
214 |
+
"S": 37,
|
215 |
+
"AE1": 38,
|
216 |
+
"AH1": 39,
|
217 |
+
"JH": 40,
|
218 |
+
"ER0": 41,
|
219 |
+
"EH2": 42,
|
220 |
+
"IY2": 43,
|
221 |
+
"OY2": 44,
|
222 |
+
"AW1": 45,
|
223 |
+
"IH1": 46,
|
224 |
+
"IY1": 47,
|
225 |
+
"OW0": 48,
|
226 |
+
"AO0": 49,
|
227 |
+
"AY0": 50,
|
228 |
+
"EY0": 51,
|
229 |
+
"AY2": 52,
|
230 |
+
"UH0": 53,
|
231 |
+
"M": 54,
|
232 |
+
"TH": 55,
|
233 |
+
"T": 56,
|
234 |
+
"OY0": 57,
|
235 |
+
"AW0": 58,
|
236 |
+
"DH": 59,
|
237 |
+
"Z": 60,
|
238 |
+
"spn": 61,
|
239 |
+
"AH0": 62,
|
240 |
+
"sp": 63,
|
241 |
+
"AO1": 64,
|
242 |
+
"OW1": 65,
|
243 |
+
"ZH": 66,
|
244 |
+
"B": 67,
|
245 |
+
"AA0": 68,
|
246 |
+
"CH": 69,
|
247 |
+
"HH": 70,
|
248 |
+
}
|
249 |
+
pad_token_id = len(phonemes_lookup_dict.keys())
|
250 |
+
|
251 |
+
assert (
|
252 |
+
"phoneme" in metadata.keys()
|
253 |
+
), "The dataloader add-on extract_phoneme_g2p_en_feature will output phoneme id, which is not specified in your dataset"
|
254 |
+
|
255 |
+
phonemes = [
|
256 |
+
phonemes_lookup_dict[x]
|
257 |
+
for x in metadata["phoneme"]
|
258 |
+
if (x in phonemes_lookup_dict.keys())
|
259 |
+
]
|
260 |
+
|
261 |
+
if (len(phonemes) / PAD_LENGTH) > 5:
|
262 |
+
print(
|
263 |
+
"Warning: Phonemes length is too long and is truncated too much! %s"
|
264 |
+
% metadata
|
265 |
+
)
|
266 |
+
|
267 |
+
phonemes = phonemes[:PAD_LENGTH]
|
268 |
+
|
269 |
+
def _pad_phonemes(phonemes_list):
|
270 |
+
return phonemes_list + [pad_token_id] * (PAD_LENGTH - len(phonemes_list))
|
271 |
+
|
272 |
+
return {"phoneme_idx": torch.LongTensor(_pad_phonemes(phonemes))}
|
273 |
+
|
274 |
+
|
275 |
+
def extract_phoneme_g2p_en_feature(config, dl_output, metadata):
|
276 |
+
PAD_LENGTH = 250
|
277 |
+
|
278 |
+
phonemes_lookup_dict = {
|
279 |
+
" ": 0,
|
280 |
+
"AA": 1,
|
281 |
+
"AE": 2,
|
282 |
+
"AH": 3,
|
283 |
+
"AO": 4,
|
284 |
+
"AW": 5,
|
285 |
+
"AY": 6,
|
286 |
+
"B": 7,
|
287 |
+
"CH": 8,
|
288 |
+
"D": 9,
|
289 |
+
"DH": 10,
|
290 |
+
"EH": 11,
|
291 |
+
"ER": 12,
|
292 |
+
"EY": 13,
|
293 |
+
"F": 14,
|
294 |
+
"G": 15,
|
295 |
+
"HH": 16,
|
296 |
+
"IH": 17,
|
297 |
+
"IY": 18,
|
298 |
+
"JH": 19,
|
299 |
+
"K": 20,
|
300 |
+
"L": 21,
|
301 |
+
"M": 22,
|
302 |
+
"N": 23,
|
303 |
+
"NG": 24,
|
304 |
+
"OW": 25,
|
305 |
+
"OY": 26,
|
306 |
+
"P": 27,
|
307 |
+
"R": 28,
|
308 |
+
"S": 29,
|
309 |
+
"SH": 30,
|
310 |
+
"T": 31,
|
311 |
+
"TH": 32,
|
312 |
+
"UH": 33,
|
313 |
+
"UW": 34,
|
314 |
+
"V": 35,
|
315 |
+
"W": 36,
|
316 |
+
"Y": 37,
|
317 |
+
"Z": 38,
|
318 |
+
"ZH": 39,
|
319 |
+
}
|
320 |
+
pad_token_id = len(phonemes_lookup_dict.keys())
|
321 |
+
|
322 |
+
assert (
|
323 |
+
"phoneme" in metadata.keys()
|
324 |
+
), "The dataloader add-on extract_phoneme_g2p_en_feature will output phoneme id, which is not specified in your dataset"
|
325 |
+
phonemes = [
|
326 |
+
phonemes_lookup_dict[x]
|
327 |
+
for x in metadata["phoneme"]
|
328 |
+
if (x in phonemes_lookup_dict.keys())
|
329 |
+
]
|
330 |
+
|
331 |
+
if (len(phonemes) / PAD_LENGTH) > 5:
|
332 |
+
print(
|
333 |
+
"Warning: Phonemes length is too long and is truncated too much! %s"
|
334 |
+
% metadata
|
335 |
+
)
|
336 |
+
|
337 |
+
phonemes = phonemes[:PAD_LENGTH]
|
338 |
+
|
339 |
+
def _pad_phonemes(phonemes_list):
|
340 |
+
return phonemes_list + [pad_token_id] * (PAD_LENGTH - len(phonemes_list))
|
341 |
+
|
342 |
+
return {"phoneme_idx": torch.LongTensor(_pad_phonemes(phonemes))}
|
343 |
+
|
344 |
+
|
345 |
+
def extract_kaldi_fbank_feature(config, dl_output, metadata):
|
346 |
+
norm_mean = -4.2677393
|
347 |
+
norm_std = 4.5689974
|
348 |
+
|
349 |
+
waveform = dl_output["waveform"] # [1, samples]
|
350 |
+
sampling_rate = dl_output["sampling_rate"]
|
351 |
+
log_mel_spec_hifigan = dl_output["log_mel_spec"]
|
352 |
+
|
353 |
+
if sampling_rate != 16000:
|
354 |
+
waveform_16k = torchaudio.functional.resample(
|
355 |
+
waveform, orig_freq=sampling_rate, new_freq=16000
|
356 |
+
)
|
357 |
+
else:
|
358 |
+
waveform_16k = waveform
|
359 |
+
|
360 |
+
waveform_16k = waveform_16k - waveform_16k.mean()
|
361 |
+
fbank = torchaudio.compliance.kaldi.fbank(
|
362 |
+
waveform_16k,
|
363 |
+
htk_compat=True,
|
364 |
+
sample_frequency=16000,
|
365 |
+
use_energy=False,
|
366 |
+
window_type="hanning",
|
367 |
+
num_mel_bins=128,
|
368 |
+
dither=0.0,
|
369 |
+
frame_shift=10,
|
370 |
+
)
|
371 |
+
|
372 |
+
TARGET_LEN = log_mel_spec_hifigan.size(0)
|
373 |
+
|
374 |
+
# cut and pad
|
375 |
+
n_frames = fbank.shape[0]
|
376 |
+
p = TARGET_LEN - n_frames
|
377 |
+
if p > 0:
|
378 |
+
m = torch.nn.ZeroPad2d((0, 0, 0, p))
|
379 |
+
fbank = m(fbank)
|
380 |
+
elif p < 0:
|
381 |
+
fbank = fbank[:TARGET_LEN, :]
|
382 |
+
|
383 |
+
fbank = (fbank - norm_mean) / (norm_std * 2)
|
384 |
+
|
385 |
+
return {"ta_kaldi_fbank": fbank} # [1024, 128]
|
386 |
+
|
387 |
+
|
388 |
+
def extract_kaldi_fbank_feature_32k(config, dl_output, metadata):
|
389 |
+
norm_mean = -4.2677393
|
390 |
+
norm_std = 4.5689974
|
391 |
+
|
392 |
+
waveform = dl_output["waveform"] # [1, samples]
|
393 |
+
sampling_rate = dl_output["sampling_rate"]
|
394 |
+
log_mel_spec_hifigan = dl_output["log_mel_spec"]
|
395 |
+
|
396 |
+
if sampling_rate != 32000:
|
397 |
+
waveform_32k = torchaudio.functional.resample(
|
398 |
+
waveform, orig_freq=sampling_rate, new_freq=32000
|
399 |
+
)
|
400 |
+
else:
|
401 |
+
waveform_32k = waveform
|
402 |
+
|
403 |
+
waveform_32k = waveform_32k - waveform_32k.mean()
|
404 |
+
fbank = torchaudio.compliance.kaldi.fbank(
|
405 |
+
waveform_32k,
|
406 |
+
htk_compat=True,
|
407 |
+
sample_frequency=32000,
|
408 |
+
use_energy=False,
|
409 |
+
window_type="hanning",
|
410 |
+
num_mel_bins=128,
|
411 |
+
dither=0.0,
|
412 |
+
frame_shift=10,
|
413 |
+
)
|
414 |
+
|
415 |
+
TARGET_LEN = log_mel_spec_hifigan.size(0)
|
416 |
+
|
417 |
+
# cut and pad
|
418 |
+
n_frames = fbank.shape[0]
|
419 |
+
p = TARGET_LEN - n_frames
|
420 |
+
if p > 0:
|
421 |
+
m = torch.nn.ZeroPad2d((0, 0, 0, p))
|
422 |
+
fbank = m(fbank)
|
423 |
+
elif p < 0:
|
424 |
+
fbank = fbank[:TARGET_LEN, :]
|
425 |
+
|
426 |
+
fbank = (fbank - norm_mean) / (norm_std * 2)
|
427 |
+
|
428 |
+
return {"ta_kaldi_fbank": fbank} # [1024, 128]
|
429 |
+
|
430 |
+
|
431 |
+
# Use the beat and downbeat information as music conditions
|
432 |
+
def extract_drum_beat(config, dl_output, metadata):
|
433 |
+
def visualization(conditional_signal, mel_spectrogram, filename):
|
434 |
+
import soundfile as sf
|
435 |
+
|
436 |
+
sf.write(
|
437 |
+
os.path.basename(dl_output["fname"]),
|
438 |
+
np.array(dl_output["waveform"])[0],
|
439 |
+
dl_output["sampling_rate"],
|
440 |
+
)
|
441 |
+
plt.figure(figsize=(10, 10))
|
442 |
+
|
443 |
+
plt.subplot(211)
|
444 |
+
plt.imshow(np.array(conditional_signal).T, aspect="auto")
|
445 |
+
plt.title("Conditional Signal")
|
446 |
+
|
447 |
+
plt.subplot(212)
|
448 |
+
plt.imshow(np.array(mel_spectrogram).T, aspect="auto")
|
449 |
+
plt.title("Mel Spectrogram")
|
450 |
+
|
451 |
+
plt.savefig(filename)
|
452 |
+
plt.close()
|
453 |
+
|
454 |
+
assert "sample_rate" in metadata and "beat" in metadata and "downbeat" in metadata
|
455 |
+
|
456 |
+
sampling_rate = metadata["sample_rate"]
|
457 |
+
duration = dl_output["duration"]
|
458 |
+
# The dataloader segment length before performing torch resampling
|
459 |
+
original_segment_length_before_resample = int(sampling_rate * duration)
|
460 |
+
|
461 |
+
random_start_sample = int(dl_output["random_start_sample_in_original_audio_file"])
|
462 |
+
|
463 |
+
# The sample idx for beat and downbeat, relatively to the segmented audio
|
464 |
+
beat = [
|
465 |
+
x - random_start_sample
|
466 |
+
for x in metadata["beat"]
|
467 |
+
if (
|
468 |
+
x - random_start_sample >= 0
|
469 |
+
and x - random_start_sample <= original_segment_length_before_resample
|
470 |
+
)
|
471 |
+
]
|
472 |
+
downbeat = [
|
473 |
+
x - random_start_sample
|
474 |
+
for x in metadata["downbeat"]
|
475 |
+
if (
|
476 |
+
x - random_start_sample >= 0
|
477 |
+
and x - random_start_sample <= original_segment_length_before_resample
|
478 |
+
)
|
479 |
+
]
|
480 |
+
|
481 |
+
latent_shape = (
|
482 |
+
config["model"]["params"]["latent_t_size"],
|
483 |
+
config["model"]["params"]["latent_f_size"],
|
484 |
+
)
|
485 |
+
conditional_signal = torch.zeros(latent_shape)
|
486 |
+
|
487 |
+
# beat: -0.5
|
488 |
+
# downbeat: +1.0
|
489 |
+
# 0: none; -0.5: beat; 1.0: downbeat; 0.5: downbeat+beat
|
490 |
+
for each in beat:
|
491 |
+
beat_index = int(
|
492 |
+
(each / original_segment_length_before_resample) * latent_shape[0]
|
493 |
+
)
|
494 |
+
beat_index = min(beat_index, conditional_signal.size(0) - 1)
|
495 |
+
|
496 |
+
conditional_signal[beat_index, :] -= 0.5
|
497 |
+
|
498 |
+
for each in downbeat:
|
499 |
+
beat_index = int(
|
500 |
+
(each / original_segment_length_before_resample) * latent_shape[0]
|
501 |
+
)
|
502 |
+
beat_index = min(beat_index, conditional_signal.size(0) - 1)
|
503 |
+
|
504 |
+
conditional_signal[beat_index, :] += 1.0
|
505 |
+
|
506 |
+
# visualization(conditional_signal, dl_output["log_mel_spec"], filename = os.path.basename(dl_output["fname"])+".png")
|
507 |
+
|
508 |
+
return {"cond_beat_downbeat": conditional_signal}
|
utilities/data/dataset.py
ADDED
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pandas as pd
|
3 |
+
|
4 |
+
import audiosr.utilities.audio as Audio
|
5 |
+
from audiosr.utilities.tools import load_json
|
6 |
+
|
7 |
+
import random
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
import torch.nn.functional
|
10 |
+
import torch
|
11 |
+
import numpy as np
|
12 |
+
import torchaudio
|
13 |
+
|
14 |
+
|
15 |
+
class AudioDataset(Dataset):
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
config=None,
|
19 |
+
split="train",
|
20 |
+
waveform_only=False,
|
21 |
+
add_ons=[],
|
22 |
+
dataset_json_path=None, #
|
23 |
+
):
|
24 |
+
"""
|
25 |
+
Dataset that manages audio recordings
|
26 |
+
:param audio_conf: Dictionary containing the audio loading and preprocessing settings
|
27 |
+
:param dataset_json_file
|
28 |
+
"""
|
29 |
+
self.config = config
|
30 |
+
self.split = split
|
31 |
+
self.pad_wav_start_sample = 0 # If none, random choose
|
32 |
+
self.trim_wav = False
|
33 |
+
self.waveform_only = waveform_only
|
34 |
+
self.add_ons = [eval(x) for x in add_ons]
|
35 |
+
print("Add-ons:", self.add_ons)
|
36 |
+
|
37 |
+
self.build_setting_parameters()
|
38 |
+
|
39 |
+
# For an external dataset
|
40 |
+
if dataset_json_path is not None:
|
41 |
+
assert type(dataset_json_path) == str
|
42 |
+
print("Load metadata from %s" % dataset_json_path)
|
43 |
+
self.data = load_json(dataset_json_path)["data"]
|
44 |
+
self.id2label, self.index_dict, self.num2label = {}, {}, {}
|
45 |
+
else:
|
46 |
+
self.metadata_root = load_json(self.config["metadata_root"])
|
47 |
+
self.dataset_name = self.config["data"][self.split]
|
48 |
+
assert split in self.config["data"].keys(), (
|
49 |
+
"The dataset split %s you specified is not present in the config. You can choose from %s"
|
50 |
+
% (split, self.config["data"].keys())
|
51 |
+
)
|
52 |
+
self.build_dataset()
|
53 |
+
self.build_id_to_label()
|
54 |
+
|
55 |
+
self.build_dsp()
|
56 |
+
self.label_num = len(self.index_dict)
|
57 |
+
print("Dataset initialize finished")
|
58 |
+
|
59 |
+
def __getitem__(self, index):
|
60 |
+
(
|
61 |
+
fname,
|
62 |
+
waveform,
|
63 |
+
stft,
|
64 |
+
log_mel_spec,
|
65 |
+
label_vector, # the one-hot representation of the audio class
|
66 |
+
# the metadata of the sampled audio file and the mixup audio file (if exist)
|
67 |
+
(datum, mix_datum),
|
68 |
+
random_start,
|
69 |
+
) = self.feature_extraction(index)
|
70 |
+
text = self.get_sample_text_caption(datum, mix_datum, label_vector)
|
71 |
+
|
72 |
+
data = {
|
73 |
+
"text": text, # list
|
74 |
+
"fname": self.text_to_filename(text)
|
75 |
+
if (len(fname) == 0)
|
76 |
+
else fname, # list
|
77 |
+
# tensor, [batchsize, class_num]
|
78 |
+
"label_vector": "" if (label_vector is None) else label_vector.float(),
|
79 |
+
# tensor, [batchsize, 1, samples_num]
|
80 |
+
"waveform": "" if (waveform is None) else waveform.float(),
|
81 |
+
# tensor, [batchsize, t-steps, f-bins]
|
82 |
+
"stft": "" if (stft is None) else stft.float(),
|
83 |
+
# tensor, [batchsize, t-steps, mel-bins]
|
84 |
+
"log_mel_spec": "" if (log_mel_spec is None) else log_mel_spec.float(),
|
85 |
+
"duration": self.duration,
|
86 |
+
"sampling_rate": self.sampling_rate,
|
87 |
+
"random_start_sample_in_original_audio_file": random_start,
|
88 |
+
}
|
89 |
+
|
90 |
+
for add_on in self.add_ons:
|
91 |
+
data.update(add_on(self.config, data, self.data[index]))
|
92 |
+
|
93 |
+
if data["text"] is None:
|
94 |
+
print("Warning: The model return None on key text", fname)
|
95 |
+
data["text"] = ""
|
96 |
+
|
97 |
+
return data
|
98 |
+
|
99 |
+
def text_to_filename(self, text):
|
100 |
+
return text.replace(" ", "_").replace("'", "_").replace('"', "_")
|
101 |
+
|
102 |
+
def get_dataset_root_path(self, dataset):
|
103 |
+
assert dataset in self.metadata_root.keys()
|
104 |
+
return self.metadata_root[dataset]
|
105 |
+
|
106 |
+
def get_dataset_metadata_path(self, dataset, key):
|
107 |
+
# key: train, test, val, class_label_indices
|
108 |
+
try:
|
109 |
+
if dataset in self.metadata_root["metadata"]["path"].keys():
|
110 |
+
return self.metadata_root["metadata"]["path"][dataset][key]
|
111 |
+
except:
|
112 |
+
raise ValueError(
|
113 |
+
'Dataset %s does not metadata "%s" specified' % (dataset, key)
|
114 |
+
)
|
115 |
+
# return None
|
116 |
+
|
117 |
+
def __len__(self):
|
118 |
+
return len(self.data)
|
119 |
+
|
120 |
+
def feature_extraction(self, index):
|
121 |
+
if index > len(self.data) - 1:
|
122 |
+
print(
|
123 |
+
"The index of the dataloader is out of range: %s/%s"
|
124 |
+
% (index, len(self.data))
|
125 |
+
)
|
126 |
+
index = random.randint(0, len(self.data) - 1)
|
127 |
+
|
128 |
+
# Read wave file and extract feature
|
129 |
+
while True:
|
130 |
+
try:
|
131 |
+
label_indices = np.zeros(self.label_num, dtype=np.float32)
|
132 |
+
datum = self.data[index]
|
133 |
+
(
|
134 |
+
log_mel_spec,
|
135 |
+
stft,
|
136 |
+
mix_lambda,
|
137 |
+
waveform,
|
138 |
+
random_start,
|
139 |
+
) = self.read_audio_file(datum["wav"])
|
140 |
+
mix_datum = None
|
141 |
+
if self.label_num > 0 and "labels" in datum.keys():
|
142 |
+
for label_str in datum["labels"].split(","):
|
143 |
+
label_indices[int(self.index_dict[label_str])] = 1.0
|
144 |
+
|
145 |
+
# If the key "label" is not in the metadata, return all zero vector
|
146 |
+
label_indices = torch.FloatTensor(label_indices)
|
147 |
+
break
|
148 |
+
except Exception as e:
|
149 |
+
index = (index + 1) % len(self.data)
|
150 |
+
print(
|
151 |
+
"Error encounter during audio feature extraction: ", e, datum["wav"]
|
152 |
+
)
|
153 |
+
continue
|
154 |
+
|
155 |
+
# The filename of the wav file
|
156 |
+
fname = datum["wav"]
|
157 |
+
# t_step = log_mel_spec.size(0)
|
158 |
+
# waveform = torch.FloatTensor(waveform[..., : int(self.hopsize * t_step)])
|
159 |
+
waveform = torch.FloatTensor(waveform)
|
160 |
+
|
161 |
+
return (
|
162 |
+
fname,
|
163 |
+
waveform,
|
164 |
+
stft,
|
165 |
+
log_mel_spec,
|
166 |
+
label_indices,
|
167 |
+
(datum, mix_datum),
|
168 |
+
random_start,
|
169 |
+
)
|
170 |
+
|
171 |
+
# def augmentation(self, log_mel_spec):
|
172 |
+
# assert torch.min(log_mel_spec) < 0
|
173 |
+
# log_mel_spec = log_mel_spec.exp()
|
174 |
+
|
175 |
+
# log_mel_spec = torch.transpose(log_mel_spec, 0, 1)
|
176 |
+
# # this is just to satisfy new torchaudio version.
|
177 |
+
# log_mel_spec = log_mel_spec.unsqueeze(0)
|
178 |
+
# if self.freqm != 0:
|
179 |
+
# log_mel_spec = self.frequency_masking(log_mel_spec, self.freqm)
|
180 |
+
# if self.timem != 0:
|
181 |
+
# log_mel_spec = self.time_masking(
|
182 |
+
# log_mel_spec, self.timem) # self.timem=0
|
183 |
+
|
184 |
+
# log_mel_spec = (log_mel_spec + 1e-7).log()
|
185 |
+
# # squeeze back
|
186 |
+
# log_mel_spec = log_mel_spec.squeeze(0)
|
187 |
+
# log_mel_spec = torch.transpose(log_mel_spec, 0, 1)
|
188 |
+
# return log_mel_spec
|
189 |
+
|
190 |
+
def build_setting_parameters(self):
|
191 |
+
# Read from the json config
|
192 |
+
self.melbins = self.config["preprocessing"]["mel"]["n_mel_channels"]
|
193 |
+
# self.freqm = self.config["preprocessing"]["mel"]["freqm"]
|
194 |
+
# self.timem = self.config["preprocessing"]["mel"]["timem"]
|
195 |
+
self.sampling_rate = self.config["preprocessing"]["audio"]["sampling_rate"]
|
196 |
+
self.hopsize = self.config["preprocessing"]["stft"]["hop_length"]
|
197 |
+
self.duration = self.config["preprocessing"]["audio"]["duration"]
|
198 |
+
self.target_length = int(self.duration * self.sampling_rate / self.hopsize)
|
199 |
+
|
200 |
+
self.mixup = self.config["augmentation"]["mixup"]
|
201 |
+
|
202 |
+
# Calculate parameter derivations
|
203 |
+
# self.waveform_sample_length = int(self.target_length * self.hopsize)
|
204 |
+
|
205 |
+
# if (self.config["balance_sampling_weight"]):
|
206 |
+
# self.samples_weight = np.loadtxt(
|
207 |
+
# self.config["balance_sampling_weight"], delimiter=","
|
208 |
+
# )
|
209 |
+
|
210 |
+
if "train" not in self.split:
|
211 |
+
self.mixup = 0.0
|
212 |
+
# self.freqm = 0
|
213 |
+
# self.timem = 0
|
214 |
+
|
215 |
+
def _relative_path_to_absolute_path(self, metadata, dataset_name):
|
216 |
+
root_path = self.get_dataset_root_path(dataset_name)
|
217 |
+
for i in range(len(metadata["data"])):
|
218 |
+
assert "wav" in metadata["data"][i].keys(), metadata["data"][i]
|
219 |
+
assert metadata["data"][i]["wav"][0] != "/", (
|
220 |
+
"The dataset metadata should only contain relative path to the audio file: "
|
221 |
+
+ str(metadata["data"][i]["wav"])
|
222 |
+
)
|
223 |
+
metadata["data"][i]["wav"] = os.path.join(
|
224 |
+
root_path, metadata["data"][i]["wav"]
|
225 |
+
)
|
226 |
+
return metadata
|
227 |
+
|
228 |
+
def build_dataset(self):
|
229 |
+
self.data = []
|
230 |
+
print("Build dataset split %s from %s" % (self.split, self.dataset_name))
|
231 |
+
if type(self.dataset_name) is str:
|
232 |
+
data_json = load_json(
|
233 |
+
self.get_dataset_metadata_path(self.dataset_name, key=self.split)
|
234 |
+
)
|
235 |
+
data_json = self._relative_path_to_absolute_path(
|
236 |
+
data_json, self.dataset_name
|
237 |
+
)
|
238 |
+
self.data = data_json["data"]
|
239 |
+
elif type(self.dataset_name) is list:
|
240 |
+
for dataset_name in self.dataset_name:
|
241 |
+
data_json = load_json(
|
242 |
+
self.get_dataset_metadata_path(dataset_name, key=self.split)
|
243 |
+
)
|
244 |
+
data_json = self._relative_path_to_absolute_path(
|
245 |
+
data_json, dataset_name
|
246 |
+
)
|
247 |
+
self.data += data_json["data"]
|
248 |
+
else:
|
249 |
+
raise Exception("Invalid data format")
|
250 |
+
print("Data size: {}".format(len(self.data)))
|
251 |
+
|
252 |
+
def build_dsp(self):
|
253 |
+
self.STFT = Audio.stft.TacotronSTFT(
|
254 |
+
self.config["preprocessing"]["stft"]["filter_length"],
|
255 |
+
self.config["preprocessing"]["stft"]["hop_length"],
|
256 |
+
self.config["preprocessing"]["stft"]["win_length"],
|
257 |
+
self.config["preprocessing"]["mel"]["n_mel_channels"],
|
258 |
+
self.config["preprocessing"]["audio"]["sampling_rate"],
|
259 |
+
self.config["preprocessing"]["mel"]["mel_fmin"],
|
260 |
+
self.config["preprocessing"]["mel"]["mel_fmax"],
|
261 |
+
)
|
262 |
+
# self.stft_transform = torchaudio.transforms.Spectrogram(
|
263 |
+
# n_fft=1024, hop_length=160
|
264 |
+
# )
|
265 |
+
# self.melscale_transform = torchaudio.transforms.MelScale(
|
266 |
+
# sample_rate=16000, n_stft=1024 // 2 + 1, n_mels=64
|
267 |
+
# )
|
268 |
+
|
269 |
+
def build_id_to_label(self):
|
270 |
+
id2label = {}
|
271 |
+
id2num = {}
|
272 |
+
num2label = {}
|
273 |
+
class_label_indices_path = self.get_dataset_metadata_path(
|
274 |
+
dataset=self.config["data"]["class_label_indices"],
|
275 |
+
key="class_label_indices",
|
276 |
+
)
|
277 |
+
if class_label_indices_path is not None:
|
278 |
+
df = pd.read_csv(class_label_indices_path)
|
279 |
+
for _, row in df.iterrows():
|
280 |
+
index, mid, display_name = row["index"], row["mid"], row["display_name"]
|
281 |
+
id2label[mid] = display_name
|
282 |
+
id2num[mid] = index
|
283 |
+
num2label[index] = display_name
|
284 |
+
self.id2label, self.index_dict, self.num2label = id2label, id2num, num2label
|
285 |
+
else:
|
286 |
+
self.id2label, self.index_dict, self.num2label = {}, {}, {}
|
287 |
+
|
288 |
+
def resample(self, waveform, sr):
|
289 |
+
waveform = torchaudio.functional.resample(waveform, sr, self.sampling_rate)
|
290 |
+
# waveform = librosa.resample(waveform, sr, self.sampling_rate)
|
291 |
+
return waveform
|
292 |
+
|
293 |
+
# if sr == 16000:
|
294 |
+
# return waveform
|
295 |
+
# if sr == 32000 and self.sampling_rate == 16000:
|
296 |
+
# waveform = waveform[::2]
|
297 |
+
# return waveform
|
298 |
+
# if sr == 48000 and self.sampling_rate == 16000:
|
299 |
+
# waveform = waveform[::3]
|
300 |
+
# return waveform
|
301 |
+
# else:
|
302 |
+
# raise ValueError(
|
303 |
+
# "We currently only support 16k audio generation. You need to resample you audio file to 16k, 32k, or 48k: %s, %s"
|
304 |
+
# % (sr, self.sampling_rate)
|
305 |
+
# )
|
306 |
+
|
307 |
+
def normalize_wav(self, waveform):
|
308 |
+
waveform = waveform - np.mean(waveform)
|
309 |
+
waveform = waveform / (np.max(np.abs(waveform)) + 1e-8)
|
310 |
+
return waveform * 0.5 # Manually limit the maximum amplitude into 0.5
|
311 |
+
|
312 |
+
def random_segment_wav(self, waveform, target_length):
|
313 |
+
waveform_length = waveform.shape[-1]
|
314 |
+
assert waveform_length > 100, "Waveform is too short, %s" % waveform_length
|
315 |
+
|
316 |
+
# Too short
|
317 |
+
if (waveform_length - target_length) <= 0:
|
318 |
+
return waveform, 0
|
319 |
+
|
320 |
+
random_start = int(self.random_uniform(0, waveform_length - target_length))
|
321 |
+
return waveform[:, random_start : random_start + target_length], random_start
|
322 |
+
|
323 |
+
def pad_wav(self, waveform, target_length):
|
324 |
+
waveform_length = waveform.shape[-1]
|
325 |
+
assert waveform_length > 100, "Waveform is too short, %s" % waveform_length
|
326 |
+
|
327 |
+
if waveform_length == target_length:
|
328 |
+
return waveform
|
329 |
+
|
330 |
+
# Pad
|
331 |
+
temp_wav = np.zeros((1, target_length), dtype=np.float32)
|
332 |
+
if self.pad_wav_start_sample is None:
|
333 |
+
rand_start = int(self.random_uniform(0, target_length - waveform_length))
|
334 |
+
else:
|
335 |
+
rand_start = 0
|
336 |
+
|
337 |
+
temp_wav[:, rand_start : rand_start + waveform_length] = waveform
|
338 |
+
return temp_wav
|
339 |
+
|
340 |
+
def trim_wav(self, waveform):
|
341 |
+
if np.max(np.abs(waveform)) < 0.0001:
|
342 |
+
return waveform
|
343 |
+
|
344 |
+
def detect_leading_silence(waveform, threshold=0.0001):
|
345 |
+
chunk_size = 1000
|
346 |
+
waveform_length = waveform.shape[0]
|
347 |
+
start = 0
|
348 |
+
while start + chunk_size < waveform_length:
|
349 |
+
if np.max(np.abs(waveform[start : start + chunk_size])) < threshold:
|
350 |
+
start += chunk_size
|
351 |
+
else:
|
352 |
+
break
|
353 |
+
return start
|
354 |
+
|
355 |
+
def detect_ending_silence(waveform, threshold=0.0001):
|
356 |
+
chunk_size = 1000
|
357 |
+
waveform_length = waveform.shape[0]
|
358 |
+
start = waveform_length
|
359 |
+
while start - chunk_size > 0:
|
360 |
+
if np.max(np.abs(waveform[start - chunk_size : start])) < threshold:
|
361 |
+
start -= chunk_size
|
362 |
+
else:
|
363 |
+
break
|
364 |
+
if start == waveform_length:
|
365 |
+
return start
|
366 |
+
else:
|
367 |
+
return start + chunk_size
|
368 |
+
|
369 |
+
start = detect_leading_silence(waveform)
|
370 |
+
end = detect_ending_silence(waveform)
|
371 |
+
|
372 |
+
return waveform[start:end]
|
373 |
+
|
374 |
+
def read_wav_file(self, filename):
|
375 |
+
# waveform, sr = librosa.load(filename, sr=None, mono=True) # 4 times slower
|
376 |
+
waveform, sr = torchaudio.load(filename)
|
377 |
+
|
378 |
+
waveform, random_start = self.random_segment_wav(
|
379 |
+
waveform, target_length=int(sr * self.duration)
|
380 |
+
)
|
381 |
+
|
382 |
+
waveform = self.resample(waveform, sr)
|
383 |
+
# random_start = int(random_start * (self.sampling_rate / sr))
|
384 |
+
|
385 |
+
waveform = waveform.numpy()[0, ...]
|
386 |
+
|
387 |
+
waveform = self.normalize_wav(waveform)
|
388 |
+
|
389 |
+
if self.trim_wav:
|
390 |
+
waveform = self.trim_wav(waveform)
|
391 |
+
|
392 |
+
waveform = waveform[None, ...]
|
393 |
+
waveform = self.pad_wav(
|
394 |
+
waveform, target_length=int(self.sampling_rate * self.duration)
|
395 |
+
)
|
396 |
+
return waveform, random_start
|
397 |
+
|
398 |
+
def mix_two_waveforms(self, waveform1, waveform2):
|
399 |
+
mix_lambda = np.random.beta(5, 5)
|
400 |
+
mix_waveform = mix_lambda * waveform1 + (1 - mix_lambda) * waveform2
|
401 |
+
return self.normalize_wav(mix_waveform), mix_lambda
|
402 |
+
|
403 |
+
def read_audio_file(self, filename, filename2=None):
|
404 |
+
if os.path.exists(filename):
|
405 |
+
waveform, random_start = self.read_wav_file(filename)
|
406 |
+
else:
|
407 |
+
print(
|
408 |
+
'Warning [dataset.py]: The wav path "',
|
409 |
+
filename,
|
410 |
+
'" is not find in the metadata. Use empty waveform instead.',
|
411 |
+
)
|
412 |
+
target_length = int(self.sampling_rate * self.duration)
|
413 |
+
waveform = torch.zeros((1, target_length))
|
414 |
+
random_start = 0
|
415 |
+
|
416 |
+
mix_lambda = 0.0
|
417 |
+
# log_mel_spec, stft = self.wav_feature_extraction_torchaudio(waveform) # this line is faster, but this implementation is not aligned with HiFi-GAN
|
418 |
+
if not self.waveform_only:
|
419 |
+
log_mel_spec, stft = self.wav_feature_extraction(waveform)
|
420 |
+
else:
|
421 |
+
# Load waveform data only
|
422 |
+
# Use zero array to keep the format unified
|
423 |
+
log_mel_spec, stft = None, None
|
424 |
+
|
425 |
+
return log_mel_spec, stft, mix_lambda, waveform, random_start
|
426 |
+
|
427 |
+
def get_sample_text_caption(self, datum, mix_datum, label_indices):
|
428 |
+
text = self.label_indices_to_text(datum, label_indices)
|
429 |
+
if mix_datum is not None:
|
430 |
+
text += " " + self.label_indices_to_text(mix_datum, label_indices)
|
431 |
+
return text
|
432 |
+
|
433 |
+
# This one is significantly slower than "wav_feature_extraction_torchaudio" if num_worker > 1
|
434 |
+
def wav_feature_extraction(self, waveform):
|
435 |
+
waveform = waveform[0, ...]
|
436 |
+
waveform = torch.FloatTensor(waveform)
|
437 |
+
|
438 |
+
log_mel_spec, stft, energy = Audio.tools.get_mel_from_wav(waveform, self.STFT)
|
439 |
+
|
440 |
+
log_mel_spec = torch.FloatTensor(log_mel_spec.T)
|
441 |
+
stft = torch.FloatTensor(stft.T)
|
442 |
+
|
443 |
+
log_mel_spec, stft = self.pad_spec(log_mel_spec), self.pad_spec(stft)
|
444 |
+
return log_mel_spec, stft
|
445 |
+
|
446 |
+
# @profile
|
447 |
+
# def wav_feature_extraction_torchaudio(self, waveform):
|
448 |
+
# waveform = waveform[0, ...]
|
449 |
+
# waveform = torch.FloatTensor(waveform)
|
450 |
+
|
451 |
+
# stft = self.stft_transform(waveform)
|
452 |
+
# mel_spec = self.melscale_transform(stft)
|
453 |
+
# log_mel_spec = torch.log(mel_spec + 1e-7)
|
454 |
+
|
455 |
+
# log_mel_spec = torch.FloatTensor(log_mel_spec.T)
|
456 |
+
# stft = torch.FloatTensor(stft.T)
|
457 |
+
|
458 |
+
# log_mel_spec, stft = self.pad_spec(log_mel_spec), self.pad_spec(stft)
|
459 |
+
# return log_mel_spec, stft
|
460 |
+
|
461 |
+
def pad_spec(self, log_mel_spec):
|
462 |
+
n_frames = log_mel_spec.shape[0]
|
463 |
+
p = self.target_length - n_frames
|
464 |
+
# cut and pad
|
465 |
+
if p > 0:
|
466 |
+
m = torch.nn.ZeroPad2d((0, 0, 0, p))
|
467 |
+
log_mel_spec = m(log_mel_spec)
|
468 |
+
elif p < 0:
|
469 |
+
log_mel_spec = log_mel_spec[0 : self.target_length, :]
|
470 |
+
|
471 |
+
if log_mel_spec.size(-1) % 2 != 0:
|
472 |
+
log_mel_spec = log_mel_spec[..., :-1]
|
473 |
+
|
474 |
+
return log_mel_spec
|
475 |
+
|
476 |
+
def _read_datum_caption(self, datum):
|
477 |
+
caption_keys = [x for x in datum.keys() if ("caption" in x)]
|
478 |
+
random_index = torch.randint(0, len(caption_keys), (1,))[0].item()
|
479 |
+
return datum[caption_keys[random_index]]
|
480 |
+
|
481 |
+
def _is_contain_caption(self, datum):
|
482 |
+
caption_keys = [x for x in datum.keys() if ("caption" in x)]
|
483 |
+
return len(caption_keys) > 0
|
484 |
+
|
485 |
+
def label_indices_to_text(self, datum, label_indices):
|
486 |
+
if self._is_contain_caption(datum):
|
487 |
+
return self._read_datum_caption(datum)
|
488 |
+
elif "label" in datum.keys():
|
489 |
+
name_indices = torch.where(label_indices > 0.1)[0]
|
490 |
+
# description_header = "This audio contains the sound of "
|
491 |
+
description_header = ""
|
492 |
+
labels = ""
|
493 |
+
for id, each in enumerate(name_indices):
|
494 |
+
if id == len(name_indices) - 1:
|
495 |
+
labels += "%s." % self.num2label[int(each)]
|
496 |
+
else:
|
497 |
+
labels += "%s, " % self.num2label[int(each)]
|
498 |
+
return description_header + labels
|
499 |
+
else:
|
500 |
+
return "" # TODO, if both label and caption are not provided, return empty string
|
501 |
+
|
502 |
+
def random_uniform(self, start, end):
|
503 |
+
val = torch.rand(1).item()
|
504 |
+
return start + (end - start) * val
|
505 |
+
|
506 |
+
def frequency_masking(self, log_mel_spec, freqm):
|
507 |
+
bs, freq, tsteps = log_mel_spec.size()
|
508 |
+
mask_len = int(self.random_uniform(freqm // 8, freqm))
|
509 |
+
mask_start = int(self.random_uniform(start=0, end=freq - mask_len))
|
510 |
+
log_mel_spec[:, mask_start : mask_start + mask_len, :] *= 0.0
|
511 |
+
return log_mel_spec
|
512 |
+
|
513 |
+
def time_masking(self, log_mel_spec, timem):
|
514 |
+
bs, freq, tsteps = log_mel_spec.size()
|
515 |
+
mask_len = int(self.random_uniform(timem // 8, timem))
|
516 |
+
mask_start = int(self.random_uniform(start=0, end=tsteps - mask_len))
|
517 |
+
log_mel_spec[:, :, mask_start : mask_start + mask_len] *= 0.0
|
518 |
+
return log_mel_spec
|