namelessai commited on
Commit
693f774
·
verified ·
1 Parent(s): 6becdff

Upload 2 files

Browse files
Files changed (2) hide show
  1. utilities/data/add_on.py +508 -0
  2. utilities/data/dataset.py +518 -0
utilities/data/add_on.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import torchaudio
5
+ import matplotlib.pyplot as plt
6
+
7
+ CACHE = {
8
+ "get_vits_phoneme_ids": {
9
+ "PAD_LENGTH": 310,
10
+ "_pad": "_",
11
+ "_punctuation": ';:,.!?¡¿—…"«»“” ',
12
+ "_letters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz",
13
+ "_letters_ipa": "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ",
14
+ "_special": "♪☎☒☝⚠",
15
+ }
16
+ }
17
+
18
+ CACHE["get_vits_phoneme_ids"]["symbols"] = (
19
+ [CACHE["get_vits_phoneme_ids"]["_pad"]]
20
+ + list(CACHE["get_vits_phoneme_ids"]["_punctuation"])
21
+ + list(CACHE["get_vits_phoneme_ids"]["_letters"])
22
+ + list(CACHE["get_vits_phoneme_ids"]["_letters_ipa"])
23
+ + list(CACHE["get_vits_phoneme_ids"]["_special"])
24
+ )
25
+ CACHE["get_vits_phoneme_ids"]["_symbol_to_id"] = {
26
+ s: i for i, s in enumerate(CACHE["get_vits_phoneme_ids"]["symbols"])
27
+ }
28
+
29
+
30
+ def get_vits_phoneme_ids(config, dl_output, metadata):
31
+ pad_token_id = 0
32
+ pad_length = CACHE["get_vits_phoneme_ids"]["PAD_LENGTH"]
33
+ _symbol_to_id = CACHE["get_vits_phoneme_ids"]["_symbol_to_id"]
34
+
35
+ assert (
36
+ "phonemes" in metadata.keys()
37
+ ), "You must provide vits phonemes on using addon get_vits_phoneme_ids"
38
+ clean_text = metadata["phonemes"]
39
+ sequence = []
40
+
41
+ for symbol in clean_text:
42
+ symbol_id = _symbol_to_id[symbol]
43
+ sequence += [symbol_id]
44
+
45
+ inserted_zero_sequence = [0] * (len(sequence) * 2)
46
+ inserted_zero_sequence[1::2] = sequence
47
+ inserted_zero_sequence = inserted_zero_sequence + [0]
48
+
49
+ def _pad_phonemes(phonemes_list):
50
+ return phonemes_list + [pad_token_id] * (pad_length - len(phonemes_list))
51
+
52
+ return {"phoneme_idx": torch.LongTensor(_pad_phonemes(inserted_zero_sequence))}
53
+
54
+
55
+ def get_vits_phoneme_ids_no_padding(config, dl_output, metadata):
56
+ pad_token_id = 0
57
+ pad_length = CACHE["get_vits_phoneme_ids"]["PAD_LENGTH"]
58
+ _symbol_to_id = CACHE["get_vits_phoneme_ids"]["_symbol_to_id"]
59
+
60
+ assert (
61
+ "phonemes" in metadata.keys()
62
+ ), "You must provide vits phonemes on using addon get_vits_phoneme_ids"
63
+ clean_text = metadata["phonemes"] + "⚠"
64
+ sequence = []
65
+
66
+ for symbol in clean_text:
67
+ if symbol not in _symbol_to_id.keys():
68
+ print("%s is not in the vocabulary. %s" % (symbol, clean_text))
69
+ symbol = "_"
70
+ symbol_id = _symbol_to_id[symbol]
71
+ sequence += [symbol_id]
72
+
73
+ def _pad_phonemes(phonemes_list):
74
+ return phonemes_list + [pad_token_id] * (pad_length - len(phonemes_list))
75
+
76
+ sequence = sequence[:pad_length]
77
+
78
+ return {"phoneme_idx": torch.LongTensor(_pad_phonemes(sequence))}
79
+
80
+
81
+ def calculate_relative_bandwidth(config, dl_output, metadata):
82
+ assert "stft" in dl_output.keys()
83
+
84
+ # The last dimension of the stft feature is the frequency dimension
85
+ freq_dimensions = dl_output["stft"].size(-1)
86
+
87
+ freq_energy_dist = torch.sum(dl_output["stft"], dim=0)
88
+ freq_energy_dist = torch.cumsum(freq_energy_dist, dim=0)
89
+ total_energy = freq_energy_dist[-1]
90
+
91
+ percentile_5th = total_energy * 0.05
92
+ percentile_95th = total_energy * 0.95
93
+
94
+ lower_idx = torch.argmin(torch.abs(percentile_5th - freq_energy_dist))
95
+ higher_idx = torch.argmin(torch.abs(percentile_95th - freq_energy_dist))
96
+
97
+ lower_idx = int((lower_idx / freq_dimensions) * 1000)
98
+ higher_idx = int((higher_idx / freq_dimensions) * 1000)
99
+
100
+ return {"freq_energy_percentile": torch.LongTensor([lower_idx, higher_idx])}
101
+
102
+
103
+ def calculate_mel_spec_relative_bandwidth_as_extra_channel(config, dl_output, metadata):
104
+ assert "stft" in dl_output.keys()
105
+ linear_mel_spec = torch.exp(torch.clip(dl_output["log_mel_spec"], max=10))
106
+
107
+ # The last dimension of the stft feature is the frequency dimension
108
+ freq_dimensions = linear_mel_spec.size(-1)
109
+ freq_energy_dist = torch.sum(linear_mel_spec, dim=0)
110
+ freq_energy_dist = torch.cumsum(freq_energy_dist, dim=0)
111
+ total_energy = freq_energy_dist[-1]
112
+
113
+ percentile_5th = total_energy * 0.05
114
+ percentile_95th = total_energy * 0.95
115
+
116
+ lower_idx = torch.argmin(torch.abs(percentile_5th - freq_energy_dist))
117
+ higher_idx = torch.argmin(torch.abs(percentile_95th - freq_energy_dist))
118
+
119
+ latent_t_size = config["model"]["params"]["latent_t_size"]
120
+ latent_f_size = config["model"]["params"]["latent_f_size"]
121
+
122
+ lower_idx = int(latent_f_size * float((lower_idx / freq_dimensions)))
123
+ higher_idx = int(latent_f_size * float((higher_idx / freq_dimensions)))
124
+
125
+ bandwidth_condition = torch.zeros((latent_t_size, latent_f_size))
126
+ bandwidth_condition[:, lower_idx:higher_idx] += 1.0
127
+
128
+ return {
129
+ "mel_spec_bandwidth_cond_extra_channel": bandwidth_condition,
130
+ "freq_energy_percentile": torch.LongTensor([lower_idx, higher_idx]),
131
+ }
132
+
133
+
134
+ def waveform_rs_48k(config, dl_output, metadata):
135
+ waveform = dl_output["waveform"] # [1, samples]
136
+ sampling_rate = dl_output["sampling_rate"]
137
+
138
+ if sampling_rate != 48000:
139
+ waveform_48k = torchaudio.functional.resample(
140
+ waveform, orig_freq=sampling_rate, new_freq=48000
141
+ )
142
+ else:
143
+ waveform_48k = waveform
144
+
145
+ return {"waveform_48k": waveform_48k}
146
+
147
+
148
+ def extract_vits_phoneme_and_flant5_text(config, dl_output, metadata):
149
+ assert (
150
+ "phoneme" not in metadata.keys()
151
+ ), "The metadata of speech you use seems belong to fastspeech. Please check dataset_root.json"
152
+
153
+ if "phonemes" in metadata.keys():
154
+ new_item = get_vits_phoneme_ids_no_padding(config, dl_output, metadata)
155
+ new_item["text"] = "" # We assume TTS data does not have text description
156
+ else:
157
+ fake_metadata = {"phonemes": ""} # Add empty phoneme sequence
158
+ new_item = get_vits_phoneme_ids_no_padding(config, dl_output, fake_metadata)
159
+
160
+ return new_item
161
+
162
+
163
+ def extract_fs2_phoneme_and_flant5_text(config, dl_output, metadata):
164
+ if "phoneme" in metadata.keys():
165
+ new_item = extract_fs2_phoneme_g2p_en_feature(config, dl_output, metadata)
166
+ new_item["text"] = ""
167
+ else:
168
+ fake_metadata = {"phoneme": []}
169
+ new_item = extract_fs2_phoneme_g2p_en_feature(config, dl_output, fake_metadata)
170
+ return new_item
171
+
172
+
173
+ def extract_fs2_phoneme_g2p_en_feature(config, dl_output, metadata):
174
+ PAD_LENGTH = 135
175
+
176
+ phonemes_lookup_dict = {
177
+ "K": 0,
178
+ "IH2": 1,
179
+ "NG": 2,
180
+ "OW2": 3,
181
+ "AH2": 4,
182
+ "F": 5,
183
+ "AE0": 6,
184
+ "IY0": 7,
185
+ "SH": 8,
186
+ "G": 9,
187
+ "W": 10,
188
+ "UW1": 11,
189
+ "AO2": 12,
190
+ "AW2": 13,
191
+ "UW0": 14,
192
+ "EY2": 15,
193
+ "UW2": 16,
194
+ "AE2": 17,
195
+ "IH0": 18,
196
+ "P": 19,
197
+ "D": 20,
198
+ "ER1": 21,
199
+ "AA1": 22,
200
+ "EH0": 23,
201
+ "UH1": 24,
202
+ "N": 25,
203
+ "V": 26,
204
+ "AY1": 27,
205
+ "EY1": 28,
206
+ "UH2": 29,
207
+ "EH1": 30,
208
+ "L": 31,
209
+ "AA2": 32,
210
+ "R": 33,
211
+ "OY1": 34,
212
+ "Y": 35,
213
+ "ER2": 36,
214
+ "S": 37,
215
+ "AE1": 38,
216
+ "AH1": 39,
217
+ "JH": 40,
218
+ "ER0": 41,
219
+ "EH2": 42,
220
+ "IY2": 43,
221
+ "OY2": 44,
222
+ "AW1": 45,
223
+ "IH1": 46,
224
+ "IY1": 47,
225
+ "OW0": 48,
226
+ "AO0": 49,
227
+ "AY0": 50,
228
+ "EY0": 51,
229
+ "AY2": 52,
230
+ "UH0": 53,
231
+ "M": 54,
232
+ "TH": 55,
233
+ "T": 56,
234
+ "OY0": 57,
235
+ "AW0": 58,
236
+ "DH": 59,
237
+ "Z": 60,
238
+ "spn": 61,
239
+ "AH0": 62,
240
+ "sp": 63,
241
+ "AO1": 64,
242
+ "OW1": 65,
243
+ "ZH": 66,
244
+ "B": 67,
245
+ "AA0": 68,
246
+ "CH": 69,
247
+ "HH": 70,
248
+ }
249
+ pad_token_id = len(phonemes_lookup_dict.keys())
250
+
251
+ assert (
252
+ "phoneme" in metadata.keys()
253
+ ), "The dataloader add-on extract_phoneme_g2p_en_feature will output phoneme id, which is not specified in your dataset"
254
+
255
+ phonemes = [
256
+ phonemes_lookup_dict[x]
257
+ for x in metadata["phoneme"]
258
+ if (x in phonemes_lookup_dict.keys())
259
+ ]
260
+
261
+ if (len(phonemes) / PAD_LENGTH) > 5:
262
+ print(
263
+ "Warning: Phonemes length is too long and is truncated too much! %s"
264
+ % metadata
265
+ )
266
+
267
+ phonemes = phonemes[:PAD_LENGTH]
268
+
269
+ def _pad_phonemes(phonemes_list):
270
+ return phonemes_list + [pad_token_id] * (PAD_LENGTH - len(phonemes_list))
271
+
272
+ return {"phoneme_idx": torch.LongTensor(_pad_phonemes(phonemes))}
273
+
274
+
275
+ def extract_phoneme_g2p_en_feature(config, dl_output, metadata):
276
+ PAD_LENGTH = 250
277
+
278
+ phonemes_lookup_dict = {
279
+ " ": 0,
280
+ "AA": 1,
281
+ "AE": 2,
282
+ "AH": 3,
283
+ "AO": 4,
284
+ "AW": 5,
285
+ "AY": 6,
286
+ "B": 7,
287
+ "CH": 8,
288
+ "D": 9,
289
+ "DH": 10,
290
+ "EH": 11,
291
+ "ER": 12,
292
+ "EY": 13,
293
+ "F": 14,
294
+ "G": 15,
295
+ "HH": 16,
296
+ "IH": 17,
297
+ "IY": 18,
298
+ "JH": 19,
299
+ "K": 20,
300
+ "L": 21,
301
+ "M": 22,
302
+ "N": 23,
303
+ "NG": 24,
304
+ "OW": 25,
305
+ "OY": 26,
306
+ "P": 27,
307
+ "R": 28,
308
+ "S": 29,
309
+ "SH": 30,
310
+ "T": 31,
311
+ "TH": 32,
312
+ "UH": 33,
313
+ "UW": 34,
314
+ "V": 35,
315
+ "W": 36,
316
+ "Y": 37,
317
+ "Z": 38,
318
+ "ZH": 39,
319
+ }
320
+ pad_token_id = len(phonemes_lookup_dict.keys())
321
+
322
+ assert (
323
+ "phoneme" in metadata.keys()
324
+ ), "The dataloader add-on extract_phoneme_g2p_en_feature will output phoneme id, which is not specified in your dataset"
325
+ phonemes = [
326
+ phonemes_lookup_dict[x]
327
+ for x in metadata["phoneme"]
328
+ if (x in phonemes_lookup_dict.keys())
329
+ ]
330
+
331
+ if (len(phonemes) / PAD_LENGTH) > 5:
332
+ print(
333
+ "Warning: Phonemes length is too long and is truncated too much! %s"
334
+ % metadata
335
+ )
336
+
337
+ phonemes = phonemes[:PAD_LENGTH]
338
+
339
+ def _pad_phonemes(phonemes_list):
340
+ return phonemes_list + [pad_token_id] * (PAD_LENGTH - len(phonemes_list))
341
+
342
+ return {"phoneme_idx": torch.LongTensor(_pad_phonemes(phonemes))}
343
+
344
+
345
+ def extract_kaldi_fbank_feature(config, dl_output, metadata):
346
+ norm_mean = -4.2677393
347
+ norm_std = 4.5689974
348
+
349
+ waveform = dl_output["waveform"] # [1, samples]
350
+ sampling_rate = dl_output["sampling_rate"]
351
+ log_mel_spec_hifigan = dl_output["log_mel_spec"]
352
+
353
+ if sampling_rate != 16000:
354
+ waveform_16k = torchaudio.functional.resample(
355
+ waveform, orig_freq=sampling_rate, new_freq=16000
356
+ )
357
+ else:
358
+ waveform_16k = waveform
359
+
360
+ waveform_16k = waveform_16k - waveform_16k.mean()
361
+ fbank = torchaudio.compliance.kaldi.fbank(
362
+ waveform_16k,
363
+ htk_compat=True,
364
+ sample_frequency=16000,
365
+ use_energy=False,
366
+ window_type="hanning",
367
+ num_mel_bins=128,
368
+ dither=0.0,
369
+ frame_shift=10,
370
+ )
371
+
372
+ TARGET_LEN = log_mel_spec_hifigan.size(0)
373
+
374
+ # cut and pad
375
+ n_frames = fbank.shape[0]
376
+ p = TARGET_LEN - n_frames
377
+ if p > 0:
378
+ m = torch.nn.ZeroPad2d((0, 0, 0, p))
379
+ fbank = m(fbank)
380
+ elif p < 0:
381
+ fbank = fbank[:TARGET_LEN, :]
382
+
383
+ fbank = (fbank - norm_mean) / (norm_std * 2)
384
+
385
+ return {"ta_kaldi_fbank": fbank} # [1024, 128]
386
+
387
+
388
+ def extract_kaldi_fbank_feature_32k(config, dl_output, metadata):
389
+ norm_mean = -4.2677393
390
+ norm_std = 4.5689974
391
+
392
+ waveform = dl_output["waveform"] # [1, samples]
393
+ sampling_rate = dl_output["sampling_rate"]
394
+ log_mel_spec_hifigan = dl_output["log_mel_spec"]
395
+
396
+ if sampling_rate != 32000:
397
+ waveform_32k = torchaudio.functional.resample(
398
+ waveform, orig_freq=sampling_rate, new_freq=32000
399
+ )
400
+ else:
401
+ waveform_32k = waveform
402
+
403
+ waveform_32k = waveform_32k - waveform_32k.mean()
404
+ fbank = torchaudio.compliance.kaldi.fbank(
405
+ waveform_32k,
406
+ htk_compat=True,
407
+ sample_frequency=32000,
408
+ use_energy=False,
409
+ window_type="hanning",
410
+ num_mel_bins=128,
411
+ dither=0.0,
412
+ frame_shift=10,
413
+ )
414
+
415
+ TARGET_LEN = log_mel_spec_hifigan.size(0)
416
+
417
+ # cut and pad
418
+ n_frames = fbank.shape[0]
419
+ p = TARGET_LEN - n_frames
420
+ if p > 0:
421
+ m = torch.nn.ZeroPad2d((0, 0, 0, p))
422
+ fbank = m(fbank)
423
+ elif p < 0:
424
+ fbank = fbank[:TARGET_LEN, :]
425
+
426
+ fbank = (fbank - norm_mean) / (norm_std * 2)
427
+
428
+ return {"ta_kaldi_fbank": fbank} # [1024, 128]
429
+
430
+
431
+ # Use the beat and downbeat information as music conditions
432
+ def extract_drum_beat(config, dl_output, metadata):
433
+ def visualization(conditional_signal, mel_spectrogram, filename):
434
+ import soundfile as sf
435
+
436
+ sf.write(
437
+ os.path.basename(dl_output["fname"]),
438
+ np.array(dl_output["waveform"])[0],
439
+ dl_output["sampling_rate"],
440
+ )
441
+ plt.figure(figsize=(10, 10))
442
+
443
+ plt.subplot(211)
444
+ plt.imshow(np.array(conditional_signal).T, aspect="auto")
445
+ plt.title("Conditional Signal")
446
+
447
+ plt.subplot(212)
448
+ plt.imshow(np.array(mel_spectrogram).T, aspect="auto")
449
+ plt.title("Mel Spectrogram")
450
+
451
+ plt.savefig(filename)
452
+ plt.close()
453
+
454
+ assert "sample_rate" in metadata and "beat" in metadata and "downbeat" in metadata
455
+
456
+ sampling_rate = metadata["sample_rate"]
457
+ duration = dl_output["duration"]
458
+ # The dataloader segment length before performing torch resampling
459
+ original_segment_length_before_resample = int(sampling_rate * duration)
460
+
461
+ random_start_sample = int(dl_output["random_start_sample_in_original_audio_file"])
462
+
463
+ # The sample idx for beat and downbeat, relatively to the segmented audio
464
+ beat = [
465
+ x - random_start_sample
466
+ for x in metadata["beat"]
467
+ if (
468
+ x - random_start_sample >= 0
469
+ and x - random_start_sample <= original_segment_length_before_resample
470
+ )
471
+ ]
472
+ downbeat = [
473
+ x - random_start_sample
474
+ for x in metadata["downbeat"]
475
+ if (
476
+ x - random_start_sample >= 0
477
+ and x - random_start_sample <= original_segment_length_before_resample
478
+ )
479
+ ]
480
+
481
+ latent_shape = (
482
+ config["model"]["params"]["latent_t_size"],
483
+ config["model"]["params"]["latent_f_size"],
484
+ )
485
+ conditional_signal = torch.zeros(latent_shape)
486
+
487
+ # beat: -0.5
488
+ # downbeat: +1.0
489
+ # 0: none; -0.5: beat; 1.0: downbeat; 0.5: downbeat+beat
490
+ for each in beat:
491
+ beat_index = int(
492
+ (each / original_segment_length_before_resample) * latent_shape[0]
493
+ )
494
+ beat_index = min(beat_index, conditional_signal.size(0) - 1)
495
+
496
+ conditional_signal[beat_index, :] -= 0.5
497
+
498
+ for each in downbeat:
499
+ beat_index = int(
500
+ (each / original_segment_length_before_resample) * latent_shape[0]
501
+ )
502
+ beat_index = min(beat_index, conditional_signal.size(0) - 1)
503
+
504
+ conditional_signal[beat_index, :] += 1.0
505
+
506
+ # visualization(conditional_signal, dl_output["log_mel_spec"], filename = os.path.basename(dl_output["fname"])+".png")
507
+
508
+ return {"cond_beat_downbeat": conditional_signal}
utilities/data/dataset.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+
4
+ import audiosr.utilities.audio as Audio
5
+ from audiosr.utilities.tools import load_json
6
+
7
+ import random
8
+ from torch.utils.data import Dataset
9
+ import torch.nn.functional
10
+ import torch
11
+ import numpy as np
12
+ import torchaudio
13
+
14
+
15
+ class AudioDataset(Dataset):
16
+ def __init__(
17
+ self,
18
+ config=None,
19
+ split="train",
20
+ waveform_only=False,
21
+ add_ons=[],
22
+ dataset_json_path=None, #
23
+ ):
24
+ """
25
+ Dataset that manages audio recordings
26
+ :param audio_conf: Dictionary containing the audio loading and preprocessing settings
27
+ :param dataset_json_file
28
+ """
29
+ self.config = config
30
+ self.split = split
31
+ self.pad_wav_start_sample = 0 # If none, random choose
32
+ self.trim_wav = False
33
+ self.waveform_only = waveform_only
34
+ self.add_ons = [eval(x) for x in add_ons]
35
+ print("Add-ons:", self.add_ons)
36
+
37
+ self.build_setting_parameters()
38
+
39
+ # For an external dataset
40
+ if dataset_json_path is not None:
41
+ assert type(dataset_json_path) == str
42
+ print("Load metadata from %s" % dataset_json_path)
43
+ self.data = load_json(dataset_json_path)["data"]
44
+ self.id2label, self.index_dict, self.num2label = {}, {}, {}
45
+ else:
46
+ self.metadata_root = load_json(self.config["metadata_root"])
47
+ self.dataset_name = self.config["data"][self.split]
48
+ assert split in self.config["data"].keys(), (
49
+ "The dataset split %s you specified is not present in the config. You can choose from %s"
50
+ % (split, self.config["data"].keys())
51
+ )
52
+ self.build_dataset()
53
+ self.build_id_to_label()
54
+
55
+ self.build_dsp()
56
+ self.label_num = len(self.index_dict)
57
+ print("Dataset initialize finished")
58
+
59
+ def __getitem__(self, index):
60
+ (
61
+ fname,
62
+ waveform,
63
+ stft,
64
+ log_mel_spec,
65
+ label_vector, # the one-hot representation of the audio class
66
+ # the metadata of the sampled audio file and the mixup audio file (if exist)
67
+ (datum, mix_datum),
68
+ random_start,
69
+ ) = self.feature_extraction(index)
70
+ text = self.get_sample_text_caption(datum, mix_datum, label_vector)
71
+
72
+ data = {
73
+ "text": text, # list
74
+ "fname": self.text_to_filename(text)
75
+ if (len(fname) == 0)
76
+ else fname, # list
77
+ # tensor, [batchsize, class_num]
78
+ "label_vector": "" if (label_vector is None) else label_vector.float(),
79
+ # tensor, [batchsize, 1, samples_num]
80
+ "waveform": "" if (waveform is None) else waveform.float(),
81
+ # tensor, [batchsize, t-steps, f-bins]
82
+ "stft": "" if (stft is None) else stft.float(),
83
+ # tensor, [batchsize, t-steps, mel-bins]
84
+ "log_mel_spec": "" if (log_mel_spec is None) else log_mel_spec.float(),
85
+ "duration": self.duration,
86
+ "sampling_rate": self.sampling_rate,
87
+ "random_start_sample_in_original_audio_file": random_start,
88
+ }
89
+
90
+ for add_on in self.add_ons:
91
+ data.update(add_on(self.config, data, self.data[index]))
92
+
93
+ if data["text"] is None:
94
+ print("Warning: The model return None on key text", fname)
95
+ data["text"] = ""
96
+
97
+ return data
98
+
99
+ def text_to_filename(self, text):
100
+ return text.replace(" ", "_").replace("'", "_").replace('"', "_")
101
+
102
+ def get_dataset_root_path(self, dataset):
103
+ assert dataset in self.metadata_root.keys()
104
+ return self.metadata_root[dataset]
105
+
106
+ def get_dataset_metadata_path(self, dataset, key):
107
+ # key: train, test, val, class_label_indices
108
+ try:
109
+ if dataset in self.metadata_root["metadata"]["path"].keys():
110
+ return self.metadata_root["metadata"]["path"][dataset][key]
111
+ except:
112
+ raise ValueError(
113
+ 'Dataset %s does not metadata "%s" specified' % (dataset, key)
114
+ )
115
+ # return None
116
+
117
+ def __len__(self):
118
+ return len(self.data)
119
+
120
+ def feature_extraction(self, index):
121
+ if index > len(self.data) - 1:
122
+ print(
123
+ "The index of the dataloader is out of range: %s/%s"
124
+ % (index, len(self.data))
125
+ )
126
+ index = random.randint(0, len(self.data) - 1)
127
+
128
+ # Read wave file and extract feature
129
+ while True:
130
+ try:
131
+ label_indices = np.zeros(self.label_num, dtype=np.float32)
132
+ datum = self.data[index]
133
+ (
134
+ log_mel_spec,
135
+ stft,
136
+ mix_lambda,
137
+ waveform,
138
+ random_start,
139
+ ) = self.read_audio_file(datum["wav"])
140
+ mix_datum = None
141
+ if self.label_num > 0 and "labels" in datum.keys():
142
+ for label_str in datum["labels"].split(","):
143
+ label_indices[int(self.index_dict[label_str])] = 1.0
144
+
145
+ # If the key "label" is not in the metadata, return all zero vector
146
+ label_indices = torch.FloatTensor(label_indices)
147
+ break
148
+ except Exception as e:
149
+ index = (index + 1) % len(self.data)
150
+ print(
151
+ "Error encounter during audio feature extraction: ", e, datum["wav"]
152
+ )
153
+ continue
154
+
155
+ # The filename of the wav file
156
+ fname = datum["wav"]
157
+ # t_step = log_mel_spec.size(0)
158
+ # waveform = torch.FloatTensor(waveform[..., : int(self.hopsize * t_step)])
159
+ waveform = torch.FloatTensor(waveform)
160
+
161
+ return (
162
+ fname,
163
+ waveform,
164
+ stft,
165
+ log_mel_spec,
166
+ label_indices,
167
+ (datum, mix_datum),
168
+ random_start,
169
+ )
170
+
171
+ # def augmentation(self, log_mel_spec):
172
+ # assert torch.min(log_mel_spec) < 0
173
+ # log_mel_spec = log_mel_spec.exp()
174
+
175
+ # log_mel_spec = torch.transpose(log_mel_spec, 0, 1)
176
+ # # this is just to satisfy new torchaudio version.
177
+ # log_mel_spec = log_mel_spec.unsqueeze(0)
178
+ # if self.freqm != 0:
179
+ # log_mel_spec = self.frequency_masking(log_mel_spec, self.freqm)
180
+ # if self.timem != 0:
181
+ # log_mel_spec = self.time_masking(
182
+ # log_mel_spec, self.timem) # self.timem=0
183
+
184
+ # log_mel_spec = (log_mel_spec + 1e-7).log()
185
+ # # squeeze back
186
+ # log_mel_spec = log_mel_spec.squeeze(0)
187
+ # log_mel_spec = torch.transpose(log_mel_spec, 0, 1)
188
+ # return log_mel_spec
189
+
190
+ def build_setting_parameters(self):
191
+ # Read from the json config
192
+ self.melbins = self.config["preprocessing"]["mel"]["n_mel_channels"]
193
+ # self.freqm = self.config["preprocessing"]["mel"]["freqm"]
194
+ # self.timem = self.config["preprocessing"]["mel"]["timem"]
195
+ self.sampling_rate = self.config["preprocessing"]["audio"]["sampling_rate"]
196
+ self.hopsize = self.config["preprocessing"]["stft"]["hop_length"]
197
+ self.duration = self.config["preprocessing"]["audio"]["duration"]
198
+ self.target_length = int(self.duration * self.sampling_rate / self.hopsize)
199
+
200
+ self.mixup = self.config["augmentation"]["mixup"]
201
+
202
+ # Calculate parameter derivations
203
+ # self.waveform_sample_length = int(self.target_length * self.hopsize)
204
+
205
+ # if (self.config["balance_sampling_weight"]):
206
+ # self.samples_weight = np.loadtxt(
207
+ # self.config["balance_sampling_weight"], delimiter=","
208
+ # )
209
+
210
+ if "train" not in self.split:
211
+ self.mixup = 0.0
212
+ # self.freqm = 0
213
+ # self.timem = 0
214
+
215
+ def _relative_path_to_absolute_path(self, metadata, dataset_name):
216
+ root_path = self.get_dataset_root_path(dataset_name)
217
+ for i in range(len(metadata["data"])):
218
+ assert "wav" in metadata["data"][i].keys(), metadata["data"][i]
219
+ assert metadata["data"][i]["wav"][0] != "/", (
220
+ "The dataset metadata should only contain relative path to the audio file: "
221
+ + str(metadata["data"][i]["wav"])
222
+ )
223
+ metadata["data"][i]["wav"] = os.path.join(
224
+ root_path, metadata["data"][i]["wav"]
225
+ )
226
+ return metadata
227
+
228
+ def build_dataset(self):
229
+ self.data = []
230
+ print("Build dataset split %s from %s" % (self.split, self.dataset_name))
231
+ if type(self.dataset_name) is str:
232
+ data_json = load_json(
233
+ self.get_dataset_metadata_path(self.dataset_name, key=self.split)
234
+ )
235
+ data_json = self._relative_path_to_absolute_path(
236
+ data_json, self.dataset_name
237
+ )
238
+ self.data = data_json["data"]
239
+ elif type(self.dataset_name) is list:
240
+ for dataset_name in self.dataset_name:
241
+ data_json = load_json(
242
+ self.get_dataset_metadata_path(dataset_name, key=self.split)
243
+ )
244
+ data_json = self._relative_path_to_absolute_path(
245
+ data_json, dataset_name
246
+ )
247
+ self.data += data_json["data"]
248
+ else:
249
+ raise Exception("Invalid data format")
250
+ print("Data size: {}".format(len(self.data)))
251
+
252
+ def build_dsp(self):
253
+ self.STFT = Audio.stft.TacotronSTFT(
254
+ self.config["preprocessing"]["stft"]["filter_length"],
255
+ self.config["preprocessing"]["stft"]["hop_length"],
256
+ self.config["preprocessing"]["stft"]["win_length"],
257
+ self.config["preprocessing"]["mel"]["n_mel_channels"],
258
+ self.config["preprocessing"]["audio"]["sampling_rate"],
259
+ self.config["preprocessing"]["mel"]["mel_fmin"],
260
+ self.config["preprocessing"]["mel"]["mel_fmax"],
261
+ )
262
+ # self.stft_transform = torchaudio.transforms.Spectrogram(
263
+ # n_fft=1024, hop_length=160
264
+ # )
265
+ # self.melscale_transform = torchaudio.transforms.MelScale(
266
+ # sample_rate=16000, n_stft=1024 // 2 + 1, n_mels=64
267
+ # )
268
+
269
+ def build_id_to_label(self):
270
+ id2label = {}
271
+ id2num = {}
272
+ num2label = {}
273
+ class_label_indices_path = self.get_dataset_metadata_path(
274
+ dataset=self.config["data"]["class_label_indices"],
275
+ key="class_label_indices",
276
+ )
277
+ if class_label_indices_path is not None:
278
+ df = pd.read_csv(class_label_indices_path)
279
+ for _, row in df.iterrows():
280
+ index, mid, display_name = row["index"], row["mid"], row["display_name"]
281
+ id2label[mid] = display_name
282
+ id2num[mid] = index
283
+ num2label[index] = display_name
284
+ self.id2label, self.index_dict, self.num2label = id2label, id2num, num2label
285
+ else:
286
+ self.id2label, self.index_dict, self.num2label = {}, {}, {}
287
+
288
+ def resample(self, waveform, sr):
289
+ waveform = torchaudio.functional.resample(waveform, sr, self.sampling_rate)
290
+ # waveform = librosa.resample(waveform, sr, self.sampling_rate)
291
+ return waveform
292
+
293
+ # if sr == 16000:
294
+ # return waveform
295
+ # if sr == 32000 and self.sampling_rate == 16000:
296
+ # waveform = waveform[::2]
297
+ # return waveform
298
+ # if sr == 48000 and self.sampling_rate == 16000:
299
+ # waveform = waveform[::3]
300
+ # return waveform
301
+ # else:
302
+ # raise ValueError(
303
+ # "We currently only support 16k audio generation. You need to resample you audio file to 16k, 32k, or 48k: %s, %s"
304
+ # % (sr, self.sampling_rate)
305
+ # )
306
+
307
+ def normalize_wav(self, waveform):
308
+ waveform = waveform - np.mean(waveform)
309
+ waveform = waveform / (np.max(np.abs(waveform)) + 1e-8)
310
+ return waveform * 0.5 # Manually limit the maximum amplitude into 0.5
311
+
312
+ def random_segment_wav(self, waveform, target_length):
313
+ waveform_length = waveform.shape[-1]
314
+ assert waveform_length > 100, "Waveform is too short, %s" % waveform_length
315
+
316
+ # Too short
317
+ if (waveform_length - target_length) <= 0:
318
+ return waveform, 0
319
+
320
+ random_start = int(self.random_uniform(0, waveform_length - target_length))
321
+ return waveform[:, random_start : random_start + target_length], random_start
322
+
323
+ def pad_wav(self, waveform, target_length):
324
+ waveform_length = waveform.shape[-1]
325
+ assert waveform_length > 100, "Waveform is too short, %s" % waveform_length
326
+
327
+ if waveform_length == target_length:
328
+ return waveform
329
+
330
+ # Pad
331
+ temp_wav = np.zeros((1, target_length), dtype=np.float32)
332
+ if self.pad_wav_start_sample is None:
333
+ rand_start = int(self.random_uniform(0, target_length - waveform_length))
334
+ else:
335
+ rand_start = 0
336
+
337
+ temp_wav[:, rand_start : rand_start + waveform_length] = waveform
338
+ return temp_wav
339
+
340
+ def trim_wav(self, waveform):
341
+ if np.max(np.abs(waveform)) < 0.0001:
342
+ return waveform
343
+
344
+ def detect_leading_silence(waveform, threshold=0.0001):
345
+ chunk_size = 1000
346
+ waveform_length = waveform.shape[0]
347
+ start = 0
348
+ while start + chunk_size < waveform_length:
349
+ if np.max(np.abs(waveform[start : start + chunk_size])) < threshold:
350
+ start += chunk_size
351
+ else:
352
+ break
353
+ return start
354
+
355
+ def detect_ending_silence(waveform, threshold=0.0001):
356
+ chunk_size = 1000
357
+ waveform_length = waveform.shape[0]
358
+ start = waveform_length
359
+ while start - chunk_size > 0:
360
+ if np.max(np.abs(waveform[start - chunk_size : start])) < threshold:
361
+ start -= chunk_size
362
+ else:
363
+ break
364
+ if start == waveform_length:
365
+ return start
366
+ else:
367
+ return start + chunk_size
368
+
369
+ start = detect_leading_silence(waveform)
370
+ end = detect_ending_silence(waveform)
371
+
372
+ return waveform[start:end]
373
+
374
+ def read_wav_file(self, filename):
375
+ # waveform, sr = librosa.load(filename, sr=None, mono=True) # 4 times slower
376
+ waveform, sr = torchaudio.load(filename)
377
+
378
+ waveform, random_start = self.random_segment_wav(
379
+ waveform, target_length=int(sr * self.duration)
380
+ )
381
+
382
+ waveform = self.resample(waveform, sr)
383
+ # random_start = int(random_start * (self.sampling_rate / sr))
384
+
385
+ waveform = waveform.numpy()[0, ...]
386
+
387
+ waveform = self.normalize_wav(waveform)
388
+
389
+ if self.trim_wav:
390
+ waveform = self.trim_wav(waveform)
391
+
392
+ waveform = waveform[None, ...]
393
+ waveform = self.pad_wav(
394
+ waveform, target_length=int(self.sampling_rate * self.duration)
395
+ )
396
+ return waveform, random_start
397
+
398
+ def mix_two_waveforms(self, waveform1, waveform2):
399
+ mix_lambda = np.random.beta(5, 5)
400
+ mix_waveform = mix_lambda * waveform1 + (1 - mix_lambda) * waveform2
401
+ return self.normalize_wav(mix_waveform), mix_lambda
402
+
403
+ def read_audio_file(self, filename, filename2=None):
404
+ if os.path.exists(filename):
405
+ waveform, random_start = self.read_wav_file(filename)
406
+ else:
407
+ print(
408
+ 'Warning [dataset.py]: The wav path "',
409
+ filename,
410
+ '" is not find in the metadata. Use empty waveform instead.',
411
+ )
412
+ target_length = int(self.sampling_rate * self.duration)
413
+ waveform = torch.zeros((1, target_length))
414
+ random_start = 0
415
+
416
+ mix_lambda = 0.0
417
+ # log_mel_spec, stft = self.wav_feature_extraction_torchaudio(waveform) # this line is faster, but this implementation is not aligned with HiFi-GAN
418
+ if not self.waveform_only:
419
+ log_mel_spec, stft = self.wav_feature_extraction(waveform)
420
+ else:
421
+ # Load waveform data only
422
+ # Use zero array to keep the format unified
423
+ log_mel_spec, stft = None, None
424
+
425
+ return log_mel_spec, stft, mix_lambda, waveform, random_start
426
+
427
+ def get_sample_text_caption(self, datum, mix_datum, label_indices):
428
+ text = self.label_indices_to_text(datum, label_indices)
429
+ if mix_datum is not None:
430
+ text += " " + self.label_indices_to_text(mix_datum, label_indices)
431
+ return text
432
+
433
+ # This one is significantly slower than "wav_feature_extraction_torchaudio" if num_worker > 1
434
+ def wav_feature_extraction(self, waveform):
435
+ waveform = waveform[0, ...]
436
+ waveform = torch.FloatTensor(waveform)
437
+
438
+ log_mel_spec, stft, energy = Audio.tools.get_mel_from_wav(waveform, self.STFT)
439
+
440
+ log_mel_spec = torch.FloatTensor(log_mel_spec.T)
441
+ stft = torch.FloatTensor(stft.T)
442
+
443
+ log_mel_spec, stft = self.pad_spec(log_mel_spec), self.pad_spec(stft)
444
+ return log_mel_spec, stft
445
+
446
+ # @profile
447
+ # def wav_feature_extraction_torchaudio(self, waveform):
448
+ # waveform = waveform[0, ...]
449
+ # waveform = torch.FloatTensor(waveform)
450
+
451
+ # stft = self.stft_transform(waveform)
452
+ # mel_spec = self.melscale_transform(stft)
453
+ # log_mel_spec = torch.log(mel_spec + 1e-7)
454
+
455
+ # log_mel_spec = torch.FloatTensor(log_mel_spec.T)
456
+ # stft = torch.FloatTensor(stft.T)
457
+
458
+ # log_mel_spec, stft = self.pad_spec(log_mel_spec), self.pad_spec(stft)
459
+ # return log_mel_spec, stft
460
+
461
+ def pad_spec(self, log_mel_spec):
462
+ n_frames = log_mel_spec.shape[0]
463
+ p = self.target_length - n_frames
464
+ # cut and pad
465
+ if p > 0:
466
+ m = torch.nn.ZeroPad2d((0, 0, 0, p))
467
+ log_mel_spec = m(log_mel_spec)
468
+ elif p < 0:
469
+ log_mel_spec = log_mel_spec[0 : self.target_length, :]
470
+
471
+ if log_mel_spec.size(-1) % 2 != 0:
472
+ log_mel_spec = log_mel_spec[..., :-1]
473
+
474
+ return log_mel_spec
475
+
476
+ def _read_datum_caption(self, datum):
477
+ caption_keys = [x for x in datum.keys() if ("caption" in x)]
478
+ random_index = torch.randint(0, len(caption_keys), (1,))[0].item()
479
+ return datum[caption_keys[random_index]]
480
+
481
+ def _is_contain_caption(self, datum):
482
+ caption_keys = [x for x in datum.keys() if ("caption" in x)]
483
+ return len(caption_keys) > 0
484
+
485
+ def label_indices_to_text(self, datum, label_indices):
486
+ if self._is_contain_caption(datum):
487
+ return self._read_datum_caption(datum)
488
+ elif "label" in datum.keys():
489
+ name_indices = torch.where(label_indices > 0.1)[0]
490
+ # description_header = "This audio contains the sound of "
491
+ description_header = ""
492
+ labels = ""
493
+ for id, each in enumerate(name_indices):
494
+ if id == len(name_indices) - 1:
495
+ labels += "%s." % self.num2label[int(each)]
496
+ else:
497
+ labels += "%s, " % self.num2label[int(each)]
498
+ return description_header + labels
499
+ else:
500
+ return "" # TODO, if both label and caption are not provided, return empty string
501
+
502
+ def random_uniform(self, start, end):
503
+ val = torch.rand(1).item()
504
+ return start + (end - start) * val
505
+
506
+ def frequency_masking(self, log_mel_spec, freqm):
507
+ bs, freq, tsteps = log_mel_spec.size()
508
+ mask_len = int(self.random_uniform(freqm // 8, freqm))
509
+ mask_start = int(self.random_uniform(start=0, end=freq - mask_len))
510
+ log_mel_spec[:, mask_start : mask_start + mask_len, :] *= 0.0
511
+ return log_mel_spec
512
+
513
+ def time_masking(self, log_mel_spec, timem):
514
+ bs, freq, tsteps = log_mel_spec.size()
515
+ mask_len = int(self.random_uniform(timem // 8, timem))
516
+ mask_start = int(self.random_uniform(start=0, end=tsteps - mask_len))
517
+ log_mel_spec[:, :, mask_start : mask_start + mask_len] *= 0.0
518
+ return log_mel_spec