amupd commited on
Commit
8b33290
1 Parent(s): fe9c876

initial commit

Browse files
Files changed (44) hide show
  1. .gitignore +3 -0
  2. README.md +5 -5
  3. SpeechT5 +1 -0
  4. app.py +466 -0
  5. artst/__init__.py +1 -0
  6. artst/criterions/__init__.py +10 -0
  7. artst/criterions/artst_criterion.py +443 -0
  8. artst/criterions/speech_pretrain_criterion.py +265 -0
  9. artst/criterions/speech_to_text_loss.py +473 -0
  10. artst/criterions/text_pretrain_criterion.py +142 -0
  11. artst/criterions/text_to_speech_loss.py +425 -0
  12. artst/data/__init__.py +0 -0
  13. artst/data/multitask_dataset.py +263 -0
  14. artst/data/speech_dataset.py +475 -0
  15. artst/data/speech_to_class_dataset.py +260 -0
  16. artst/data/speech_to_speech_dataset.py +280 -0
  17. artst/data/speech_to_text_dataset.py +298 -0
  18. artst/data/text_dataset.py +474 -0
  19. artst/data/text_to_speech_dataset.py +344 -0
  20. artst/models/__init__.py +2 -0
  21. artst/models/artst.py +1448 -0
  22. artst/models/modules/__init__.py +0 -0
  23. artst/models/modules/decoder.py +323 -0
  24. artst/models/modules/encoder.py +380 -0
  25. artst/models/modules/multihead_attention.py +525 -0
  26. artst/models/modules/speaker_decoder_postnet.py +196 -0
  27. artst/models/modules/speech_decoder_postnet.py +75 -0
  28. artst/models/modules/speech_decoder_prenet.py +109 -0
  29. artst/models/modules/speech_encoder_postnet.py +123 -0
  30. artst/models/modules/speech_encoder_prenet.py +373 -0
  31. artst/models/modules/text_decoder_postnet.py +92 -0
  32. artst/models/modules/text_decoder_prenet.py +128 -0
  33. artst/models/modules/text_encoder_prenet.py +44 -0
  34. artst/models/modules/transformer_layer.py +410 -0
  35. artst/models/t5_transformer_lm.py +23 -0
  36. artst/sequence_generator.py +1080 -0
  37. artst/tasks/__init__.py +0 -0
  38. artst/tasks/artst.py +711 -0
  39. ckpts/mgb2_asr.pt +3 -0
  40. pre-requirements.txt +34 -0
  41. samples/sample_audio.wav +0 -0
  42. utils/arabic.model +3 -0
  43. utils/audios.tsv +2 -0
  44. utils/dict.txt +92 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __pycache__
2
+ */__pycache__
3
+ */*/__pycache__
README.md CHANGED
@@ -1,13 +1,13 @@
1
  ---
2
- title: ArtstASR
3
- emoji: 📉
4
  colorFrom: gray
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 4.8.0
8
  app_file: app.py
9
  pinned: false
10
- license: mit
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: ArtstTTS
3
+ emoji: 💭
4
  colorFrom: gray
5
+ colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 4.7.1
8
  app_file: app.py
9
  pinned: false
10
+ python_version: 3.8.2
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
SpeechT5 ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 8b5ade783571e63450aaa5507444150dcb08fa94
app.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3 -u
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ Translate pre-processed data with a trained model.
8
+ """
9
+
10
+ import ast
11
+ import logging
12
+ import argparse
13
+ import math
14
+ import os
15
+ import sys
16
+ from argparse import Namespace
17
+ from itertools import chain
18
+
19
+ import numpy as np
20
+ import torch
21
+ from omegaconf import DictConfig
22
+
23
+ from fairseq import checkpoint_utils, options, scoring, tasks, utils
24
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
25
+ from fairseq.logging import progress_bar
26
+ from fairseq.logging.meters import StopwatchMeter, TimeMeter
27
+
28
+ import os
29
+ import torch
30
+ import gradio as gr
31
+ import numpy as np
32
+ import os.path as op
33
+ import pyarabic.araby as araby
34
+ import subprocess
35
+
36
+ import soundfile as sf
37
+
38
+
39
+ from artst.tasks.artst import ArTSTTask
40
+ from artst.models.artst import ArTSTTransformerModel
41
+ from fairseq.tasks.hubert_pretraining import LabelEncoder
42
+
43
+ from fairseq import checkpoint_utils, options, scoring, tasks, utils
44
+
45
+ from loguru import logger
46
+ from fairseq.logging.meters import StopwatchMeter, TimeMeter
47
+
48
+
49
+ def postprocess(wav, cur_sample_rate):
50
+ if wav.dim() == 2:
51
+ wav = wav.mean(-1)
52
+ assert wav.dim() == 1, wav.dim()
53
+
54
+ if cur_sample_rate != 16000:
55
+ raise Exception(f"sr {cur_sample_rate} != {16000}")
56
+ return wav
57
+
58
+
59
+ def main(cfg: DictConfig, audio_path):
60
+ print('config')
61
+ print(cfg)
62
+
63
+ if isinstance(cfg, Namespace):
64
+ cfg = convert_namespace_to_omegaconf(cfg)
65
+
66
+ assert cfg.common_eval.path is not None, "--path required for generation!"
67
+ assert (
68
+ not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam
69
+ ), "--sampling requires --nbest to be equal to --beam"
70
+ assert (
71
+ cfg.generation.replace_unk is None or cfg.dataset.dataset_impl == "raw"
72
+ ), "--replace-unk requires a raw text dataset (--dataset-impl=raw)"
73
+
74
+ if cfg.common_eval.results_path is not None:
75
+ os.makedirs(cfg.common_eval.results_path, exist_ok=True)
76
+ output_path = os.path.join(
77
+ cfg.common_eval.results_path,
78
+ "generate-{}.txt".format(cfg.dataset.gen_subset),
79
+ )
80
+ with open(output_path, "w", buffering=1, encoding="utf-8") as h:
81
+ return _main(cfg, h)
82
+ else:
83
+ return _main(cfg, sys.stdout, audio_path)
84
+
85
+
86
+ def get_symbols_to_strip_from_output(generator):
87
+ if hasattr(generator, "symbols_to_strip_from_output"):
88
+ return generator.symbols_to_strip_from_output
89
+ else:
90
+ return {generator.eos}
91
+
92
+
93
+ def _main(cfg: DictConfig, output_file, audio_path):
94
+ logging.basicConfig(
95
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
96
+ datefmt="%Y-%m-%d %H:%M:%S",
97
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
98
+ stream=output_file,
99
+ )
100
+ logger = logging.getLogger("fairseq_cli.generate")
101
+
102
+ utils.import_user_module(cfg.common)
103
+
104
+ if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None:
105
+ cfg.dataset.max_tokens = 12000
106
+ logger.info(cfg)
107
+
108
+ # Fix seed for stochastic decoding
109
+ if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
110
+ np.random.seed(cfg.common.seed)
111
+ utils.set_torch_seed(cfg.common.seed)
112
+
113
+ use_cuda = torch.cuda.is_available() and not cfg.common.cpu
114
+
115
+ # Load dataset splits
116
+ task = tasks.setup_task(cfg.task)
117
+
118
+ # Set dictionaries
119
+ try:
120
+ src_dict = getattr(task, "source_dictionary", None)
121
+ except NotImplementedError:
122
+ src_dict = None
123
+ tgt_dict = task.target_dictionary
124
+
125
+ overrides = ast.literal_eval(cfg.common_eval.model_overrides)
126
+
127
+ # Load ensemble
128
+ logger.info("loading model(s) from {}".format(cfg.common_eval.path))
129
+ models, saved_cfg = checkpoint_utils.load_model_ensemble(
130
+ utils.split_paths(cfg.common_eval.path),
131
+ arg_overrides=overrides,
132
+ task=task,
133
+ suffix=cfg.checkpoint.checkpoint_suffix,
134
+ strict=(cfg.checkpoint.checkpoint_shard_count == 1),
135
+ num_shards=cfg.checkpoint.checkpoint_shard_count,
136
+ )
137
+
138
+ # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
139
+ # task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task)
140
+
141
+ if cfg.generation.lm_path is not None:
142
+ overrides["data"] = cfg.task.data
143
+
144
+ try:
145
+ lms, _ = checkpoint_utils.load_model_ensemble(
146
+ [cfg.generation.lm_path], arg_overrides=overrides, task=None
147
+ )
148
+ except:
149
+ logger.warning(
150
+ f"Failed to load language model! Please make sure that the language model dict is the same "
151
+ f"as target dict and is located in the data dir ({cfg.task.data})"
152
+ )
153
+ raise
154
+
155
+ assert len(lms) == 1
156
+ else:
157
+ lms = [None]
158
+
159
+ # Optimize ensemble for generation
160
+ for model in chain(models, lms):
161
+ if model is None:
162
+ continue
163
+ if cfg.common.fp16:
164
+ model.half()
165
+ if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
166
+ model.cuda()
167
+ model.prepare_for_inference_(cfg)
168
+
169
+ # Load alignment dictionary for unknown word replacement
170
+ # (None if no unknown word replacement, empty if no path to align dictionary)
171
+ align_dict = utils.load_align_dict(cfg.generation.replace_unk)
172
+
173
+ # Initialize generator
174
+ gen_timer = StopwatchMeter()
175
+
176
+ extra_gen_cls_kwargs = {"lm_model": lms[0], "lm_weight": cfg.generation.lm_weight}
177
+ generator = task.build_generator(
178
+ models, cfg.generation, extra_gen_cls_kwargs=extra_gen_cls_kwargs
179
+ )
180
+
181
+ # Handle tokenization and BPE
182
+ tokenizer = task.build_tokenizer(cfg.tokenizer)
183
+ bpe = task.build_bpe(cfg.bpe)
184
+
185
+ def decode_fn(x):
186
+ if bpe is not None:
187
+ x = bpe.decode(x)
188
+ if tokenizer is not None:
189
+ x = tokenizer.decode(x)
190
+ return x
191
+
192
+ scorer = scoring.build_scorer(cfg.scoring, tgt_dict)
193
+
194
+ num_sentences = 0
195
+ has_target = True
196
+ wps_meter = TimeMeter()
197
+
198
+
199
+ wav, cur_sample_rate = sf.read(audio_path)
200
+ wav = torch.from_numpy(wav).float()
201
+ wav = postprocess(wav, cur_sample_rate)
202
+ sample = {'index': 0, 'net_input': {'source': torch.tensor(wav).unsqueeze(dim=0), 'padding_mask':
203
+ torch.BoolTensor(wav.shape).fill_(False).unsqueeze(dim=0)}, 'id': [0], 'target': [[None], ]}
204
+
205
+ prefix_tokens = None
206
+ if cfg.generation.prefix_size > 0:
207
+ prefix_tokens = sample["target"][:, : cfg.generation.prefix_size]
208
+
209
+ constraints = None
210
+ if "constraints" in sample:
211
+ constraints = sample["constraints"]
212
+
213
+ gen_timer.start()
214
+ hypos = task.inference_step(
215
+ generator,
216
+ models,
217
+ sample,
218
+ prefix_tokens=prefix_tokens,
219
+ constraints=constraints,
220
+ )
221
+ num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
222
+ gen_timer.stop(num_generated_tokens)
223
+
224
+ for i, sample_id in enumerate(sample["id"]):
225
+ has_target = False
226
+
227
+ # Remove padding
228
+ if "src_tokens" in sample["net_input"]:
229
+ src_tokens = utils.strip_pad(
230
+ sample["net_input"]["src_tokens"][i, :], tgt_dict.pad()
231
+ )
232
+ else:
233
+ src_tokens = None
234
+
235
+ target_tokens = None
236
+ if has_target:
237
+ target_tokens = (
238
+ utils.strip_pad(sample["target"][i, :], tgt_dict.pad()).int().cpu()
239
+ )
240
+
241
+ # Either retrieve the original sentences or regenerate them from tokens.
242
+ if align_dict is not None:
243
+ src_str = task.dataset(cfg.dataset.gen_subset).src.get_original_text(
244
+ sample_id
245
+ )
246
+ target_str = task.dataset(cfg.dataset.gen_subset).tgt.get_original_text(
247
+ sample_id
248
+ )
249
+ else:
250
+ if src_dict is not None:
251
+ src_str = src_dict.string(src_tokens, cfg.common_eval.post_process)
252
+ else:
253
+ src_str = ""
254
+ if has_target:
255
+ target_str = tgt_dict.string(
256
+ target_tokens,
257
+ cfg.common_eval.post_process,
258
+ escape_unk=True,
259
+ extra_symbols_to_ignore=get_symbols_to_strip_from_output(
260
+ generator
261
+ ),
262
+ )
263
+
264
+ src_str = decode_fn(src_str)
265
+ if has_target:
266
+ target_str = decode_fn(target_str)
267
+
268
+ if not cfg.common_eval.quiet:
269
+ if src_dict is not None:
270
+ print("S-{}\t{}".format(sample_id, src_str), file=output_file)
271
+ if has_target:
272
+ print("T-{}\t{}".format(sample_id, target_str), file=output_file)
273
+
274
+ # Process top predictions
275
+ for j, hypo in enumerate(hypos[i][: cfg.generation.nbest]):
276
+ hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
277
+ hypo_tokens=hypo["tokens"].int().cpu(),
278
+ src_str=src_str,
279
+ alignment=hypo["alignment"],
280
+ align_dict=align_dict,
281
+ tgt_dict=tgt_dict,
282
+ remove_bpe=cfg.common_eval.post_process,
283
+ extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator),
284
+ )
285
+ detok_hypo_str = decode_fn(hypo_str)
286
+ if not cfg.common_eval.quiet:
287
+ score = hypo["score"] / math.log(2) # convert to base 2
288
+ # original hypothesis (after tokenization and BPE)
289
+ print(
290
+ "H-{}\t{}\t{}".format(sample_id, score, hypo_str),
291
+ file=output_file,
292
+ )
293
+ # detokenized hypothesis
294
+ print(
295
+ "D-{}\t{}\t{}".format(sample_id, score, detok_hypo_str),
296
+ file=output_file,
297
+ )
298
+ print(
299
+ "P-{}\t{}".format(
300
+ sample_id,
301
+ " ".join(
302
+ map(
303
+ lambda x: "{:.4f}".format(x),
304
+ # convert from base e to base 2
305
+ hypo["positional_scores"]
306
+ .div_(math.log(2))
307
+ .tolist(),
308
+ )
309
+ ),
310
+ ),
311
+ file=output_file,
312
+ )
313
+
314
+ if cfg.generation.print_alignment == "hard":
315
+ print(
316
+ "A-{}\t{}".format(
317
+ sample_id,
318
+ " ".join(
319
+ [
320
+ "{}-{}".format(src_idx, tgt_idx)
321
+ for src_idx, tgt_idx in alignment
322
+ ]
323
+ ),
324
+ ),
325
+ file=output_file,
326
+ )
327
+ if cfg.generation.print_alignment == "soft":
328
+ print(
329
+ "A-{}\t{}".format(
330
+ sample_id,
331
+ " ".join(
332
+ [",".join(src_probs) for src_probs in alignment]
333
+ ),
334
+ ),
335
+ file=output_file,
336
+ )
337
+
338
+ if cfg.generation.print_step:
339
+ print(
340
+ "I-{}\t{}".format(sample_id, hypo["steps"]),
341
+ file=output_file,
342
+ )
343
+
344
+ if cfg.generation.retain_iter_history:
345
+ for step, h in enumerate(hypo["history"]):
346
+ _, h_str, _ = utils.post_process_prediction(
347
+ hypo_tokens=h["tokens"].int().cpu(),
348
+ src_str=src_str,
349
+ alignment=None,
350
+ align_dict=None,
351
+ tgt_dict=tgt_dict,
352
+ remove_bpe=None,
353
+ )
354
+ print(
355
+ "E-{}_{}\t{}".format(sample_id, step, h_str),
356
+ file=output_file,
357
+ )
358
+
359
+ # Score only the top hypothesis
360
+ if has_target and j == 0:
361
+ if (
362
+ align_dict is not None
363
+ or cfg.common_eval.post_process is not None
364
+ ):
365
+ # Convert back to tokens for evaluation with unk replacement and/or without BPE
366
+ target_tokens = tgt_dict.encode_line(
367
+ target_str, add_if_not_exist=True
368
+ )
369
+ hypo_tokens = tgt_dict.encode_line(
370
+ detok_hypo_str, add_if_not_exist=True
371
+ )
372
+ if hasattr(scorer, "add_string"):
373
+ scorer.add_string(target_str, detok_hypo_str)
374
+ else:
375
+ scorer.add(target_tokens, hypo_tokens)
376
+
377
+ wps_meter.update(num_generated_tokens)
378
+ # progress.log({"wps": round(wps_meter.avg)})
379
+
380
+ logger.info("NOTE: hypothesis and token scores are output in base 2")
381
+ if has_target:
382
+ if cfg.bpe and not cfg.generation.sacrebleu:
383
+ if cfg.common_eval.post_process:
384
+ logger.warning(
385
+ "BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization"
386
+ )
387
+ else:
388
+ logger.warning(
389
+ "If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words. Use --sacrebleu for standard 13a BLEU tokenization"
390
+ )
391
+ # use print to be consistent with other main outputs: S-, H-, T-, D- and so on
392
+ print(
393
+ "Generate {} with beam={}: {}".format(
394
+ cfg.dataset.gen_subset, cfg.generation.beam, scorer.result_string()
395
+ ),
396
+ file=output_file,
397
+ )
398
+ return detok_hypo_str
399
+
400
+ def inference(audio_path):
401
+ # parser = options.get_generation_parser()
402
+ # TODO: replace this workaround with refactoring of `AudioPretraining`
403
+ parser = argparse.ArgumentParser(description='Process some integers.')
404
+ parser.add_argument(
405
+ "--arch",
406
+ "-a",
407
+ metavar="ARCH",
408
+ default="wav2vec2",
409
+ help="Model architecture. For constructing tasks that rely on "
410
+ "model args (e.g. `AudioPretraining`)",
411
+ )
412
+ parser.add_argument('--data', type=str, default='./utils', metavar='data')
413
+ parser.add_argument('--bpe-tokenizer', type=str, default='./utils/arabic.model')
414
+ parser.add_argument('--user-dir', type=str, default='./SpeechT5/SpeechT5/speecht5')
415
+ parser.add_argument('--task', type=str, default='artst')
416
+ parser.add_argument('--t5-task', type=str, default='s2t')
417
+ parser.add_argument('--path', type=str, default='./ckpts/mgb2_asr.pt')
418
+ parser.add_argument('--ctc-weight', type=float, default=0.25)
419
+ parser.add_argument('--max-tokens', type=int, default=350000)
420
+ parser.add_argument('--beam', type=int, default=5)
421
+ parser.add_argument('--scoring', type=str, default='wer')
422
+ parser.add_argument('--max-len-a', type=float, default=0)
423
+ parser.add_argument('--max-len-b', type=int, default=1000)
424
+ parser.add_argument('--sample-rate', type=int, default=16000)
425
+ parser.add_argument('--batch-size', type=int, default=1)
426
+ # parser.add_argument('--num-workers', type=int, default=4)
427
+ parser.add_argument('--seed', type=int, default=4)
428
+ parser.add_argument('--normalize', type=bool, default=True)
429
+
430
+ args = parser.parse_args()
431
+ return main(args, audio_path=audio_path)
432
+
433
+
434
+ text_box = gr.Textbox(label="Arabic Text")
435
+ input_audio = gr.Audio(label="Upload Audio", type="filepath", sources="upload")
436
+ title="ArTST: Arabic Speech Recognition"
437
+ description="ArTST: Arabic text and speech transformer based on the T5 transformer. This space demonstarates the ASR checkpoint finetuned on \
438
+ the MGB-2 dataset. The model is pre-trained on the MGB-2 dataset."
439
+
440
+ examples=["/l/users/amirbek.djanibekov/artst-tts-demo/samples/sample_audio.wav"]
441
+
442
+ article = """
443
+ <div style='margin:20px auto;'>
444
+ <p>References: <a href="https://arxiv.org/abs/2310.16621">ArTST paper</a> |
445
+ <a href="https://github.com/mbzuai-nlp/ArTST">GitHub</a> |
446
+ <a href="https://huggingface.co/MBZUAI/ArTST">Weights and Tokenizer</a></p>
447
+ <pre>
448
+ @misc{toyin2023artst,
449
+ title={ArTST: Arabic Text and Speech Transformer},
450
+ author={Hawau Olamide Toyin and Amirbek Djanibekov and Ajinkya Kulkarni and Hanan Aldarmaki},
451
+ year={2023},
452
+ eprint={2310.16621},
453
+ archivePrefix={arXiv},
454
+ primaryClass={cs.CL}
455
+ }
456
+ </pre>
457
+ <p>Speaker embeddings were generated from <a href="http://www.festvox.org/cmu_arctic/">CMU ARCTIC</a>.</p>
458
+ <p>ArTST is based on <a href="https://arxiv.org/abs/2110.07205">SpeechT5 architecture</a>.</p>
459
+ </div>
460
+ """
461
+
462
+ demo = gr.Interface(inference, \
463
+ inputs=input_audio, outputs=text_box, title=title, description=description, examples=examples, article=article)
464
+
465
+ if __name__ == "__main__":
466
+ demo.launch(share=True)
artst/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import data, tasks, criterions, models # noqa
artst/criterions/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import os
3
+
4
+
5
+ for file in os.listdir(os.path.dirname(__file__)):
6
+ if file.endswith(".py") and not file.startswith("_"):
7
+ criterion_name = file[: file.find(".py")]
8
+ importlib.import_module(
9
+ "artst.criterions." + criterion_name
10
+ )
artst/criterions/artst_criterion.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transform (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+ # Based on speecht5, fairseq and espnet code bases
5
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
6
+ # --------------------------------------------------------
7
+
8
+ import re
9
+ from dataclasses import dataclass
10
+
11
+ import math
12
+ from fairseq import metrics, utils
13
+ from fairseq.criterions import FairseqCriterion, register_criterion
14
+ from artst.criterions.text_to_speech_loss import TexttoSpeechLoss
15
+ from artst.criterions.text_pretrain_criterion import TextPretrainCriterion, TextPretrainCriterionConfig
16
+ from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterionConfig
17
+ from artst.criterions.speech_pretrain_criterion import SpeechPretrainCriterion, SpeechPretrainCriterionConfig
18
+ from artst.criterions.speech_to_text_loss import SpeechtoTextLoss, SpeechtoTextLossConfig
19
+ from fairseq.logging.meters import safe_round
20
+
21
+ @dataclass
22
+ class ArTSTCriterionConfig(
23
+ LabelSmoothedCrossEntropyCriterionConfig,
24
+ TextPretrainCriterionConfig,
25
+ SpeechPretrainCriterionConfig,
26
+ SpeechtoTextLossConfig
27
+ ):
28
+ pass
29
+
30
+ @register_criterion(
31
+ "artst", dataclass=ArTSTCriterionConfig
32
+ )
33
+ class ArTSTCriterion(FairseqCriterion):
34
+ def __init__(
35
+ self,
36
+ task,
37
+ sentence_avg,
38
+ label_smoothing,
39
+ pred_masked_weight,
40
+ pred_nomask_weight,
41
+ loss_weights=None,
42
+ log_keys=None,
43
+ ignore_prefix_size=0,
44
+ report_accuracy=False,
45
+ use_masking=True,
46
+ use_weighted_masking=False,
47
+ loss_type="L1",
48
+ bce_pos_weight=5.0,
49
+ bce_loss_lambda=1.0,
50
+ use_guided_attn_loss=False,
51
+ num_heads_applied_guided_attn=2,
52
+ ce_weight=1.0,
53
+ ctc_weight=0.0,
54
+ hubert_weight=1.0,
55
+ dec_weight=1.0,
56
+ bart_weight=1.0,
57
+ ):
58
+ super().__init__(task)
59
+ self.speech_criterion = TexttoSpeechLoss(
60
+ task,
61
+ sentence_avg,
62
+ use_masking,
63
+ use_weighted_masking,
64
+ loss_type,
65
+ bce_pos_weight,
66
+ bce_loss_lambda,
67
+ use_guided_attn_loss,
68
+ num_heads_applied_guided_attn=num_heads_applied_guided_attn,
69
+ )
70
+ self.text_criterion = SpeechtoTextLoss(
71
+ SpeechtoTextLossConfig,
72
+ task,
73
+ sentence_avg,
74
+ label_smoothing,
75
+ ignore_prefix_size,
76
+ report_accuracy,
77
+ ce_weight,
78
+ ctc_weight
79
+ )
80
+ self.text_pretrain_criterion = TextPretrainCriterion(
81
+ task,
82
+ sentence_avg,
83
+ bart_weight,
84
+ loss_weights,
85
+ )
86
+ self.speech_pretrain_criterion = SpeechPretrainCriterion(
87
+ task,
88
+ sentence_avg,
89
+ pred_masked_weight,
90
+ pred_nomask_weight,
91
+ loss_weights,
92
+ log_keys,
93
+ use_masking,
94
+ use_weighted_masking,
95
+ loss_type,
96
+ bce_pos_weight,
97
+ hubert_weight,
98
+ dec_weight
99
+ )
100
+
101
+ def forward(self, model, sample, reduce=True):
102
+ """Compute the loss for the given sample.
103
+
104
+ Returns a tuple with three elements:
105
+ 1) the loss
106
+ 2) the sample size, which is used as the denominator for the gradient
107
+ 3) logging outputs to display while training
108
+ """
109
+
110
+ task_name = sample['task_name']
111
+ if task_name == 's2t' or task_name == 's2c':
112
+ return self.text_criterion(model, sample, reduce)
113
+ elif task_name == 't2s' or task_name == 's2s':
114
+ return self.speech_criterion(model, sample)
115
+ elif task_name == 'text_pretrain':
116
+ return self.text_pretrain_criterion(model, sample, reduce)
117
+ elif task_name == 'speech_pretrain':
118
+ return self.speech_pretrain_criterion(model, sample, reduce)
119
+
120
+ @classmethod
121
+ def reduce_metrics(cls, logging_outputs):
122
+ """Aggregate logging outputs from data parallel training."""
123
+ logging_outputs_dict = {}
124
+ for logging_output in logging_outputs:
125
+ for task_name in logging_output:
126
+ if task_name not in ['s2t', 't2s', 's2c', 's2s', 'text_pretrain', 'speech_pretrain']:
127
+ continue
128
+
129
+ if task_name not in logging_outputs_dict:
130
+ logging_outputs_dict[task_name] = []
131
+ logging_outputs_dict[task_name].append(logging_output[task_name])
132
+
133
+ for task_name in logging_outputs_dict:
134
+ if task_name == 's2t':
135
+ # LabelSmoothedCrossEntropyCriterion.reduce_metrics([logging_output['s2t'] for logging_output in logging_outputs])
136
+ s2t_logging_output = logging_outputs_dict[task_name]
137
+ # s2t_sum = sum(log.get("ce_loss", 0) for log in logging_outputs)
138
+ loss_sum = sum(log.get("loss", 0) for log in s2t_logging_output)
139
+ nll_loss_sum = sum(log.get("nll_loss", 0) for log in s2t_logging_output)
140
+ ntokens = sum(log.get("ntokens", 0) for log in s2t_logging_output)
141
+ ce_loss_sum = sum(log.get("ce_loss", 0) for log in s2t_logging_output)
142
+ ctc_loss_sum = sum(log.get("ctc_loss", 0) for log in s2t_logging_output)
143
+
144
+ sample_size = max(1, sum(log.get("sample_size", 0) for log in s2t_logging_output))
145
+ metrics.log_scalar(
146
+ "s2t_loss", loss_sum / sample_size / math.log(2), sample_size, 1, round=3
147
+ )
148
+
149
+ metrics.log_scalar(
150
+ "s2t_nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, 2, round=3
151
+ )
152
+ metrics.log_derived(
153
+ "s2t_ppl", lambda meters: utils.get_perplexity(meters["s2t_nll_loss"].avg, 2)
154
+ )
155
+ metrics.log_scalar(
156
+ "ctc_loss", ctc_loss_sum / sample_size / math.log(2), ntokens, 2, round=3
157
+ )
158
+ metrics.log_scalar(
159
+ "ce_loss", ce_loss_sum / ntokens, ntokens, 2, round=3
160
+ )
161
+
162
+ total = utils.item(sum(log.get("total", 0) for log in s2t_logging_output))
163
+ if total > 0:
164
+ metrics.log_scalar("s2t_total", total)
165
+ n_correct = utils.item(
166
+ sum(log.get("n_correct", 0) for log in s2t_logging_output)
167
+ )
168
+ metrics.log_scalar("s2t_n_correct", n_correct)
169
+ metrics.log_derived(
170
+ "s2t_accuracy",
171
+ lambda meters: round(
172
+ meters["s2t_n_correct"].sum * 100.0 / meters["s2t_total"].sum, 3
173
+ )
174
+ if meters["s2t_total"].sum > 0
175
+ else float("nan"),
176
+ 2
177
+ )
178
+ c_errors = sum(log.get("c_errors", 0) for log in s2t_logging_output)
179
+ metrics.log_scalar("_c_errors", c_errors)
180
+ c_total = sum(log.get("c_total", 0) for log in s2t_logging_output)
181
+ metrics.log_scalar("_c_total", c_total)
182
+ w_errors = sum(log.get("w_errors", 0) for log in s2t_logging_output)
183
+ metrics.log_scalar("_w_errors", w_errors)
184
+ wv_errors = sum(log.get("wv_errors", 0) for log in s2t_logging_output)
185
+ metrics.log_scalar("_wv_errors", wv_errors)
186
+ w_total = sum(log.get("w_total", 0) for log in s2t_logging_output)
187
+ metrics.log_scalar("_w_total", w_total)
188
+ if c_total > 0:
189
+ metrics.log_derived(
190
+ "uer",
191
+ lambda meters: safe_round(
192
+ meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3
193
+ )
194
+ if meters["_c_total"].sum > 0
195
+ else float("nan"),
196
+ )
197
+ if w_total > 0:
198
+ metrics.log_derived(
199
+ "wer",
200
+ lambda meters: safe_round(
201
+ meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3
202
+ )
203
+ if meters["_w_total"].sum > 0
204
+ else float("nan"),
205
+ )
206
+ metrics.log_derived(
207
+ "raw_wer",
208
+ lambda meters: safe_round(
209
+ meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3
210
+ )
211
+ if meters["_w_total"].sum > 0
212
+ else float("nan"),
213
+ )
214
+
215
+ if task_name == 't2s':
216
+ # TTSLossCriterion.reduce_metrics([logging_output['t2s'] for logging_output in logging_outputs])
217
+ # t2s_sum = sum(log.get("speech_loss", 0) for log in logging_outputs)
218
+ t2s_logging_output = logging_outputs_dict[task_name]
219
+ loss_sum = sum(log.get("loss", 0) for log in t2s_logging_output)
220
+ l1_loss_sum = sum(log.get("l1_loss", 0) for log in t2s_logging_output)
221
+ l2_loss_sum = sum(log.get("l2_loss", 0) for log in t2s_logging_output)
222
+ bce_loss_sum = sum(log.get("bce_loss", 0) for log in t2s_logging_output)
223
+ sample_size = max(1, sum(log.get("sample_size", 0) for log in t2s_logging_output))
224
+ metrics.log_scalar(
225
+ "t2s_loss", loss_sum / sample_size, sample_size, 1, round=5
226
+ )
227
+ encoder_alpha_sum = sum(log.get("encoder_alpha", 0) for log in t2s_logging_output)
228
+ decoder_alpha_sum = sum(log.get("decoder_alpha", 0) for log in t2s_logging_output)
229
+ ngpu = sum(log.get("ngpu", 0) for log in t2s_logging_output)
230
+
231
+ metrics.log_scalar(
232
+ "t2s_l1_loss", l1_loss_sum / sample_size, sample_size, 2, round=5
233
+ )
234
+ metrics.log_scalar(
235
+ "t2s_l2_loss", l2_loss_sum / sample_size, sample_size, 2, round=5
236
+ )
237
+ metrics.log_scalar(
238
+ "t2s_bce_loss", bce_loss_sum / sample_size, sample_size, 2, round=5
239
+ )
240
+ metrics.log_scalar(
241
+ "t2s_encoder_alpha", encoder_alpha_sum / sample_size, sample_size, round=5
242
+ )
243
+ metrics.log_scalar(
244
+ "t2s_decoder_alpha", decoder_alpha_sum / sample_size, sample_size, round=5
245
+ )
246
+
247
+ if "enc_dec_attn_loss" in t2s_logging_output[0]:
248
+ enc_dec_attn_loss_sum = sum(log.get("enc_dec_attn_loss", 0) for log in t2s_logging_output)
249
+ metrics.log_scalar(
250
+ "t2s_enc_dec_attn_loss", enc_dec_attn_loss_sum / sample_size, sample_size, round=8
251
+ )
252
+
253
+ if task_name == 's2c':
254
+ s2c_logging_output = logging_outputs_dict[task_name]
255
+ loss_sum = sum(log.get("loss", 0) for log in s2c_logging_output)
256
+ nll_loss_sum = sum(log.get("nll_loss", 0) for log in s2c_logging_output)
257
+ ntokens = sum(log.get("ntokens", 0) for log in s2c_logging_output)
258
+
259
+ sample_size = max(1, sum(log.get("sample_size", 0) for log in s2c_logging_output))
260
+ metrics.log_scalar(
261
+ "s2c_loss", loss_sum / sample_size / math.log(2), sample_size, 1, round=3
262
+ )
263
+
264
+ metrics.log_scalar(
265
+ "s2c_nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, 2, round=3
266
+ )
267
+
268
+ total = utils.item(sum(log.get("total", 0) for log in s2c_logging_output))
269
+ if total > 0:
270
+ metrics.log_scalar("s2c_total", total)
271
+ n_correct = utils.item(sum(log.get("n_correct", 0) for log in s2c_logging_output))
272
+ metrics.log_scalar("s2c_n_correct", n_correct)
273
+ metrics.log_derived(
274
+ "s2c_accuracy",
275
+ lambda meters: round(
276
+ meters["s2c_n_correct"].sum * 100.0 / meters["s2c_total"].sum, 3
277
+ )
278
+ if meters["s2c_total"].sum > 0
279
+ else float("nan"),
280
+ 2
281
+ )
282
+
283
+ if task_name == 's2s':
284
+ s2s_logging_output = logging_outputs_dict[task_name]
285
+ loss_sum = sum(log.get("loss", 0) for log in s2s_logging_output)
286
+ l1_loss_sum = sum(log.get("l1_loss", 0) for log in s2s_logging_output)
287
+ l2_loss_sum = sum(log.get("l2_loss", 0) for log in s2s_logging_output)
288
+ bce_loss_sum = sum(log.get("bce_loss", 0) for log in s2s_logging_output)
289
+ sample_size = max(1, sum(log.get("sample_size", 0) for log in s2s_logging_output))
290
+ metrics.log_scalar(
291
+ "s2s_loss", loss_sum / sample_size, sample_size, 1, round=5
292
+ )
293
+ encoder_alpha_sum = sum(log.get("encoder_alpha", 0) for log in s2s_logging_output)
294
+ decoder_alpha_sum = sum(log.get("decoder_alpha", 0) for log in s2s_logging_output)
295
+ ngpu = sum(log.get("ngpu", 0) for log in s2s_logging_output)
296
+
297
+ metrics.log_scalar(
298
+ "s2s_l1_loss", l1_loss_sum / sample_size, sample_size, 2, round=5
299
+ )
300
+ metrics.log_scalar(
301
+ "s2s_l2_loss", l2_loss_sum / sample_size, sample_size, 2, round=5
302
+ )
303
+ metrics.log_scalar(
304
+ "s2s_bce_loss", bce_loss_sum / sample_size, sample_size, 2, round=5
305
+ )
306
+ metrics.log_scalar(
307
+ "s2s_decoder_alpha", decoder_alpha_sum / sample_size, sample_size, round=5
308
+ )
309
+
310
+ if "enc_dec_attn_loss" in s2s_logging_output[0]:
311
+ enc_dec_attn_loss_sum = sum(log.get("enc_dec_attn_loss", 0) for log in s2s_logging_output)
312
+ metrics.log_scalar(
313
+ "s2s_enc_dec_attn_loss", enc_dec_attn_loss_sum / sample_size, sample_size, round=8
314
+ )
315
+
316
+ if task_name == 'text_pretrain':
317
+ bart_logging_output = logging_outputs_dict[task_name]
318
+ loss_sum = sum(log.get("loss", 0) for log in bart_logging_output)
319
+ ntokens = sum(log.get("ntokens", 0) for log in bart_logging_output)
320
+ sample_size = max(1, sum(log.get("sample_size", 0) for log in bart_logging_output))
321
+ bart_loss_sum = sum(log.get("bart_loss", 0) for log in bart_logging_output)
322
+
323
+ # we divide by log(2) to convert the loss from base e to base 2
324
+ metrics.log_scalar(
325
+ "text_loss", loss_sum / sample_size / math.log(2), sample_size, round=3
326
+ )
327
+ metrics.log_scalar(
328
+ "bart_loss", bart_loss_sum / sample_size / math.log(2), ntokens, 2, round=3
329
+ )
330
+ if sample_size != ntokens:
331
+ metrics.log_scalar(
332
+ "bart_nll_loss", bart_loss_sum / ntokens / math.log(2), ntokens, round=3
333
+ )
334
+ metrics.log_derived(
335
+ "bart_ppl", lambda meters: utils.get_perplexity(meters["bart_nll_loss"].avg)
336
+ )
337
+ else:
338
+ metrics.log_derived(
339
+ "bart_ppl", lambda meters: utils.get_perplexity(meters["bart_loss"].avg)
340
+ )
341
+ metrics.log_scalar("bart_wpb", ntokens, priority=180, round=1)
342
+
343
+ val_prob_perplexity = 0
344
+ val_code_perplexity = 0
345
+ sample_size_pp = 0
346
+ count_log_cp = 0
347
+ for log in bart_logging_output:
348
+ if "loss_prob_perplexity" in log:
349
+ val_prob_perplexity = val_prob_perplexity + log["loss_prob_perplexity"]
350
+ sample_size_pp = sample_size_pp + log["sample_size"]
351
+ if "code_perplexity" in log:
352
+ val_code_perplexity = val_code_perplexity + log["code_perplexity"]
353
+ count_log_cp = count_log_cp + 1
354
+ if val_prob_perplexity > 0:
355
+ metrics.log_scalar("text_loss_prob_perplexity", val_prob_perplexity / sample_size_pp / math.log(2), round=3)
356
+ if val_code_perplexity > 0:
357
+ metrics.log_scalar("text_code_perplexity", val_code_perplexity / count_log_cp, round=3)
358
+
359
+ if task_name == 'speech_pretrain':
360
+ hubert_logging_output = logging_outputs_dict[task_name]
361
+ loss_sum = sum(log.get("loss", 0) for log in hubert_logging_output)
362
+ ntokens = sum(log.get("ntokens", 0) for log in hubert_logging_output)
363
+ sample_size = max(1, sum(log.get("sample_size", 0) for log in hubert_logging_output))
364
+ dec_loss_sum = sum(log.get("dec_loss", 0) for log in hubert_logging_output)
365
+ l1_loss_sum = sum(log.get("l1_loss", 0) for log in hubert_logging_output)
366
+ l2_loss_sum = sum(log.get("l2_loss", 0) for log in hubert_logging_output)
367
+ bce_loss_sum = sum(log.get("bce_loss", 0) for log in hubert_logging_output)
368
+ ngpu = sum(log.get("ngpu", 0) for log in hubert_logging_output)
369
+
370
+ metrics.log_scalar("hubert_loss", loss_sum / sample_size / math.log(2), sample_size, round=3)
371
+ if sample_size != ntokens:
372
+ metrics.log_scalar("hubert_nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3)
373
+ metrics.log_derived("hubert_ppl", lambda meters: utils.get_perplexity(meters["hubert_nll_loss"].avg))
374
+ else:
375
+ metrics.log_derived("hubert_ppl", lambda meters: utils.get_perplexity(meters["hubert_loss"].avg))
376
+
377
+ counts = {}
378
+ for lk in hubert_logging_output[0].keys():
379
+ if lk.startswith("count_"):
380
+ val = sum(log[lk] for log in hubert_logging_output)
381
+ metrics.log_scalar("hubert_" + lk, val)
382
+ counts[lk] = val
383
+
384
+ for lk in hubert_logging_output[0].keys():
385
+ if lk.startswith("loss_") and lk != 'loss_prob_perplexity':
386
+ val = sum(log[lk] for log in hubert_logging_output)
387
+ metrics.log_scalar("hubert_" + lk, val / sample_size / math.log(2), round=3)
388
+ elif lk.startswith("correct_"):
389
+ val = sum(log[lk] for log in hubert_logging_output)
390
+ metrics.log_scalar("hubert_" + lk, val / counts[re.sub("correct", "count", lk)])
391
+ # elif lk == 'code_perplexity':
392
+ # val = sum(log[lk] for log in hubert_logging_output)
393
+ # metrics.log_scalar("hubert_" + lk, val / len(hubert_logging_output), round=3)
394
+
395
+ val_prob_perplexity = 0
396
+ val_code_perplexity = 0
397
+ sample_size_pp = 0
398
+ count_log_cp = 0
399
+ for log in hubert_logging_output:
400
+ if "loss_prob_perplexity" in log:
401
+ val_prob_perplexity = val_prob_perplexity + log["loss_prob_perplexity"]
402
+ sample_size_pp = sample_size_pp + log["sample_size"]
403
+ if "code_perplexity" in log:
404
+ val_code_perplexity = val_code_perplexity + log["code_perplexity"]
405
+ count_log_cp = count_log_cp + 1
406
+ if val_prob_perplexity > 0:
407
+ metrics.log_scalar("hubert_loss_prob_perplexity", val_prob_perplexity / sample_size_pp / math.log(2), round=3)
408
+ if val_code_perplexity > 0:
409
+ metrics.log_scalar("hubert_code_perplexity", val_code_perplexity / count_log_cp, round=3)
410
+
411
+ metrics.log_scalar(
412
+ "hubert_dec_loss", dec_loss_sum / ngpu, sample_size, 2, round=5
413
+ )
414
+ metrics.log_scalar(
415
+ "hubert_l1_loss", l1_loss_sum / ngpu, sample_size, 2, round=5
416
+ )
417
+ metrics.log_scalar(
418
+ "hubert_l2_loss", l2_loss_sum / ngpu, sample_size, 2, round=5
419
+ )
420
+ metrics.log_scalar(
421
+ "hubert_bce_loss", bce_loss_sum / ngpu, sample_size, 2, round=5
422
+ )
423
+ if "enc_dec_attn_loss" in hubert_logging_output[0]:
424
+ enc_dec_attn_loss_sum = sum(log.get("enc_dec_attn_loss", 0) for log in hubert_logging_output)
425
+ metrics.log_scalar(
426
+ "hubert_enc_dec_attn_loss", enc_dec_attn_loss_sum / ngpu, sample_size, round=8
427
+ )
428
+ metrics.log_scalar("hubert_wpb", ntokens, priority=180, round=1)
429
+
430
+ loss = sum(log.get("loss", 0) for log in logging_outputs)
431
+ sample_size = max(1, sum(log.get("sample_size", 0) for log in logging_outputs))
432
+ metrics.log_scalar(
433
+ "loss", loss / sample_size, sample_size, 1, round=5
434
+ )
435
+
436
+ @staticmethod
437
+ def logging_outputs_can_be_summed() -> bool:
438
+ """
439
+ Whether the logging outputs returned by `forward` can be summed
440
+ across workers prior to calling `reduce_metrics`. Setting this
441
+ to True will improves distributed training speed.
442
+ """
443
+ return False
artst/criterions/speech_pretrain_criterion.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transform (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+ # Based on speecht5, fairseq and espnet code bases
5
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
6
+ # --------------------------------------------------------
7
+
8
+ import math
9
+ import re
10
+ from dataclasses import dataclass, field
11
+ from typing import List, Optional
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from fairseq import metrics, utils
16
+ from fairseq.criterions import FairseqCriterion
17
+ from artst.criterions.text_to_speech_loss import TexttoSpeechLoss, TexttoSpeechLossConfig
18
+
19
+
20
+ @dataclass
21
+ class SpeechPretrainCriterionConfig(TexttoSpeechLossConfig):
22
+ pred_masked_weight: float = field(
23
+ default=1.0,
24
+ metadata={"help": "weight for predictive loss for masked frames"},
25
+ )
26
+ pred_nomask_weight: float = field(
27
+ default=0.0,
28
+ metadata={"help": "weight for predictive loss for unmasked frames"},
29
+ )
30
+ loss_weights: Optional[List[float]] = field(
31
+ default_factory=lambda: [10,],
32
+ metadata={"help": "weights for additional loss terms (not first one)"},
33
+ )
34
+ log_keys: List[str] = field(
35
+ default_factory=lambda: [],
36
+ metadata={"help": "output keys to log"},
37
+ )
38
+ hubert_weight: float = field(
39
+ default=1.0,
40
+ metadata={"help": "weight of hubert loss"},
41
+ )
42
+ dec_weight: float = field(
43
+ default=1.0,
44
+ metadata={"help": "weight of decoder loss"},
45
+ )
46
+
47
+
48
+ class SpeechPretrainCriterion(FairseqCriterion):
49
+ def __init__(
50
+ self,
51
+ task,
52
+ sentence_avg,
53
+ pred_masked_weight,
54
+ pred_nomask_weight,
55
+ loss_weights=None,
56
+ log_keys=None,
57
+ use_masking=True,
58
+ use_weighted_masking=False,
59
+ loss_type="L1",
60
+ bce_pos_weight=5.0,
61
+ hubert_weight=1.0,
62
+ dec_weight=1.0,
63
+ ):
64
+ super().__init__(task)
65
+ self.pred_masked_weight = pred_masked_weight
66
+ self.pred_nomask_weight = pred_nomask_weight
67
+ self.loss_weights = loss_weights
68
+ self.log_keys = [] if log_keys is None else log_keys
69
+ self.hubert_weight = hubert_weight
70
+ self.dec_weight = dec_weight
71
+
72
+ self.speech_criterion = TexttoSpeechLoss(
73
+ task,
74
+ sentence_avg,
75
+ use_masking,
76
+ use_weighted_masking,
77
+ loss_type,
78
+ bce_pos_weight,
79
+ )
80
+
81
+ def forward(self, model, sample, reduce=True, log_pred=False):
82
+ """Compute the loss for the given sample.
83
+ Returns a tuple with three elements:
84
+ 1) the loss
85
+ 2) the sample size, which is used as the denominator for the gradient
86
+ 3) logging outputs to display while training
87
+ """
88
+ if self.dec_weight == 0:
89
+ sample["net_input"]["only_hubert"] = True
90
+ net_output, net_output_dec = model(target_list=sample["target_list"], **sample["net_input"])
91
+ loss = 0.
92
+ sample_size = 0
93
+ logging_output = {}
94
+ reduction = "sum" if reduce else "none"
95
+
96
+ loss_m_list = []
97
+ logp_m_list = model.get_logits(net_output, True)
98
+ targ_m_list = model.get_targets(None, net_output, True)
99
+ assert self.pred_masked_weight == 0 or len(logp_m_list) > 0
100
+ for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)):
101
+ loss_m = F.cross_entropy(logp_m, targ_m, reduction=reduction)
102
+ loss_m_list.append(loss_m)
103
+ logging_output[f"loss_m_{i}"] = loss_m.detach().item()
104
+ if self.pred_masked_weight > 0:
105
+ loss += self.pred_masked_weight * sum(loss_m_list)
106
+ sample_size += targ_m_list[0].numel()
107
+
108
+ loss_u_list = []
109
+ logp_u_list = model.get_logits(net_output, False)
110
+ targ_u_list = model.get_targets(None, net_output, False)
111
+ assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0
112
+ for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)):
113
+ loss_u = F.cross_entropy(logp_u, targ_u, reduction=reduction)
114
+ loss_u_list.append(loss_u)
115
+ logging_output[f"loss_u_{i}"] = loss_u.detach().item()
116
+ if self.pred_nomask_weight > 0:
117
+ loss += self.pred_nomask_weight * sum(loss_u_list)
118
+ sample_size += targ_u_list[0].numel()
119
+
120
+ if self.loss_weights is not None:
121
+ assert hasattr(model, "get_extra_losses")
122
+ extra_losses, names = model.get_extra_losses(net_output)
123
+ if torch.is_tensor(extra_losses):
124
+ extra_losses = [extra_losses]
125
+ names = [names]
126
+ if len(self.loss_weights) == 1 and len(extra_losses) != 1:
127
+ self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
128
+ if len(self.loss_weights) > len(extra_losses):
129
+ modified_loss_weight = self.loss_weights[:len(extra_losses)]
130
+ else:
131
+ modified_loss_weight = self.loss_weights
132
+
133
+ # assert len(extra_losses) == len(self.loss_weights), f"{len(extra_losses)}, {len(self.loss_weights)}"
134
+ for p, n, coef in zip(extra_losses, names, modified_loss_weight):
135
+ # print(n + str(coef))
136
+ if coef != 0 and p is not None:
137
+ p = coef * p.float() * sample_size
138
+ loss += p
139
+ logging_output[f"loss_{n}"] = p.detach().item()
140
+
141
+ logging_output = {
142
+ "ntokens": sample_size,
143
+ "nsentences": sample["id"].numel(),
144
+ "sample_size": sample_size,
145
+ "ngpu": 1,
146
+ **logging_output,
147
+ }
148
+
149
+ if 'loss_prob_perplexity' in logging_output:
150
+ logging_output['code_perplexity'] = net_output['code_perplexity'].detach().item()
151
+
152
+ for lk in self.log_keys:
153
+ if lk in net_output:
154
+ logging_output[lk] = float((net_output[lk].item()))
155
+
156
+ def compute_correct(logits):
157
+ if logits.numel() == 0:
158
+ return 0, 0
159
+ else:
160
+ assert logits.dim() > 1, logits.shape
161
+ max = logits.argmax(-1) == 0
162
+ min = logits.argmin(-1) == 0
163
+ both = max & min
164
+ corr = max.long().sum().item() - both.long().sum().item()
165
+ count = max.numel()
166
+ return corr, count
167
+
168
+ with torch.no_grad():
169
+ for i, logp_m in enumerate(logp_m_list):
170
+ corr_m, count_m = compute_correct(logp_m)
171
+ logging_output[f"correct_m_{i}"] = corr_m
172
+ logging_output[f"count_m_{i}"] = count_m
173
+
174
+ for i, logp_u in enumerate(logp_u_list):
175
+ corr_u, count_u = compute_correct(logp_u)
176
+ logging_output[f"correct_u_{i}"] = corr_u
177
+ logging_output[f"count_u_{i}"] = count_u
178
+
179
+ if self.dec_weight == 0.0:
180
+ logging_output["loss"] = loss.item() if reduce else loss
181
+ return loss, sample_size, logging_output
182
+
183
+ # ## dec loss
184
+ dec_loss, l1_loss, l2_loss, bce_loss, enc_dec_attn_loss = self.speech_criterion.compute_loss(model, net_output_dec, sample)
185
+
186
+ # Log tts loss
187
+ logging_output['dec_loss'] = dec_loss.item()
188
+ logging_output['l1_loss'] = l1_loss.item()
189
+ logging_output['l2_loss'] = l2_loss.item()
190
+ logging_output['bce_loss'] = bce_loss.item()
191
+ if enc_dec_attn_loss is not None:
192
+ logging_output['enc_dec_attn_loss'] = enc_dec_attn_loss.item()
193
+
194
+ loss = self.hubert_weight * loss + self.dec_weight * sample_size * dec_loss
195
+ logging_output["loss"] = loss.item() if reduce else loss
196
+ return loss, sample_size, logging_output
197
+
198
+ @staticmethod
199
+ def reduce_metrics(logging_outputs) -> None:
200
+ """Aggregate logging outputs from data parallel training (copied from normal cross entropy)."""
201
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
202
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
203
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
204
+ dec_loss_sum = sum(log.get("dec_loss", 0) for log in logging_outputs)
205
+ l1_loss_sum = sum(log.get("l1_loss", 0) for log in logging_outputs)
206
+ l2_loss_sum = sum(log.get("l2_loss", 0) for log in logging_outputs)
207
+ bce_loss_sum = sum(log.get("bce_loss", 0) for log in logging_outputs)
208
+ ngpu = sum(log.get("ngpu", 0) for log in logging_outputs)
209
+
210
+ metrics.log_scalar("loss", loss_sum / sample_size / math.log(2), sample_size, round=3)
211
+ if sample_size != ntokens:
212
+ metrics.log_scalar("nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3)
213
+ metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg))
214
+ else:
215
+ metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["loss"].avg))
216
+
217
+ counts = {}
218
+ for lk in logging_outputs[0].keys():
219
+ if lk.startswith("count_"):
220
+ val = sum(log[lk] for log in logging_outputs)
221
+ metrics.log_scalar(lk, val)
222
+ counts[lk] = val
223
+
224
+ for lk in logging_outputs[0].keys():
225
+ if lk.startswith("loss_"):
226
+ val = sum(log[lk] for log in logging_outputs)
227
+ metrics.log_scalar(lk, val / sample_size / math.log(2), round=3)
228
+ elif lk.startswith("correct_"):
229
+ val = sum(log[lk] for log in logging_outputs)
230
+ metrics.log_scalar(lk, val / counts[re.sub("correct", "count", lk)])
231
+ elif lk == 'code_perplexity':
232
+ val = sum(log[lk] for log in logging_outputs)
233
+ metrics.log_scalar(lk, val / len(logging_outputs), round=3)
234
+
235
+ metrics.log_scalar(
236
+ "dec_loss", dec_loss_sum / ngpu, sample_size, 2, round=5
237
+ )
238
+ metrics.log_scalar(
239
+ "l1_loss", l1_loss_sum / ngpu, sample_size, 2, round=5
240
+ )
241
+ metrics.log_scalar(
242
+ "l2_loss", l2_loss_sum / ngpu, sample_size, 2, round=5
243
+ )
244
+ metrics.log_scalar(
245
+ "bce_loss", bce_loss_sum / ngpu, sample_size, 2, round=5
246
+ )
247
+ if "enc_dec_attn_loss" in logging_outputs[0]:
248
+ enc_dec_attn_loss_sum = sum(log.get("enc_dec_attn_loss", 0) for log in logging_outputs)
249
+ metrics.log_scalar(
250
+ "enc_dec_attn_loss", enc_dec_attn_loss_sum / ngpu, sample_size, round=8
251
+ )
252
+
253
+ @staticmethod
254
+ def aggregate_logging_outputs(logging_outputs):
255
+ """Aggregate logging outputs from data parallel training."""
256
+ raise NotImplementedError()
257
+
258
+ @staticmethod
259
+ def logging_outputs_can_be_summed() -> bool:
260
+ """
261
+ Whether the logging outputs returned by `forward` can be summed
262
+ across workers prior to calling `reduce_metrics`. Setting this
263
+ to True will improves distributed training speed.
264
+ """
265
+ return False
artst/criterions/speech_to_text_loss.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transform (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+ # Based on speecht5, fairseq and espnet code bases
5
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
6
+ # --------------------------------------------------------
7
+
8
+ import math
9
+ from argparse import Namespace
10
+ from dataclasses import dataclass, field
11
+ from omegaconf import II
12
+ from typing import Optional
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from fairseq import metrics, utils
17
+ from fairseq.criterions import FairseqCriterion, register_criterion
18
+ from fairseq.dataclass import FairseqDataclass
19
+ from fairseq.data.data_utils import post_process
20
+ from fairseq.tasks import FairseqTask
21
+ from fairseq.logging.meters import safe_round
22
+
23
+ import logging
24
+ logger = logging.getLogger(__name__)
25
+
26
+ @dataclass
27
+ class SpeechtoTextLossConfig(FairseqDataclass):
28
+ zero_infinity: bool = field(
29
+ default=False,
30
+ metadata={"help": "zero inf loss when source length <= target length"},
31
+ )
32
+ sentence_avg: bool = II("optimization.sentence_avg")
33
+ post_process: Optional[str] = field(
34
+ default="sentencepiece",
35
+ metadata={
36
+ "help": "how to post process predictions into words. can be letter, "
37
+ "wordpiece, BPE symbols, etc. "
38
+ "See fairseq.data.data_utils.post_process() for full list of options"
39
+ },
40
+ )
41
+ wer_kenlm_model: Optional[str] = field(
42
+ default=None,
43
+ metadata={
44
+ "help": "if this is provided, use kenlm to compute wer (along with other wer_* args)"
45
+ },
46
+ )
47
+ wer_lexicon: Optional[str] = field(
48
+ default=None,
49
+ metadata={"help": "lexicon to use with wer_kenlm_model"},
50
+ )
51
+ wer_lm_weight: float = field(
52
+ default=2.0,
53
+ metadata={"help": "lm weight to use with wer_kenlm_model"},
54
+ )
55
+ wer_word_score: float = field(
56
+ default=-1.0,
57
+ metadata={"help": "lm word score to use with wer_kenlm_model"},
58
+ )
59
+
60
+ wer_args: Optional[str] = field(
61
+ default=None,
62
+ metadata={
63
+ "help": "DEPRECATED: tuple of (wer_kenlm_model, wer_lexicon, wer_lm_weight, wer_word_score)"
64
+ },
65
+ )
66
+
67
+ label_smoothing: float = field(
68
+ default=0.0,
69
+ metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
70
+ )
71
+ report_accuracy: bool = field(
72
+ default=False,
73
+ metadata={"help": "report accuracy metric"},
74
+ )
75
+ ignore_prefix_size: int = field(
76
+ default=0,
77
+ metadata={"help": "Ignore first N tokens"},
78
+ )
79
+ #: bool = II("optimization.sentence_avg")
80
+
81
+ ce_weight: float = field(
82
+ default=1.0,
83
+ metadata={"help": "loss weight for cross entropy"},
84
+ )
85
+ ctc_weight: float = field(
86
+ default=0.0,
87
+ metadata={"help": "loss weiehgt for ctc in ASR"},
88
+ )
89
+
90
+
91
+ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True):
92
+ if target.dim() == lprobs.dim() - 1:
93
+ target = target.unsqueeze(-1)
94
+ nll_loss = -lprobs.gather(dim=-1, index=target)
95
+ smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
96
+ if ignore_index is not None:
97
+ pad_mask = target.eq(ignore_index)
98
+ nll_loss.masked_fill_(pad_mask, 0.0)
99
+ smooth_loss.masked_fill_(pad_mask, 0.0)
100
+ else:
101
+ nll_loss = nll_loss.squeeze(-1)
102
+ smooth_loss = smooth_loss.squeeze(-1)
103
+ if reduce:
104
+ nll_loss = nll_loss.sum()
105
+ smooth_loss = smooth_loss.sum()
106
+ eps_i = epsilon / (lprobs.size(-1) - 1)
107
+ loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss
108
+ return loss, nll_loss
109
+
110
+
111
+ class SpeechtoTextLoss(FairseqCriterion):
112
+ def __init__(
113
+ self,
114
+ cfg: SpeechtoTextLossConfig,
115
+ task: FairseqTask,
116
+ sentence_avg=True,
117
+ label_smoothing=0.1,
118
+ ignore_prefix_size=0,
119
+ report_accuracy=False,
120
+ ce_weight=1.0,
121
+ ctc_weight=0.0,
122
+ ):
123
+
124
+ super().__init__(task)
125
+ self.blank_idx = (
126
+ task.target_dictionary.index(task.blank_symbol)
127
+ if hasattr(task, "blank_symbol")
128
+ else 0
129
+ )
130
+ #print ("self.blank_idx: ", self.blank_idx)
131
+
132
+ self.pad_idx = task.target_dictionary.pad()
133
+ self.eos_idx = task.target_dictionary.eos()
134
+ self.post_process = cfg.post_process
135
+ self.ce_weight = ce_weight
136
+ self.ctc_weight = ctc_weight
137
+
138
+ ## for ce
139
+ self.sentence_avg = sentence_avg
140
+ self.eps = label_smoothing
141
+ self.ignore_prefix_size = ignore_prefix_size
142
+ self.report_accuracy = report_accuracy
143
+
144
+ if cfg.wer_args is not None:
145
+ (
146
+ cfg.wer_kenlm_model,
147
+ cfg.wer_lexicon,
148
+ cfg.wer_lm_weight,
149
+ cfg.wer_word_score,
150
+ ) = eval(cfg.wer_args)
151
+
152
+ if cfg.wer_kenlm_model is not None:
153
+ from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
154
+
155
+ dec_args = Namespace()
156
+ dec_args.nbest = 1
157
+ dec_args.criterion = "ctc"
158
+ dec_args.kenlm_model = cfg.wer_kenlm_model
159
+ dec_args.lexicon = cfg.wer_lexicon
160
+ dec_args.beam = 50
161
+ dec_args.beam_size_token = min(50, len(task.target_dictionary))
162
+ dec_args.beam_threshold = min(50, len(task.target_dictionary))
163
+ dec_args.lm_weight = cfg.wer_lm_weight
164
+ dec_args.word_score = cfg.wer_word_score
165
+ dec_args.unk_weight = -math.inf
166
+ dec_args.sil_weight = 0
167
+
168
+ self.w2l_decoder = W2lKenLMDecoder(dec_args, task.target_dictionary)
169
+ else:
170
+ self.w2l_decoder = None
171
+
172
+ self.zero_infinity = cfg.zero_infinity
173
+ #self.sentence_avg = cfg.sentence_avg
174
+
175
+ if self.ce_weight > 0 and self.ctc_weight > 0:
176
+ logger.info("Using cross entropy loss and CTC loss for ASR")
177
+ elif self.ce_weight > 0:
178
+ logger.info("Only using CE loss")
179
+ elif self.ctc_weight > 0:
180
+ logger.info("Only using CTC loss for ASR")
181
+ else:
182
+ logger.info("ERROR")
183
+
184
+ def forward(self, model, sample, reduce=True):
185
+
186
+ if self.ce_weight == 0 and self.ctc_weight > 0:
187
+ sample["only_ctc"] = True
188
+
189
+ net_output_decoder, net_output = model(**sample["net_input"])
190
+
191
+ if self.ce_weight > 0:
192
+ loss_ce, nll_loss_ce = self.compute_loss(model, net_output_decoder, sample, reduce=reduce)
193
+ #print ("loss_ce: ", loss_ce)
194
+ else:
195
+ nll_loss_ce = None
196
+
197
+ if self.ctc_weight > 0:
198
+ loss_ctc, lprobs, input_lengths = self.compute_loss_ctc(model, net_output, sample)
199
+
200
+ if self.ce_weight > 0 and self.ctc_weight > 0:
201
+ loss = self.ce_weight * loss_ce + self.ctc_weight * loss_ctc
202
+ elif self.ce_weight > 0:
203
+ loss = loss_ce
204
+ elif self.ctc_weight > 0:
205
+ loss = loss_ctc
206
+ else:
207
+ logger.info("ERROR: must ce_weight > 0 or ctc_weight > 0")
208
+
209
+ ntokens = (
210
+ sample["ntokens"] if "ntokens" in sample else sample["target_lengths"].sum().item()
211
+ )
212
+
213
+ sample_size = sample["target"].size(0) if self.sentence_avg else ntokens
214
+
215
+ logging_output = {
216
+ "loss": loss.item(),
217
+ "ce_loss": loss_ce.item() if self.ce_weight > 0 else 0,
218
+ "ctc_loss": loss_ctc.item() if self.ctc_weight > 0 else 0,
219
+ "nll_loss": nll_loss_ce.item() if nll_loss_ce is not None else 0,
220
+ "ntokens": sample["ntokens"],
221
+ "nsentences": sample["target"].size(0),
222
+ "sample_size": sample_size,
223
+ }
224
+
225
+ if self.ce_weight > 0 and self.report_accuracy:
226
+ n_correct, total = self.compute_accuracy(model, net_output_decoder, sample)
227
+ logging_output["n_correct"] = utils.item(n_correct.item())
228
+ logging_output["total"] = utils.item(total.data)
229
+
230
+ if self.ctc_weight > 0 and not model.training:
231
+ import editdistance
232
+
233
+ with torch.no_grad():
234
+ lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu()
235
+
236
+ c_err = 0
237
+ c_len = 0
238
+ w_errs = 0
239
+ w_len = 0
240
+ wv_errs = 0
241
+ for lp, t, inp_l in zip(
242
+ lprobs_t,
243
+ sample["target_label"]
244
+ if "target_label" in sample
245
+ else sample["target"],
246
+ input_lengths,
247
+ ):
248
+ lp = lp[:inp_l].unsqueeze(0)
249
+
250
+ decoded = None
251
+ if self.w2l_decoder is not None:
252
+ decoded = self.w2l_decoder.decode(lp)
253
+ if len(decoded) < 1:
254
+ decoded = None
255
+ else:
256
+ decoded = decoded[0]
257
+ if len(decoded) < 1:
258
+ decoded = None
259
+ else:
260
+ decoded = decoded[0]
261
+
262
+ p = (t != self.task.target_dictionary.pad()) & (
263
+ t != self.task.target_dictionary.eos()
264
+ )
265
+ targ = t[p]
266
+ targ_units = self.task.target_dictionary.string(targ)
267
+ targ_units_arr = targ.tolist()
268
+
269
+ toks = lp.argmax(dim=-1).unique_consecutive()
270
+ pred_units_arr = toks[toks != self.blank_idx].tolist()
271
+
272
+ c_err += editdistance.eval(pred_units_arr, targ_units_arr)
273
+ c_len += len(targ_units_arr)
274
+
275
+ targ_words = post_process(targ_units, self.post_process).split()
276
+
277
+ pred_units = self.task.target_dictionary.string(pred_units_arr)
278
+ pred_words_raw = post_process(pred_units, self.post_process).split()
279
+
280
+ if decoded is not None and "words" in decoded:
281
+ pred_words = decoded["words"]
282
+ w_errs += editdistance.eval(pred_words, targ_words)
283
+ wv_errs += editdistance.eval(pred_words_raw, targ_words)
284
+ else:
285
+ dist = editdistance.eval(pred_words_raw, targ_words)
286
+ w_errs += dist
287
+ wv_errs += dist
288
+
289
+ w_len += len(targ_words)
290
+
291
+ logging_output["wv_errors"] = wv_errs
292
+ logging_output["w_errors"] = w_errs
293
+ logging_output["w_total"] = w_len
294
+ logging_output["c_errors"] = c_err
295
+ logging_output["c_total"] = c_len
296
+
297
+ return loss, sample_size, logging_output
298
+
299
+ def compute_loss_ctc(self, model, net_output, sample):
300
+ lprobs = model.get_normalized_probs_for_ctc(
301
+ net_output, log_probs=True
302
+ ).contiguous() # (T, B, C) from the encoder
303
+
304
+ if net_output["encoder_padding_mask"] is not None:
305
+ non_padding_mask = ~net_output["encoder_padding_mask"][0]
306
+ input_lengths = non_padding_mask.long().sum(-1)
307
+ else:
308
+ input_lengths = lprobs.new_full(
309
+ (lprobs.size(1),), lprobs.size(0), dtype=torch.long
310
+ )
311
+
312
+ pad_mask = (sample["target"] != self.pad_idx) & (
313
+ sample["target"] != self.eos_idx
314
+ )
315
+ targets_flat = sample["target"].masked_select(pad_mask)
316
+ if "target_lengths" in sample:
317
+ target_lengths = sample["target_lengths"]
318
+ else:
319
+ target_lengths = pad_mask.sum(-1)
320
+
321
+ ##processing
322
+ target_lengths = target_lengths - 1
323
+
324
+ with torch.backends.cudnn.flags(enabled=False):
325
+ loss_ctc = F.ctc_loss(
326
+ lprobs,
327
+ targets_flat,
328
+ input_lengths,
329
+ target_lengths,
330
+ blank=self.blank_idx,
331
+ reduction="sum",
332
+ zero_infinity=True,
333
+ )
334
+
335
+ return loss_ctc, lprobs, input_lengths
336
+
337
+ ## for ce
338
+ def get_lprobs_and_target(self, model, net_output, sample):
339
+ lprobs = model.get_normalized_probs(net_output, log_probs=True)
340
+ target = model.get_targets(sample, net_output)
341
+ if self.ignore_prefix_size > 0:
342
+ if getattr(lprobs, "batch_first", False):
343
+ lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
344
+ target = target[:, self.ignore_prefix_size :].contiguous()
345
+ else:
346
+ lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous()
347
+ target = target[self.ignore_prefix_size :, :].contiguous()
348
+ return lprobs.view(-1, lprobs.size(-1)), target.view(-1)
349
+
350
+ def compute_loss(self, model, net_output, sample, reduce=True):
351
+ lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
352
+ loss, nll_loss = label_smoothed_nll_loss(
353
+ lprobs,
354
+ target,
355
+ self.eps,
356
+ ignore_index=self.padding_idx,
357
+ reduce=reduce,
358
+ )
359
+ return loss, nll_loss
360
+
361
+ def compute_accuracy(self, model, net_output, sample):
362
+ lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
363
+ mask = target.ne(self.padding_idx)
364
+ n_correct = torch.sum(
365
+ lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
366
+ )
367
+ total = torch.sum(mask)
368
+ return n_correct, total
369
+
370
+
371
+ @staticmethod
372
+ def reduce_metrics(logging_outputs) -> None:
373
+ """Aggregate logging outputs from data parallel training."""
374
+
375
+ loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
376
+ nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
377
+ ce_loss_sum = sum(log.get("ce_loss", 0) for log in logging_outputs)
378
+ ctc_loss_sum = sum(log.get("ctc_loss", 0) for log in logging_outputs)
379
+ ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
380
+ nsentences = utils.item(
381
+ sum(log.get("nsentences", 0) for log in logging_outputs)
382
+ )
383
+ sample_size = utils.item(
384
+ sum(log.get("sample_size", 0) for log in logging_outputs)
385
+ )
386
+
387
+ metrics.log_scalar(
388
+ "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
389
+ )
390
+
391
+ metrics.log_scalar(
392
+ "ctc_loss", ctc_loss_sum / sample_size / math.log(2), ntokens, 2, round=3
393
+ )
394
+ metrics.log_scalar(
395
+ "ce_loss", ce_loss_sum / ntokens, ntokens, 2, round=3
396
+ )
397
+ metrics.log_scalar(
398
+ "nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, 2, round=3
399
+ )
400
+ metrics.log_derived(
401
+ "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg, 2)
402
+ )
403
+
404
+ total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
405
+ if total > 0:
406
+ metrics.log_scalar("total", total)
407
+ n_correct = utils.item(
408
+ sum(log.get("n_correct", 0) for log in logging_outputs)
409
+ )
410
+ metrics.log_scalar("n_correct", n_correct)
411
+ metrics.log_derived(
412
+ "accuracy",
413
+ lambda meters: round(
414
+ meters["n_correct"].sum * 100.0 / meters["total"].sum, 3
415
+ )
416
+ if meters["total"].sum > 0
417
+ else float("nan"),
418
+ 2
419
+ )
420
+
421
+ metrics.log_scalar("ntokens", ntokens)
422
+ metrics.log_scalar("nsentences", nsentences)
423
+ if sample_size != ntokens:
424
+ metrics.log_scalar(
425
+ "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
426
+ )
427
+
428
+ c_errors = sum(log.get("c_errors", 0) for log in logging_outputs)
429
+ metrics.log_scalar("_c_errors", c_errors)
430
+ c_total = sum(log.get("c_total", 0) for log in logging_outputs)
431
+ metrics.log_scalar("_c_total", c_total)
432
+ w_errors = sum(log.get("w_errors", 0) for log in logging_outputs)
433
+ metrics.log_scalar("_w_errors", w_errors)
434
+ wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs)
435
+ metrics.log_scalar("_wv_errors", wv_errors)
436
+ w_total = sum(log.get("w_total", 0) for log in logging_outputs)
437
+ metrics.log_scalar("_w_total", w_total)
438
+
439
+ if c_total > 0:
440
+ metrics.log_derived(
441
+ "uer",
442
+ lambda meters: safe_round(
443
+ meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3
444
+ )
445
+ if meters["_c_total"].sum > 0
446
+ else float("nan"),
447
+ )
448
+ if w_total > 0:
449
+ metrics.log_derived(
450
+ "wer",
451
+ lambda meters: safe_round(
452
+ meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3
453
+ )
454
+ if meters["_w_total"].sum > 0
455
+ else float("nan"),
456
+ )
457
+ metrics.log_derived(
458
+ "raw_wer",
459
+ lambda meters: safe_round(
460
+ meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3
461
+ )
462
+ if meters["_w_total"].sum > 0
463
+ else float("nan"),
464
+ )
465
+
466
+ @staticmethod
467
+ def logging_outputs_can_be_summed() -> bool:
468
+ """
469
+ Whether the logging outputs returned by `forward` can be summed
470
+ across workers prior to calling `reduce_metrics`. Setting this
471
+ to True will improves distributed training speed.
472
+ """
473
+ return True
artst/criterions/text_pretrain_criterion.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transform (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+ # Based on speecht5, fairseq and espnet code bases
5
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
6
+ # --------------------------------------------------------
7
+
8
+ import math
9
+ from dataclasses import dataclass, field
10
+ from typing import List, Optional
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from fairseq import metrics, utils
15
+ from fairseq.criterions import FairseqCriterion, register_criterion
16
+ from fairseq.dataclass import FairseqDataclass
17
+ from omegaconf import II
18
+
19
+
20
+ @dataclass
21
+ class TextPretrainCriterionConfig(FairseqDataclass):
22
+ sentence_avg: bool = II("optimization.sentence_avg")
23
+ loss_weights: Optional[List[float]] = field(
24
+ default_factory=lambda: [0.1,],
25
+ metadata={"help": "weights for additional loss terms (not first one)"},
26
+ )
27
+ bart_weight: float = field(
28
+ default=1.0,
29
+ metadata={"help": "loss weight for cross entropy"},
30
+ )
31
+
32
+
33
+ class TextPretrainCriterion(FairseqCriterion):
34
+ def __init__(self, task, sentence_avg, bart_weight, loss_weights=None):
35
+ super().__init__(task)
36
+ self.sentence_avg = sentence_avg
37
+ self.loss_weights = loss_weights
38
+ self.bart_weight = bart_weight
39
+
40
+ def forward(self, model, sample, reduce=True):
41
+ """Compute the loss for the given sample.
42
+
43
+ Returns a tuple with three elements:
44
+ 1) the loss
45
+ 2) the sample size, which is used as the denominator for the gradient
46
+ 3) logging outputs to display while training
47
+ """
48
+ net_output, codebook_out, encoder_output = model(**sample["net_input"])
49
+ bart_loss, _ = self.compute_loss(model, net_output, sample, reduce=reduce)
50
+ sample_size = (
51
+ sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
52
+ )
53
+
54
+ loss = self.bart_weight * bart_loss
55
+ logging_output = {
56
+ "loss": loss.item(),
57
+ "ntokens": sample["ntokens"],
58
+ "nsentences": sample["target"].size(0),
59
+ "bart_loss": bart_loss.item(),
60
+ "sample_size": sample_size,
61
+ }
62
+
63
+ if "prob_perplexity" in codebook_out:
64
+ assert hasattr(model, "get_extra_losses")
65
+ extra_losses, names = model.get_extra_losses(codebook_out)
66
+ if torch.is_tensor(extra_losses):
67
+ extra_losses = [extra_losses]
68
+ names = [names]
69
+ if len(self.loss_weights) == 1 and len(extra_losses) != 1:
70
+ self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
71
+ if len(self.loss_weights) > len(extra_losses):
72
+ modified_loss_weight = self.loss_weights[len(extra_losses):]
73
+ else:
74
+ modified_loss_weight = self.loss_weights
75
+
76
+ # assert len(extra_losses) == len(self.loss_weights), f"{len(extra_losses)}, {len(self.loss_weights)}"
77
+ for p, n, coef in zip(extra_losses, names, modified_loss_weight):
78
+ # print(n + str(coef))
79
+ if coef != 0 and p is not None:
80
+ p = coef * p.float() * sample_size
81
+ loss += p
82
+ logging_output[f"loss_{n}"] = p.item()
83
+
84
+ if 'loss_prob_perplexity' in logging_output:
85
+ logging_output['code_perplexity'] = codebook_out['code_perplexity'].item()
86
+
87
+ return loss, sample_size, logging_output
88
+
89
+ def compute_loss(self, model, net_output, sample, reduce=True):
90
+ lprobs = model.get_normalized_probs(net_output, log_probs=True)
91
+ lprobs = lprobs.view(-1, lprobs.size(-1))
92
+ target = model.get_targets(sample, net_output).view(-1)
93
+ loss = F.nll_loss(
94
+ lprobs,
95
+ target,
96
+ ignore_index=self.padding_idx,
97
+ reduction="sum" if reduce else "none",
98
+ )
99
+ return loss, loss
100
+
101
+ @staticmethod
102
+ def reduce_metrics(logging_outputs) -> None:
103
+ """Aggregate logging outputs from data parallel training."""
104
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
105
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
106
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
107
+ bart_loss_sum = sum(log.get("bart_loss", 0) for log in logging_outputs)
108
+
109
+ # we divide by log(2) to convert the loss from base e to base 2
110
+ metrics.log_scalar(
111
+ "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
112
+ )
113
+ metrics.log_scalar(
114
+ "bart_loss", bart_loss_sum / sample_size / math.log(2), ntokens, 2, round=3
115
+ )
116
+ if sample_size != ntokens:
117
+ metrics.log_scalar(
118
+ "nll_loss", bart_loss_sum / ntokens / math.log(2), ntokens, round=3
119
+ )
120
+ metrics.log_derived(
121
+ "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
122
+ )
123
+ else:
124
+ metrics.log_derived(
125
+ "ppl", lambda meters: utils.get_perplexity(meters["bart_loss"].avg)
126
+ )
127
+
128
+ if "loss_prob_perplexity" in logging_outputs[0].keys():
129
+ val = sum(log["loss_prob_perplexity"] for log in logging_outputs)
130
+ metrics.log_scalar("loss_prob_perplexity", val / sample_size / math.log(2), round=3)
131
+ if "code_perplexity" in logging_outputs[0].keys():
132
+ val = sum(log["code_perplexity"] for log in logging_outputs)
133
+ metrics.log_scalar("code_perplexity", val / len(logging_outputs), round=3)
134
+
135
+ @staticmethod
136
+ def logging_outputs_can_be_summed() -> bool:
137
+ """
138
+ Whether the logging outputs returned by `forward` can be summed
139
+ across workers prior to calling `reduce_metrics`. Setting this
140
+ to True will improves distributed training speed.
141
+ """
142
+ return True
artst/criterions/text_to_speech_loss.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transform (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+ # Based on speecht5, fairseq and espnet code bases
5
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
6
+ # --------------------------------------------------------
7
+
8
+ from dataclasses import dataclass, field
9
+
10
+ import torch
11
+ from fairseq import metrics, utils
12
+ from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask
13
+ from fairseq.criterions import FairseqCriterion, register_criterion
14
+ from fairseq.dataclass import FairseqDataclass
15
+ from artst.models.modules.speech_encoder_prenet import SpeechEncoderPrenet
16
+ from espnet.nets.pytorch_backend.e2e_tts_tacotron2 import GuidedAttentionLoss
17
+ from omegaconf import II
18
+ from typing import Any
19
+
20
+
21
+ @dataclass
22
+ class TexttoSpeechLossConfig(FairseqDataclass):
23
+ use_masking: bool = field(
24
+ default=True,
25
+ metadata={"help": "Whether to use masking in calculation of loss"},
26
+ )
27
+ use_weighted_masking: bool = field(
28
+ default=False,
29
+ metadata={"help": "Whether to use weighted masking in calculation of loss"},
30
+ )
31
+ loss_type: str = field(
32
+ default="L1",
33
+ metadata={"help": "How to calc loss"},
34
+ )
35
+ bce_pos_weight: float = field(
36
+ default=5.0,
37
+ metadata={"help": "Positive sample weight in BCE calculation (only for use-masking=True)"},
38
+ )
39
+ bce_loss_lambda: float = field(
40
+ default=1.0,
41
+ metadata={"help": "Lambda in bce loss"},
42
+ )
43
+ use_guided_attn_loss: bool = field(
44
+ default=False,
45
+ metadata={"help": "Whether to use guided attention loss"},
46
+ )
47
+ guided_attn_loss_sigma: float = field(
48
+ default=0.4,
49
+ metadata={"help": "Sigma in guided attention loss"},
50
+ )
51
+ guided_attn_loss_lambda: float = field(
52
+ default=10.0,
53
+ metadata={"help": "Lambda in guided attention loss"},
54
+ )
55
+ num_layers_applied_guided_attn: int = field(
56
+ default=2,
57
+ metadata={"help": "Number of layers to be applied guided attention loss, if set -1, all of the layers will be applied."},
58
+ )
59
+ num_heads_applied_guided_attn: int = field(
60
+ default=2,
61
+ metadata={"help": "Number of heads in each layer to be applied guided attention loss, if set -1, all of the heads will be applied."},
62
+ )
63
+ modules_applied_guided_attn: Any = field(
64
+ default=("encoder-decoder",),
65
+ metadata={"help": "Module name list to be applied guided attention loss"},
66
+ )
67
+ sentence_avg: bool = II("optimization.sentence_avg")
68
+
69
+
70
+ class TexttoSpeechLoss(FairseqCriterion):
71
+ def __init__(
72
+ self,
73
+ task,
74
+ sentence_avg,
75
+ use_masking=True,
76
+ use_weighted_masking=False,
77
+ loss_type="L1",
78
+ bce_pos_weight=5.0,
79
+ bce_loss_lambda=1.0,
80
+ use_guided_attn_loss=False,
81
+ guided_attn_loss_sigma=0.4,
82
+ guided_attn_loss_lambda=1.0,
83
+ num_layers_applied_guided_attn=2,
84
+ num_heads_applied_guided_attn=2,
85
+ modules_applied_guided_attn=["encoder-decoder"],
86
+ ):
87
+ super().__init__(task)
88
+ self.sentence_avg = sentence_avg
89
+ self.use_masking = use_masking
90
+ self.use_weighted_masking = use_weighted_masking
91
+ self.loss_type = loss_type
92
+ self.bce_pos_weight = bce_pos_weight
93
+ self.bce_loss_lambda = bce_loss_lambda
94
+ self.use_guided_attn_loss = use_guided_attn_loss
95
+ self.guided_attn_loss_sigma = guided_attn_loss_sigma
96
+ self.guided_attn_loss_lambda = guided_attn_loss_lambda
97
+ # define loss function
98
+ self.criterion = Tacotron2Loss(
99
+ use_masking=use_masking,
100
+ use_weighted_masking=use_weighted_masking,
101
+ bce_pos_weight=bce_pos_weight,
102
+ )
103
+ if self.use_guided_attn_loss:
104
+ self.num_layers_applied_guided_attn = num_layers_applied_guided_attn
105
+ self.num_heads_applied_guided_attn = num_heads_applied_guided_attn
106
+ self.modules_applied_guided_attn = modules_applied_guided_attn
107
+ if self.use_guided_attn_loss:
108
+ self.attn_criterion = GuidedMultiHeadAttentionLoss(
109
+ sigma=guided_attn_loss_sigma,
110
+ alpha=guided_attn_loss_lambda,
111
+ )
112
+
113
+ def forward(self, model, sample):
114
+ """Compute the loss for the given sample.
115
+
116
+ Returns a tuple with three elements:
117
+ 1) the loss
118
+ 2) the sample size, which is used as the denominator for the gradient
119
+ 3) logging outputs to display while training
120
+ """
121
+ net_output = model(**sample["net_input"])
122
+ loss, l1_loss, l2_loss, bce_loss, enc_dec_attn_loss = self.compute_loss(model, net_output, sample)
123
+ # sample_size = (
124
+ # sample["target"].size(0) if self.sentence_avg else sample["nframes"]
125
+ # )
126
+ sample_size = 1
127
+ logging_output = {
128
+ "loss": loss.item(),
129
+ "l1_loss": l1_loss.item(),
130
+ "l2_loss": l2_loss.item(),
131
+ "bce_loss": bce_loss.item(),
132
+ "sample_size": 1,
133
+ "ntokens": sample["ntokens"],
134
+ "nsentences": sample["target"].size(0),
135
+ }
136
+
137
+ if enc_dec_attn_loss is not None:
138
+ logging_output['enc_dec_attn_loss'] = enc_dec_attn_loss.item()
139
+
140
+ if hasattr(model, 'text_encoder_prenet'):
141
+ logging_output["encoder_alpha"] = model.text_encoder_prenet.encoder_prenet[-1].alpha.item()
142
+ logging_output["decoder_alpha"] = model.speech_decoder_prenet.decoder_prenet[-1].alpha.item()
143
+ elif hasattr(model, "speech_encoder_prenet"):
144
+ logging_output["decoder_alpha"] = model.speech_decoder_prenet.decoder_prenet[-1].alpha.item()
145
+ else:
146
+ if 'task' not in sample:
147
+ logging_output["encoder_alpha"] = model.encoder_prenet.encoder_prenet[-1].alpha.item()
148
+ logging_output["decoder_alpha"] = model.decoder_prenet.decoder_prenet[-1].alpha.item()
149
+
150
+ return loss, sample_size, logging_output
151
+
152
+ def compute_loss(self, model, net_output, sample):
153
+ before_outs, after_outs, logits, attn = net_output
154
+ labels = sample["labels"]
155
+ ys = sample["dec_target"]
156
+ olens = sample["dec_target_lengths"]
157
+ ilens = sample["src_lengths"]
158
+
159
+ # modifiy mod part of groundtruth
160
+ if model.reduction_factor > 1:
161
+ olens_in = olens.new([torch.div(olen, model.reduction_factor, rounding_mode='floor') for olen in olens])
162
+ olens = olens.new([olen - olen % model.reduction_factor for olen in olens])
163
+ max_olen = max(olens)
164
+ ys = ys[:, :max_olen]
165
+ labels = labels[:, :max_olen]
166
+ labels = torch.scatter(labels, 1, (olens - 1).unsqueeze(1), 1.0) # make sure at least one frame has 1
167
+ # labels[:, -1] = 1.0
168
+ else:
169
+ olens_in = olens
170
+
171
+ # caluculate loss values
172
+ l1_loss, l2_loss, bce_loss = self.criterion(
173
+ after_outs, before_outs, logits, ys, labels, olens
174
+ )
175
+
176
+ # l1_loss = l1_loss / ys.size(2)
177
+ # l2_loss = l2_loss / ys.size(2)
178
+
179
+ if self.loss_type == "L1":
180
+ loss = l1_loss + self.bce_loss_lambda * bce_loss if self.bce_loss_lambda > 0.0 else l1_loss
181
+ elif self.loss_type == "L2":
182
+ loss = l2_loss + self.bce_loss_lambda * bce_loss if self.bce_loss_lambda > 0.0 else l2_loss
183
+ elif self.loss_type == "L1+L2":
184
+ loss = l1_loss + l2_loss + self.bce_loss_lambda * bce_loss if self.bce_loss_lambda > 0.0 else l1_loss + l2_loss
185
+ else:
186
+ raise ValueError("unknown --loss-type " + self.loss_type)
187
+
188
+ # calculate guided attention loss
189
+ enc_dec_attn_loss = None
190
+ if self.use_guided_attn_loss:
191
+ # calculate the input lengths of encoder, which is determined by encoder prenet
192
+ if hasattr(model, 'encoder_reduction_factor') and model.encoder_reduction_factor > 1:
193
+ ilens_in = ilens.new([ilen // model.encoder_reduction_factor for ilen in ilens])
194
+ else:
195
+ ilens_in = ilens
196
+ # work for speech to speech model's input
197
+ if "task_name" in sample and sample["task_name"] == "s2s":
198
+ m = None
199
+ if hasattr(model, 'encoder_prenet'):
200
+ m = model.encoder_prenet
201
+ elif hasattr(model, 'speech_encoder_prenet'):
202
+ m = model.speech_encoder_prenet
203
+ if m is not None and isinstance(m, SpeechEncoderPrenet):
204
+ ilens_in = m.get_src_lengths(ilens_in)
205
+ # calculate for encoder-decoder
206
+ if "encoder-decoder" in self.modules_applied_guided_attn:
207
+ attn = [att_l[:, : self.num_heads_applied_guided_attn] for att_l in attn]
208
+ att_ws = torch.cat(attn, dim=1) # (B, H*L, T_out, T_in)
209
+ enc_dec_attn_loss = self.attn_criterion(att_ws, ilens_in, olens_in)
210
+ loss = loss + enc_dec_attn_loss
211
+
212
+ return loss, l1_loss, l2_loss, bce_loss, enc_dec_attn_loss
213
+
214
+ @classmethod
215
+ def reduce_metrics(cls, logging_outputs) -> None:
216
+ """Aggregate logging outputs from data parallel training."""
217
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
218
+ l1_loss_sum = sum(log.get("l1_loss", 0) for log in logging_outputs)
219
+ l2_loss_sum = sum(log.get("l2_loss", 0) for log in logging_outputs)
220
+ bce_loss_sum = sum(log.get("bce_loss", 0) for log in logging_outputs)
221
+ sample_size = max(1, sum(log.get("sample_size", 0) for log in logging_outputs))
222
+ metrics.log_scalar(
223
+ "loss", loss_sum / sample_size, sample_size, 1, round=5
224
+ )
225
+ encoder_alpha_sum = sum(log.get("encoder_alpha", 0) for log in logging_outputs)
226
+ decoder_alpha_sum = sum(log.get("decoder_alpha", 0) for log in logging_outputs)
227
+ ngpu = sum(log.get("ngpu", 0) for log in logging_outputs)
228
+
229
+ metrics.log_scalar(
230
+ "l1_loss", l1_loss_sum / sample_size, sample_size, 2, round=5
231
+ )
232
+ metrics.log_scalar(
233
+ "l2_loss", l2_loss_sum / sample_size, sample_size, 2, round=5
234
+ )
235
+ metrics.log_scalar(
236
+ "bce_loss", bce_loss_sum / sample_size, sample_size, 2, round=5
237
+ )
238
+ metrics.log_scalar(
239
+ "encoder_alpha", encoder_alpha_sum / sample_size, sample_size, round=5
240
+ )
241
+ metrics.log_scalar(
242
+ "decoder_alpha", decoder_alpha_sum / sample_size, sample_size, round=5
243
+ )
244
+
245
+ if "enc_dec_attn_loss" in logging_outputs[0]:
246
+ enc_dec_attn_loss_sum = sum(log.get("enc_dec_attn_loss", 0) for log in logging_outputs)
247
+ metrics.log_scalar(
248
+ "enc_dec_attn_loss", enc_dec_attn_loss_sum / sample_size, sample_size, round=8
249
+ )
250
+
251
+
252
+ @staticmethod
253
+ def logging_outputs_can_be_summed() -> bool:
254
+ """
255
+ Whether the logging outputs returned by `forward` can be summed
256
+ across workers prior to calling `reduce_metrics`. Setting this
257
+ to True will improves distributed training speed.
258
+ """
259
+ return True
260
+
261
+ class Tacotron2Loss(torch.nn.Module):
262
+ """Loss function module for Tacotron2."""
263
+
264
+ def __init__(
265
+ self, use_masking=True, use_weighted_masking=False, bce_pos_weight=20.0
266
+ ):
267
+ """Initialize Tactoron2 loss module.
268
+
269
+ Args:
270
+ use_masking (bool): Whether to apply masking
271
+ for padded part in loss calculation.
272
+ use_weighted_masking (bool):
273
+ Whether to apply weighted masking in loss calculation.
274
+ bce_pos_weight (float): Weight of positive sample of stop token.
275
+
276
+ """
277
+ super(Tacotron2Loss, self).__init__()
278
+ assert (use_masking != use_weighted_masking) or not use_masking
279
+ self.use_masking = use_masking
280
+ self.use_weighted_masking = use_weighted_masking
281
+
282
+ # define criterions
283
+ # reduction = "none" if self.use_weighted_masking else "sum"
284
+ reduction = "none" if self.use_weighted_masking else "mean"
285
+ self.l1_criterion = torch.nn.L1Loss(reduction=reduction)
286
+ self.mse_criterion = torch.nn.MSELoss(reduction=reduction)
287
+ self.bce_criterion = torch.nn.BCEWithLogitsLoss(
288
+ reduction=reduction, pos_weight=torch.tensor(bce_pos_weight)
289
+ )
290
+
291
+ # NOTE(kan-bayashi): register pre hook function for the compatibility
292
+ self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook)
293
+
294
+ def forward(self, after_outs, before_outs, logits, ys, labels, olens):
295
+ """Calculate forward propagation.
296
+
297
+ Args:
298
+ after_outs (Tensor): Batch of outputs after postnets (B, Lmax, odim).
299
+ before_outs (Tensor): Batch of outputs before postnets (B, Lmax, odim).
300
+ logits (Tensor): Batch of stop logits (B, Lmax).
301
+ ys (Tensor): Batch of padded target features (B, Lmax, odim).
302
+ labels (LongTensor): Batch of the sequences of stop token labels (B, Lmax).
303
+ olens (LongTensor): Batch of the lengths of each target (B,).
304
+
305
+ Returns:
306
+ Tensor: L1 loss value.
307
+ Tensor: Mean square error loss value.
308
+ Tensor: Binary cross entropy loss value.
309
+
310
+ """
311
+ # make mask and apply it
312
+ if self.use_masking:
313
+ masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
314
+ ys = ys.masked_select(masks)
315
+ after_outs = after_outs.masked_select(masks)
316
+ before_outs = before_outs.masked_select(masks)
317
+ labels = labels.masked_select(masks[:, :, 0])
318
+ logits = logits.masked_select(masks[:, :, 0])
319
+
320
+ # calculate loss
321
+ l1_loss = self.l1_criterion(after_outs, ys) + self.l1_criterion(before_outs, ys)
322
+ mse_loss = self.mse_criterion(after_outs, ys) + self.mse_criterion(
323
+ before_outs, ys
324
+ )
325
+ bce_loss = self.bce_criterion(logits, labels)
326
+
327
+ # make weighted mask and apply it
328
+ if self.use_weighted_masking:
329
+ masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
330
+ weights = masks.float() / masks.sum(dim=1, keepdim=True).float()
331
+ out_weights = weights.div(ys.size(0) * ys.size(2))
332
+ logit_weights = weights.div(ys.size(0))
333
+
334
+ # apply weight
335
+ l1_loss = l1_loss.mul(out_weights).masked_select(masks).sum()
336
+ mse_loss = mse_loss.mul(out_weights).masked_select(masks).sum()
337
+ bce_loss = (
338
+ bce_loss.mul(logit_weights.squeeze(-1))
339
+ .masked_select(masks.squeeze(-1))
340
+ .sum()
341
+ )
342
+
343
+ return l1_loss, mse_loss, bce_loss
344
+
345
+ def _load_state_dict_pre_hook(
346
+ self,
347
+ state_dict,
348
+ prefix,
349
+ local_metadata,
350
+ strict,
351
+ missing_keys,
352
+ unexpected_keys,
353
+ error_msgs,
354
+ ):
355
+ """Apply pre hook fucntion before loading state dict.
356
+
357
+ From v.0.6.1 `bce_criterion.pos_weight` param is registered as a parameter but
358
+ old models do not include it and as a result, it causes missing key error when
359
+ loading old model parameter. This function solve the issue by adding param in
360
+ state dict before loading as a pre hook function
361
+ of the `load_state_dict` method.
362
+
363
+ """
364
+ key = prefix + "bce_criterion.pos_weight"
365
+ if key not in state_dict:
366
+ state_dict[key] = self.bce_criterion.pos_weight
367
+
368
+ class GuidedMultiHeadAttentionLoss(GuidedAttentionLoss):
369
+ """Guided attention loss function module for multi head attention.
370
+ Args:
371
+ sigma (float, optional): Standard deviation to control
372
+ how close attention to a diagonal.
373
+ alpha (float, optional): Scaling coefficient (lambda).
374
+ reset_always (bool, optional): Whether to always reset masks.
375
+ """
376
+
377
+ def forward(self, att_ws, ilens, olens):
378
+ """Calculate forward propagation.
379
+ Args:
380
+ att_ws (Tensor):
381
+ Batch of multi head attention weights (B, H, T_max_out, T_max_in).
382
+ ilens (LongTensor): Batch of input lenghts (B,).
383
+ olens (LongTensor): Batch of output lenghts (B,).
384
+ Returns:
385
+ Tensor: Guided attention loss value.
386
+ """
387
+ if self.guided_attn_masks is None:
388
+ self.guided_attn_masks = (
389
+ self._make_guided_attention_masks(ilens, olens)
390
+ .to(att_ws.device)
391
+ .unsqueeze(1)
392
+ )
393
+ if self.masks is None:
394
+ self.masks = self._make_masks(ilens, olens).to(att_ws.device).unsqueeze(1)
395
+ losses = self.guided_attn_masks * att_ws
396
+ loss = torch.mean(losses.masked_select(self.masks))
397
+ if self.reset_always:
398
+ self._reset_masks()
399
+
400
+ return self.alpha * loss
401
+
402
+ def _make_guided_attention_masks(self, ilens, olens):
403
+ n_batches = len(ilens)
404
+ max_ilen = max(ilens)
405
+ max_olen = max(olens)
406
+ guided_attn_masks = torch.zeros((n_batches, max_olen, max_ilen), device=olens.device)
407
+ for idx, (ilen, olen) in enumerate(zip(ilens, olens)):
408
+ guided_attn_masks[idx, :olen, :ilen] = self._make_guided_attention_mask(
409
+ ilen, olen, self.sigma
410
+ )
411
+ return guided_attn_masks
412
+
413
+ @staticmethod
414
+ def _make_guided_attention_mask(ilen, olen, sigma):
415
+ grid_x, grid_y = torch.meshgrid(torch.arange(olen, device=olen.device), torch.arange(ilen, device=olen.device))
416
+ grid_x, grid_y = grid_x.float(), grid_y.float()
417
+ return 1.0 - torch.exp(
418
+ -((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma**2))
419
+ )
420
+
421
+ @staticmethod
422
+ def _make_masks(ilens, olens):
423
+ in_masks = make_non_pad_mask(ilens).to(ilens.device) # (B, T_in)
424
+ out_masks = make_non_pad_mask(olens).to(olens.device) # (B, T_out)
425
+ return out_masks.unsqueeze(-1) & in_masks.unsqueeze(-2) # (B, T_out, T_in)
artst/data/__init__.py ADDED
File without changes
artst/data/multitask_dataset.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+ # Based on speecht5, fairseq and espnet code bases
5
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
6
+ # --------------------------------------------------------
7
+
8
+ import bisect
9
+
10
+ import logging
11
+ import numpy as np
12
+ from torch.utils.data.dataloader import default_collate
13
+ from fairseq.data import data_utils
14
+
15
+ from fairseq.data.fairseq_dataset import FairseqDataset
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ class MultitaskDataset(FairseqDataset):
20
+ @staticmethod
21
+ def cumsum(sequence):
22
+ r, s = [], 0
23
+ for e in sequence:
24
+ curr_len = len(e)
25
+ r.append(curr_len + s)
26
+ s += curr_len
27
+ return r
28
+
29
+ def __init__(self, datasets, sample_ratios=1, batch_ratio=None):
30
+ super(MultitaskDataset, self).__init__()
31
+ assert len(datasets) > 0, "datasets should not be an empty iterable"
32
+ self.datasets = list(datasets)
33
+ if isinstance(sample_ratios, int):
34
+ sample_ratios = [sample_ratios] * len(self.datasets)
35
+ if batch_ratio is not None:
36
+ logger.info('batch ratio is ' + str(batch_ratio))
37
+ self.batch_ratio = batch_ratio
38
+ else:
39
+ self.batch_ratio = None
40
+ else:
41
+ logger.info('set sample ratio to ' + str(sample_ratios))
42
+ if batch_ratio is not None:
43
+ logger.info('batch ratio is ' + str(batch_ratio))
44
+ self.batch_ratio = batch_ratio
45
+ else:
46
+ self.batch_ratio = None
47
+ self.sample_ratios = sample_ratios
48
+ self._ordered_indices = None
49
+ self._update_size()
50
+
51
+ def __len__(self):
52
+ return self.cumulative_sizes[-1]
53
+
54
+ def __getitem__(self, idx):
55
+ dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
56
+ sample = self.datasets[dataset_idx][sample_idx]
57
+ if isinstance(sample, dict):
58
+ sample["dataset_idx"] = dataset_idx
59
+ else:
60
+ sample = sample + (dataset_idx,)
61
+ return sample
62
+
63
+ def _update_size(self):
64
+ self.cumulative_sizes = self.cumsum(self.datasets)
65
+ self.real_sizes = [len(d) for d in self.datasets]
66
+
67
+ def _get_dataset_and_sample_index(self, idx: int):
68
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
69
+ if dataset_idx == 0:
70
+ sample_idx = idx
71
+ else:
72
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
73
+ sample_idx = sample_idx % self.real_sizes[dataset_idx]
74
+ return dataset_idx, sample_idx
75
+
76
+ def collater(self, samples, **extra_args):
77
+ # For now only supports datasets with same underlying collater implementations
78
+ if samples is not None and len(samples) > 0:
79
+ if isinstance(samples[0], dict):
80
+ dataset_idx = samples[0]["dataset_idx"]
81
+ else:
82
+ dataset_idx = samples[0][-1]
83
+ samples = [sample[:-1] for sample in samples]
84
+ else:
85
+ dataset_idx = 0
86
+
87
+ if hasattr(self.datasets[dataset_idx], "collater"):
88
+ return self.datasets[dataset_idx].collater(samples, **extra_args)
89
+ else:
90
+ return default_collate(samples, **extra_args)
91
+
92
+ def size(self, idx: int):
93
+ """
94
+ Return an example's size as a float or tuple.
95
+ """
96
+ dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
97
+ return self.datasets[dataset_idx].size(sample_idx)
98
+
99
+ def num_tokens(self, index: int):
100
+ return np.max(self.size(index))
101
+
102
+ def attr(self, attr: str, index: int):
103
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, index)
104
+ return getattr(self.datasets[dataset_idx], attr, None)
105
+
106
+ @property
107
+ def sizes(self):
108
+ _dataset_sizes = []
109
+ for ds in self.datasets:
110
+ if isinstance(ds.sizes, np.ndarray):
111
+ _dataset_sizes.append(ds.sizes)
112
+ else:
113
+ # Only support underlying dataset with single size array.
114
+ assert isinstance(ds.sizes, list)
115
+ _dataset_sizes.append(ds.sizes[0])
116
+ return np.concatenate(_dataset_sizes)
117
+
118
+ @property
119
+ def supports_prefetch(self):
120
+ return all(d.supports_prefetch for d in self.datasets)
121
+
122
+ def ordered_indices(self):
123
+ # ordered_indices = []
124
+ # for i, dataset in enumerate(self.datasets):
125
+ # indice = dataset.ordered_indices()
126
+ # ordered_indices.append(indice)
127
+ if self._ordered_indices is None:
128
+ # Call the underlying dataset's ordered_indices() here, so that we
129
+ # get the same random ordering as we would have from using the
130
+ # underlying sub-datasets directly.
131
+ self._ordered_indices = [
132
+ dataset.ordered_indices()
133
+ for dataset in self.datasets
134
+ ]
135
+ return np.arange(len(self))
136
+
137
+ def prefetch(self, indices):
138
+ frm = 0
139
+ for to, ds in zip(self.cumulative_sizes, self.datasets):
140
+ real_size = len(ds)
141
+ if getattr(ds, "supports_prefetch", False):
142
+ ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to])
143
+ frm = to
144
+
145
+ def batch_by_size(
146
+ self,
147
+ indices,
148
+ max_tokens=None,
149
+ max_sentences=None,
150
+ required_batch_size_multiple=1,
151
+ ):
152
+ if not hasattr(self, "max_tokens"):
153
+ self.max_tokens = max_tokens
154
+ if not hasattr(self, "max_sentences"):
155
+ self.max_sentences = max_sentences
156
+ if not hasattr(self, "required_batch_size_multiple"):
157
+ self.required_batch_size_multiple = required_batch_size_multiple
158
+ batch_samplers = []
159
+ for i, dataset in enumerate(self.datasets):
160
+ batch_sampler = dataset.batch_by_size(
161
+ self._ordered_indices[i],
162
+ max_tokens=max_tokens if self.batch_ratio is None else max_tokens * self.batch_ratio[i],
163
+ max_sentences=max_sentences,
164
+ required_batch_size_multiple=required_batch_size_multiple,
165
+ )
166
+ if i > 0:
167
+ for batch in batch_sampler:
168
+ batch += self.cumulative_sizes[i - 1]
169
+ if self.sample_ratios[i] != 1.0:
170
+ batch_sampler = np.array(batch_sampler)
171
+ batch_sampler = np.random.choice(batch_sampler, int(len(batch_sampler) * self.sample_ratios[i]))
172
+ batch_sampler = list(batch_sampler)
173
+ logger.info('Adjust batch by ratio ' + str(self.sample_ratios[i]) + ' and the number of batch is ' + str(int(len(batch_sampler))) + ' for dataset ' + str(i))
174
+ batch_samplers.extend(batch_sampler)
175
+ return batch_samplers
176
+
177
+ def filter_indices_by_size(self, indices, max_positions):
178
+ """
179
+ Filter each sub-dataset independently, then update the round robin to work
180
+ on the filtered sub-datasets.
181
+ """
182
+ if not hasattr(self, "max_positions"):
183
+ self.max_positions = max_positions
184
+ ignored_some = False
185
+ for i in range(len(self.datasets)):
186
+ # ignored = []
187
+ self._ordered_indices[i], ignored = self.datasets[i].filter_indices_by_size(
188
+ self._ordered_indices[i], self.max_positions[i]
189
+ )
190
+ if len(ignored) > 0:
191
+ ignored_some = True
192
+ logger.warning(
193
+ f"{len(ignored)} samples from {i} have invalid sizes and will be skipped, "
194
+ f"max_positions={self.max_positions[i]}, first few sample ids={ignored[:10]}"
195
+ )
196
+
197
+ logger.info('update dataset size')
198
+ self._update_size()
199
+
200
+ # Since we are modifying in place the _ordered_indices,
201
+ # it's not possible anymore to return valid ignored indices.
202
+ # Hopefully the extra debug information print above should be enough to debug.
203
+ # Ideally we would receive ignore_invalid_inputs so that we could have
204
+ # a proper error message.
205
+ return (np.arange(len(self)), [0] if ignored_some else [])
206
+
207
+ @property
208
+ def can_reuse_epoch_itr_across_epochs(self):
209
+ return all(d.can_reuse_epoch_itr_across_epochs for d in self.datasets)
210
+
211
+ def set_epoch(self, epoch):
212
+ super().set_epoch(epoch)
213
+ for ds in self.datasets:
214
+ if hasattr(ds, "set_epoch"):
215
+ ds.set_epoch(epoch)
216
+
217
+ def shuffle_batches(self, batches, seed):
218
+ logger.info("shuffle batches")
219
+ new_batches_fromlist = []
220
+ new_batches_notlist = []
221
+ new_batches = []
222
+ with data_utils.numpy_seed(seed):
223
+ np.random.shuffle(batches)
224
+ for batch in batches:
225
+ if isinstance(batch, list):
226
+ # np.random.shuffle(batch)
227
+ new_batches_fromlist.append(batch)
228
+ else:
229
+ new_batches_notlist.append(batch)
230
+ logger.info("Get " + str(len(new_batches_fromlist)) + " chunk from speech sides")
231
+ logger.info("Get " + str(sum([len(batch_list) for batch_list in new_batches_fromlist])) + " batches from speech sides")
232
+ logger.info("Get " + str(len(new_batches_notlist)) + " batches from text sides")
233
+ if len(new_batches_fromlist) == 0:
234
+ return new_batches_notlist
235
+ st_ratio = int(len(new_batches_notlist) / len(new_batches_fromlist))
236
+ logger.info("Get st_ratio " + str(st_ratio))
237
+ last_idx = 0
238
+ for i in range(len(new_batches_fromlist)):
239
+ if i == len(new_batches_fromlist) - 1:
240
+ new_batches_fromlist[i].extend(new_batches_notlist[last_idx:])
241
+ else:
242
+ new_batches_fromlist[i].extend(new_batches_notlist[last_idx : last_idx + st_ratio])
243
+ np.random.shuffle(new_batches_fromlist[i])
244
+ new_batches.extend(new_batches_fromlist[i])
245
+ last_idx = last_idx + st_ratio
246
+ logger.info("Finish shuffle")
247
+ return new_batches
248
+
249
+ def reset_batch_sampler(self):
250
+ logger.info("reset batch sampler")
251
+ self._ordered_indices = [
252
+ self.datasets[i].ordered_indices()
253
+ for i in range(len(self.datasets))
254
+ ]
255
+ self.filter_indices_by_size(None, None)
256
+
257
+ batch_samplers = self.batch_by_size(
258
+ None,
259
+ self.max_tokens,
260
+ self.max_sentences,
261
+ self.required_batch_size_multiple
262
+ )
263
+ return batch_samplers
artst/data/speech_dataset.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+
5
+ # Based on speecht5, fairseq and espnet code bases
6
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
7
+ # --------------------------------------------------------
8
+
9
+ import itertools
10
+ import logging
11
+ import os
12
+ import sys
13
+ from typing import Any, List, Optional, Union
14
+
15
+ import numpy as np
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ import librosa
20
+ from fairseq.data.audio.speech_to_text_dataset import get_features_or_waveform
21
+ from fairseq.data import data_utils
22
+ from fairseq.data.fairseq_dataset import FairseqDataset
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ def _collate_frames(
27
+ frames: List[torch.Tensor], is_audio_input: bool = False
28
+ ):
29
+ """
30
+ Convert a list of 2D frames into a padded 3D tensor
31
+ Args:
32
+ frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is
33
+ length of i-th frame and f_dim is static dimension of features
34
+ Returns:
35
+ 3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
36
+ """
37
+ max_len = max(frame.size(0) for frame in frames)
38
+ if is_audio_input:
39
+ out = frames[0].new_zeros((len(frames), max_len))
40
+ else:
41
+ out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1)))
42
+ for i, v in enumerate(frames):
43
+ out[i, : v.size(0)] = v
44
+ return out
45
+
46
+ def add_first_frame_and_remove_last_frame(ys):
47
+ ys_in = torch.cat(
48
+ [ys.new_zeros((ys.shape[0], 1, ys.shape[2])), ys[:, :-1]], dim=1
49
+ )
50
+ return ys_in
51
+
52
+ def load_audio(manifest_path, max_keep, min_keep):
53
+ n_long, n_short = 0, 0
54
+ names, inds, sizes, spk_embeds = [], [], [], []
55
+ with open(manifest_path) as f:
56
+ root = f.readline().strip()
57
+ for ind, line in enumerate(f):
58
+ items = line.strip().split("\t")
59
+ assert len(items) == 3, line
60
+ sz = int(items[1])
61
+ if min_keep is not None and sz < min_keep:
62
+ n_short += 1
63
+ elif max_keep is not None and sz > max_keep:
64
+ n_long += 1
65
+ else:
66
+ names.append(items[0])
67
+ spk_embeds.append(items[2])
68
+ inds.append(ind)
69
+ sizes.append(sz)
70
+ tot = ind + 1
71
+ logger.info(
72
+ (
73
+ f"max_keep={max_keep}, min_keep={min_keep}, "
74
+ f"loaded {len(names)}, skipped {n_short} short and {n_long} long, "
75
+ f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
76
+ )
77
+ )
78
+ return root, names, inds, tot, sizes, spk_embeds
79
+
80
+
81
+ def load_label(label_path, inds, tot):
82
+ with open(label_path) as f:
83
+ labels = [line.rstrip() for line in f]
84
+ assert (
85
+ len(labels) == tot
86
+ ), f"number of labels does not match ({len(labels)} != {tot})"
87
+ labels = [labels[i] for i in inds]
88
+ return labels
89
+
90
+
91
+ def load_label_offset(label_path, inds, tot):
92
+ with open(label_path) as f:
93
+ code_lengths = [len(line.encode("utf-8")) for line in f]
94
+ assert (
95
+ len(code_lengths) == tot
96
+ ), f"number of labels does not match ({len(code_lengths)} != {tot})"
97
+ offsets = list(itertools.accumulate([0] + code_lengths))
98
+ offsets = [(offsets[i], offsets[i + 1]) for i in inds]
99
+ return offsets
100
+
101
+
102
+ def verify_label_lengths(
103
+ audio_sizes,
104
+ audio_rate,
105
+ label_path,
106
+ label_rate,
107
+ inds,
108
+ tot,
109
+ tol=0.1, # tolerance in seconds
110
+ ):
111
+ if label_rate < 0:
112
+ logger.info(f"{label_path} is sequence label. skipped")
113
+ return
114
+
115
+ with open(label_path) as f:
116
+ lengths = [len(line.rstrip().split()) for line in f]
117
+ assert len(lengths) == tot
118
+ lengths = [lengths[i] for i in inds]
119
+ num_invalid = 0
120
+ for i, ind in enumerate(inds):
121
+ dur_from_audio = audio_sizes[i] / audio_rate
122
+ dur_from_label = lengths[i] / label_rate
123
+ if abs(dur_from_audio - dur_from_label) > tol:
124
+ logger.warning(
125
+ (
126
+ f"audio and label duration differ too much "
127
+ f"(|{dur_from_audio} - {dur_from_label}| > {tol}) "
128
+ f"in line {ind+1} of {label_path}. Check if `label_rate` "
129
+ f"is correctly set (currently {label_rate}). "
130
+ f"num. of samples = {audio_sizes[i]}; "
131
+ f"label length = {lengths[i]}"
132
+ )
133
+ )
134
+ num_invalid += 1
135
+ if num_invalid > 0:
136
+ logger.warning(
137
+ f"total {num_invalid} (audio, label) pairs with mismatched lengths"
138
+ )
139
+
140
+
141
+ def logmelfilterbank(
142
+ audio,
143
+ sampling_rate,
144
+ fft_size=1024,
145
+ hop_size=256,
146
+ win_length=None,
147
+ window="hann",
148
+ num_mels=80,
149
+ fmin=80,
150
+ fmax=7600,
151
+ eps=1e-10,
152
+ ):
153
+ """Compute log-Mel filterbank feature.
154
+ (https://github.com/kan-bayashi/ParallelWaveGAN/blob/master/parallel_wavegan/bin/preprocess.py)
155
+
156
+ Args:
157
+ audio (ndarray): Audio signal (T,).
158
+ sampling_rate (int): Sampling rate.
159
+ fft_size (int): FFT size.
160
+ hop_size (int): Hop size.
161
+ win_length (int): Window length. If set to None, it will be the same as fft_size.
162
+ window (str): Window function type.
163
+ num_mels (int): Number of mel basis.
164
+ fmin (int): Minimum frequency in mel basis calculation.
165
+ fmax (int): Maximum frequency in mel basis calculation.
166
+ eps (float): Epsilon value to avoid inf in log calculation.
167
+
168
+ Returns:
169
+ ndarray: Log Mel filterbank feature (#frames, num_mels).
170
+
171
+ """
172
+ # get amplitude spectrogram
173
+ x_stft = librosa.stft(audio, n_fft=fft_size, hop_length=hop_size,
174
+ win_length=win_length, window=window, pad_mode="reflect")
175
+ spc = np.abs(x_stft).T # (#frames, #bins)
176
+
177
+ # get mel basis
178
+ fmin = 0 if fmin is None else fmin
179
+ fmax = sampling_rate / 2 if fmax is None else fmax
180
+ mel_basis = librosa.filters.mel(sr=sampling_rate, n_fft=fft_size, n_mels=num_mels, fmin=fmin, fmax=fmax)
181
+
182
+ return np.log10(np.maximum(eps, np.dot(spc, mel_basis.T)))
183
+
184
+
185
+ class SpeechPretrainDataset(FairseqDataset):
186
+ def __init__(
187
+ self,
188
+ manifest_path: str,
189
+ sample_rate: float,
190
+ label_paths: List[str],
191
+ label_rates: Union[List[float], float], # -1 for sequence labels
192
+ pad_list: List[str],
193
+ eos_list: List[str],
194
+ label_processors: Optional[List[Any]] = None,
195
+ max_keep_sample_size: Optional[int] = None,
196
+ min_keep_sample_size: Optional[int] = None,
197
+ max_sample_size: Optional[int] = None,
198
+ shuffle: bool = True,
199
+ pad_audio: bool = False,
200
+ normalize: bool = False,
201
+ store_labels: bool = True,
202
+ random_crop: bool = False,
203
+ single_target: bool = False,
204
+ reduction_factor: int = 1,
205
+ ):
206
+ self.audio_root, self.audio_names, inds, tot, self.sizes, self.spk_embeds = load_audio(
207
+ manifest_path, max_keep_sample_size, min_keep_sample_size
208
+ )
209
+ self.sample_rate = sample_rate
210
+ self.shuffle = shuffle
211
+ self.random_crop = random_crop
212
+
213
+ self.num_labels = len(label_paths)
214
+ self.pad_list = pad_list
215
+ self.eos_list = eos_list
216
+ self.label_processors = label_processors
217
+ self.single_target = single_target
218
+ self.label_rates = (
219
+ [label_rates for _ in range(len(label_paths))]
220
+ if isinstance(label_rates, float)
221
+ else label_rates
222
+ )
223
+ self.store_labels = store_labels
224
+ if store_labels:
225
+ self.label_list = [load_label(p, inds, tot) for p in label_paths]
226
+ else:
227
+ self.label_paths = label_paths
228
+ self.label_offsets_list = [
229
+ load_label_offset(p, inds, tot) for p in label_paths
230
+ ]
231
+ assert label_processors is None or len(label_processors) == self.num_labels
232
+ for label_path, label_rate in zip(label_paths, self.label_rates):
233
+ verify_label_lengths(
234
+ self.sizes, sample_rate, label_path, label_rate, inds, tot
235
+ )
236
+
237
+ self.max_sample_size = (
238
+ max_sample_size if max_sample_size is not None else sys.maxsize
239
+ )
240
+ self.pad_audio = pad_audio
241
+ self.normalize = normalize
242
+ self.reduction_factor = reduction_factor
243
+ logger.info(
244
+ f"pad_audio={pad_audio}, random_crop={random_crop}, reduction_factor={reduction_factor}, "
245
+ f"normalize={normalize}, max_sample_size={self.max_sample_size}"
246
+ )
247
+
248
+ def get_audio(self, index):
249
+ import soundfile as sf
250
+
251
+ wav_path = os.path.join(self.audio_root, self.audio_names[index])
252
+ wav, cur_sample_rate = sf.read(wav_path)
253
+ wav = torch.from_numpy(wav).float()
254
+ fbank = logmelfilterbank(
255
+ wav.view(-1).cpu().numpy(), 16000
256
+ )
257
+ fbank = torch.from_numpy(fbank).float()
258
+ wav = self.postprocess(wav, cur_sample_rate)
259
+ return wav, fbank
260
+
261
+ def get_label(self, index, label_idx):
262
+ if self.store_labels:
263
+ label = self.label_list[label_idx][index]
264
+ else:
265
+ with open(self.label_paths[label_idx]) as f:
266
+ offset_s, offset_e = self.label_offsets_list[label_idx][index]
267
+ f.seek(offset_s)
268
+ label = f.read(offset_e - offset_s)
269
+
270
+ if self.label_processors is not None:
271
+ label = self.label_processors[label_idx](label)
272
+ return label
273
+
274
+ def get_labels(self, index):
275
+ return [self.get_label(index, i) for i in range(self.num_labels)]
276
+
277
+ def __getitem__(self, index):
278
+ wav, fbank = self.get_audio(index)
279
+ labels = self.get_labels(index)
280
+ spkembs = get_features_or_waveform(
281
+ os.path.join(self.audio_root, self.spk_embeds[index])
282
+ )
283
+ spkembs = torch.from_numpy(spkembs).float()
284
+ return {"id": index, "source": wav, "target": fbank, "label_list": labels, 'spkembs': spkembs}
285
+
286
+ def __len__(self):
287
+ return len(self.sizes)
288
+
289
+ def crop_to_max_size(self, wav, target_size):
290
+ size = len(wav)
291
+ diff = size - target_size
292
+ if diff <= 0:
293
+ return wav, 0
294
+
295
+ start, end = 0, target_size
296
+ if self.random_crop:
297
+ start = np.random.randint(0, diff + 1)
298
+ end = size - diff + start
299
+ return wav[start:end], start
300
+
301
+ def collater(self, samples):
302
+ # target = max(sizes) -> random_crop not used
303
+ # target = max_sample_size -> random_crop used for long
304
+ samples = [s for s in samples if s["source"] is not None]
305
+ if len(samples) == 0:
306
+ return {}
307
+
308
+ audios = [s["source"] for s in samples]
309
+ audio_sizes = [len(s) for s in audios]
310
+
311
+ fbanks = [s["target"] for s in samples]
312
+ fbank_sizes = [len(s) for s in fbanks]
313
+
314
+ if self.pad_audio:
315
+ audio_size = min(max(audio_sizes), self.max_sample_size)
316
+ else:
317
+ audio_size = min(min(audio_sizes), self.max_sample_size)
318
+ collated_audios, padding_mask, audio_starts = self.collater_audio(
319
+ audios, audio_size
320
+ )
321
+
322
+ collated_fbanks = []
323
+ collated_audios_size = []
324
+ for i in range(len(fbanks)):
325
+ fbank_start = int(audio_starts[i] / (audio_sizes[i] / fbank_sizes[i]))
326
+ fbank_size = int(audio_size / (audio_sizes[i] / fbank_sizes[i]))
327
+ fbank_end = min(fbank_start + fbank_size, fbank_sizes[i])
328
+ collated_fbanks.append(fbanks[i][fbank_start : fbank_end])
329
+ collated_audios_size.append(audio_size)
330
+ collated_fbanks_size = [len(s) for s in collated_fbanks]
331
+ collated_fbanks = _collate_frames(collated_fbanks)
332
+ collated_fbanks_size = torch.tensor(collated_fbanks_size, dtype=torch.long)
333
+
334
+ # thin out frames for reduction factor (B, Lmax, odim) -> (B, Lmax//r, odim)
335
+ if self.reduction_factor > 1:
336
+ collated_fbanks_in = collated_fbanks[:, self.reduction_factor - 1 :: self.reduction_factor]
337
+ collated_fbanks_size_in = collated_fbanks_size.new([torch.div(olen, self.reduction_factor, rounding_mode='floor') for olen in collated_fbanks_size])
338
+ else:
339
+ collated_fbanks_in, collated_fbanks_size_in = collated_fbanks, collated_fbanks_size
340
+
341
+ prev_output_tokens = torch.cat(
342
+ [collated_fbanks_in.new_zeros((collated_fbanks_in.shape[0], 1, collated_fbanks_in.shape[2])), collated_fbanks_in[:, :-1]], dim=1
343
+ )
344
+
345
+ # make labels for stop prediction
346
+ labels = collated_fbanks.new_zeros(collated_fbanks.size(0), collated_fbanks.size(1))
347
+ for i, l in enumerate(fbank_sizes):
348
+ labels[i, l - 1 :] = 1.0
349
+
350
+ spkembs = _collate_frames([s["spkembs"] for s in samples], is_audio_input=True)
351
+
352
+ targets_by_label = [
353
+ [s["label_list"][i] for s in samples] for i in range(self.num_labels)
354
+ ]
355
+ targets_list, lengths_list, ntokens_list = self.collater_label(
356
+ targets_by_label, audio_size, audio_starts
357
+ )
358
+
359
+ net_input = {
360
+ "source": collated_audios,
361
+ "padding_mask": padding_mask,
362
+ "prev_output_tokens": prev_output_tokens,
363
+ "spkembs": spkembs,
364
+ "tgt_lengths": collated_fbanks_size_in,
365
+ }
366
+
367
+ batch = {
368
+ "id": torch.LongTensor([s["id"] for s in samples]),
369
+ "net_input": net_input,
370
+ "labels": labels,
371
+ "dec_target": collated_fbanks,
372
+ "dec_target_lengths": collated_fbanks_size,
373
+ "src_lengths": collated_audios_size,
374
+ "task_name": 'speech_pretrain',
375
+ }
376
+
377
+ if self.single_target:
378
+ batch["target_lengths"] = lengths_list[0]
379
+ batch["ntokens"] = ntokens_list[0]
380
+ batch["target"] = targets_list[0]
381
+ else:
382
+ batch["target_lengths_list"] = lengths_list
383
+ batch["ntokens_list"] = ntokens_list
384
+ batch["target_list"] = targets_list
385
+ return batch
386
+
387
+ def collater_audio(self, audios, audio_size):
388
+ collated_audios = audios[0].new_zeros(len(audios), audio_size)
389
+ padding_mask = (
390
+ torch.BoolTensor(collated_audios.shape).fill_(False)
391
+ # if self.pad_audio else None
392
+ )
393
+ audio_starts = [0 for _ in audios]
394
+ for i, audio in enumerate(audios):
395
+ diff = len(audio) - audio_size
396
+ if diff == 0:
397
+ collated_audios[i] = audio
398
+ elif diff < 0:
399
+ assert self.pad_audio
400
+ collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
401
+ padding_mask[i, diff:] = True
402
+ else:
403
+ collated_audios[i], audio_starts[i] = self.crop_to_max_size(
404
+ audio, audio_size
405
+ )
406
+ return collated_audios, padding_mask, audio_starts
407
+
408
+ def collater_frm_label(self, targets, audio_size, audio_starts, label_rate, pad):
409
+ assert label_rate > 0
410
+ s2f = label_rate / self.sample_rate
411
+ frm_starts = [int(round(s * s2f)) for s in audio_starts]
412
+ frm_size = int(round(audio_size * s2f))
413
+ if not self.pad_audio:
414
+ rem_size = [len(t) - s for t, s in zip(targets, frm_starts)]
415
+ frm_size = min(frm_size, *rem_size)
416
+ targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)]
417
+ logger.debug(f"audio_starts={audio_starts}")
418
+ logger.debug(f"frame_starts={frm_starts}")
419
+ logger.debug(f"frame_size={frm_size}")
420
+
421
+ lengths = torch.LongTensor([len(t) for t in targets])
422
+ ntokens = lengths.sum().item()
423
+ targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
424
+ return targets, lengths, ntokens
425
+
426
+ def collater_seq_label(self, targets, pad):
427
+ lengths = torch.LongTensor([len(t) for t in targets])
428
+ ntokens = lengths.sum().item()
429
+ targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
430
+ return targets, lengths, ntokens
431
+
432
+ def collater_label(self, targets_by_label, audio_size, audio_starts):
433
+ targets_list, lengths_list, ntokens_list = [], [], []
434
+ itr = zip(targets_by_label, self.label_rates, self.pad_list)
435
+ for targets, label_rate, pad in itr:
436
+ if label_rate == -1.0:
437
+ targets, lengths, ntokens = self.collater_seq_label(targets, pad)
438
+ else:
439
+ targets, lengths, ntokens = self.collater_frm_label(
440
+ targets, audio_size, audio_starts, label_rate, pad
441
+ )
442
+ targets_list.append(targets)
443
+ lengths_list.append(lengths)
444
+ ntokens_list.append(ntokens)
445
+ return targets_list, lengths_list, ntokens_list
446
+
447
+ def num_tokens(self, index):
448
+ return self.size(index)
449
+
450
+ def size(self, index):
451
+ if self.pad_audio:
452
+ return self.sizes[index]
453
+ return min(self.sizes[index], self.max_sample_size)
454
+
455
+ def ordered_indices(self):
456
+ if self.shuffle:
457
+ order = [np.random.permutation(len(self))]
458
+ else:
459
+ order = [np.arange(len(self))]
460
+
461
+ order.append(self.sizes)
462
+ return np.lexsort(order)[::-1]
463
+
464
+ def postprocess(self, wav, cur_sample_rate):
465
+ if wav.dim() == 2:
466
+ wav = wav.mean(-1)
467
+ assert wav.dim() == 1, wav.dim()
468
+
469
+ if cur_sample_rate != self.sample_rate:
470
+ raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
471
+
472
+ if self.normalize:
473
+ with torch.no_grad():
474
+ wav = F.layer_norm(wav, wav.shape)
475
+ return wav
artst/data/speech_to_class_dataset.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+ # Based on speecht5, fairseq and espnet code bases
5
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
6
+ # --------------------------------------------------------
7
+
8
+ import logging
9
+ import os
10
+ from typing import Any, List, Optional
11
+
12
+ import numpy as np
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from fairseq.data import data_utils, Dictionary
17
+ from fairseq.data.fairseq_dataset import FairseqDataset
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def load_audio(manifest_path, max_keep, min_keep):
23
+ """manifest tsv: wav_path, wav_nframe, wav_class
24
+
25
+ Args
26
+ manifest_path: str
27
+ max_keep: int
28
+ min_keep: int
29
+
30
+ Return
31
+ root, names, inds, tot, sizes, classes
32
+ """
33
+ n_long, n_short = 0, 0
34
+ names, inds, sizes, classes = [], [], [], []
35
+ with open(manifest_path) as f:
36
+ root = f.readline().strip()
37
+ for ind, line in enumerate(f):
38
+ items = line.strip().split("\t")
39
+ assert len(items) >= 2, line
40
+ sz = int(items[1])
41
+ if min_keep is not None and sz < min_keep:
42
+ n_short += 1
43
+ elif max_keep is not None and sz > max_keep:
44
+ n_long += 1
45
+ else:
46
+ names.append(items[0])
47
+ if len(items) > 2:
48
+ classes.append(items[2])
49
+ inds.append(ind)
50
+ sizes.append(sz)
51
+ tot = ind + 1
52
+ logger.info(
53
+ (
54
+ f"max_keep={max_keep}, min_keep={min_keep}, "
55
+ f"loaded {len(names)}, skipped {n_short} short and {n_long} long, "
56
+ f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
57
+ )
58
+ )
59
+ if len(classes) == 0:
60
+ logger.warn("no classes loaded only if inference")
61
+ return root, names, inds, tot, sizes, classes
62
+
63
+
64
+ def sample_from_feature(x: np.ndarray, max_segment_length: int = 300):
65
+ """Load a segment within 300-400/51200-76800 frames or the corresponding samples from a utterance.
66
+
67
+ Args:
68
+ x (np.ndarray): feature or waveform (frames[, features]), e.g., log mel filter bank or waveform
69
+ max_segment_length (int, optional): maximum segment length. Defaults to 400.
70
+
71
+ Returns:
72
+ np.ndarray: segmented features
73
+ """
74
+ if len(x) <= max_segment_length:
75
+ return x
76
+ start = np.random.randint(0, x.shape[0] - max_segment_length)
77
+ return x[start: start + max_segment_length]
78
+
79
+
80
+ class SpeechToClassDataset(FairseqDataset):
81
+ def __init__(
82
+ self,
83
+ manifest_path: str,
84
+ sample_rate: float,
85
+ label_processors: Optional[List[Any]] = None,
86
+ max_keep_sample_size: Optional[int] = None,
87
+ min_keep_sample_size: Optional[int] = None,
88
+ shuffle: bool = True,
89
+ normalize: bool = False,
90
+ tgt_dict: Optional[Dictionary] = None,
91
+ max_length: Optional[int] = None
92
+ ):
93
+ self.audio_root, self.audio_names, inds, tot, self.wav_sizes, self.wav_classes = load_audio(
94
+ manifest_path, max_keep_sample_size, min_keep_sample_size
95
+ )
96
+ self.sample_rate = sample_rate
97
+ self.shuffle = shuffle
98
+
99
+ self.label_processors = label_processors
100
+
101
+ self.normalize = normalize
102
+ self.tgt_dict = tgt_dict
103
+ self.max_length = max_length
104
+ logger.info(
105
+ f"max_length={max_length}, normalize={normalize}"
106
+ )
107
+
108
+ def get_audio(self, index):
109
+ import soundfile as sf
110
+
111
+ wav_path = os.path.join(self.audio_root, self.audio_names[index])
112
+ wav, cur_sample_rate = sf.read(wav_path)
113
+ if self.max_length is not None:
114
+ wav = sample_from_feature(wav, self.max_length)
115
+ wav = torch.from_numpy(wav).float()
116
+ wav = self.postprocess(wav, cur_sample_rate)
117
+ return wav
118
+
119
+ def get_label(self, index):
120
+ label = self.wav_classes[index]
121
+
122
+ if self.label_processors is not None:
123
+ label = self.label_processors(label)
124
+ return label
125
+
126
+ def __getitem__(self, index):
127
+ wav = self.get_audio(index)
128
+ label = None
129
+ if len(self.wav_classes) == len(self.audio_names):
130
+ label = self.get_label(index)
131
+ return {"id": index, "source": wav, "label": label}
132
+
133
+ def __len__(self):
134
+ return len(self.wav_sizes)
135
+
136
+ def collater(self, samples):
137
+ samples = [s for s in samples if s["source"] is not None]
138
+ if len(samples) == 0:
139
+ return {}
140
+
141
+ audios = [s["source"] for s in samples]
142
+ audio_sizes = [len(s) for s in audios]
143
+
144
+ audio_size = max(audio_sizes)
145
+ collated_audios, padding_mask = self.collater_audio(
146
+ audios, audio_size
147
+ )
148
+
149
+ decoder_label = None
150
+ decoder_target = None
151
+ decoder_target_lengths = None
152
+ if samples[0]["label"] is not None:
153
+ targets_by_label = [
154
+ [s["label"] for s in samples]
155
+ ]
156
+ targets_list, lengths_list, ntokens_list = self.collater_label(targets_by_label)
157
+
158
+ decoder_label = [
159
+ (targets_list[0][i, :lengths_list[0][i]]).long()
160
+ for i in range(targets_list[0].size(0))
161
+ ]
162
+
163
+ decoder_target = data_utils.collate_tokens(
164
+ decoder_label,
165
+ self.tgt_dict.pad(),
166
+ self.tgt_dict.eos(),
167
+ left_pad=False,
168
+ move_eos_to_beginning=False,
169
+ )
170
+ decoder_target_lengths = torch.tensor(
171
+ [x.size(0) for x in decoder_label], dtype=torch.long
172
+ )
173
+ prev_output_tokens = data_utils.collate_tokens(
174
+ [torch.LongTensor([-1]) for _ in samples],
175
+ self.tgt_dict.pad(),
176
+ self.tgt_dict.eos(),
177
+ left_pad=False,
178
+ move_eos_to_beginning=True,
179
+ )
180
+
181
+ net_input = {
182
+ "source": collated_audios,
183
+ "padding_mask": padding_mask,
184
+ "prev_output_tokens": prev_output_tokens,
185
+ "task_name": "s2c",
186
+ }
187
+ batch = {
188
+ "id": torch.LongTensor([s["id"] for s in samples]),
189
+ "net_input": net_input,
190
+ "target": decoder_target,
191
+ "target_lengths": decoder_target_lengths,
192
+ "task_name": "s2c",
193
+ "ntokens": len(samples),
194
+ }
195
+
196
+ return batch
197
+
198
+ def collater_audio(self, audios, audio_size):
199
+ collated_audios = audios[0].new_zeros(len(audios), audio_size)
200
+ padding_mask = (
201
+ torch.BoolTensor(collated_audios.shape).fill_(False)
202
+ )
203
+ for i, audio in enumerate(audios):
204
+ diff = len(audio) - audio_size
205
+ if diff == 0:
206
+ collated_audios[i] = audio
207
+ elif diff < 0:
208
+ collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
209
+ padding_mask[i, diff:] = True
210
+ else:
211
+ raise Exception("Diff should not be larger than 0")
212
+ return collated_audios, padding_mask
213
+
214
+ def collater_seq_label(self, targets, pad):
215
+ lengths = torch.LongTensor([len(t) for t in targets])
216
+ ntokens = lengths.sum().item()
217
+ targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
218
+ return targets, lengths, ntokens
219
+
220
+ def collater_label(self, targets_by_label):
221
+ targets_list, lengths_list, ntokens_list = [], [], []
222
+ itr = zip(targets_by_label, [self.tgt_dict.pad()])
223
+ for targets, pad in itr:
224
+ targets, lengths, ntokens = self.collater_seq_label(targets, pad)
225
+ targets_list.append(targets)
226
+ lengths_list.append(lengths)
227
+ ntokens_list.append(ntokens)
228
+ return targets_list, lengths_list, ntokens_list
229
+
230
+ def num_tokens(self, index):
231
+ return self.size(index)
232
+
233
+ def size(self, index):
234
+ return self.wav_sizes[index]
235
+
236
+ @property
237
+ def sizes(self):
238
+ return np.array(self.wav_sizes)
239
+
240
+ def ordered_indices(self):
241
+ if self.shuffle:
242
+ order = [np.random.permutation(len(self))]
243
+ else:
244
+ order = [np.arange(len(self))]
245
+
246
+ order.append(self.wav_sizes)
247
+ return np.lexsort(order)[::-1]
248
+
249
+ def postprocess(self, wav, cur_sample_rate):
250
+ if wav.dim() == 2:
251
+ wav = wav.mean(-1)
252
+ assert wav.dim() == 1, wav.dim()
253
+
254
+ if cur_sample_rate != self.sample_rate:
255
+ raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
256
+
257
+ if self.normalize:
258
+ with torch.no_grad():
259
+ wav = F.layer_norm(wav, wav.shape)
260
+ return wav
artst/data/speech_to_speech_dataset.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+ # Based on speecht5, fairseq and espnet code bases
5
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
6
+ # --------------------------------------------------------
7
+
8
+ import logging
9
+ import os
10
+ from typing import Any, List, Optional
11
+
12
+ import librosa
13
+ import numpy as np
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from fairseq.data.fairseq_dataset import FairseqDataset
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ def _collate_frames(
21
+ frames: List[torch.Tensor], is_audio_input: bool = False
22
+ ):
23
+ """
24
+ Convert a list of 2D frames into a padded 3D tensor
25
+ Args:
26
+ frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is
27
+ length of i-th frame and f_dim is static dimension of features
28
+ Returns:
29
+ 3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
30
+ """
31
+ max_len = max(frame.size(0) for frame in frames)
32
+ if is_audio_input:
33
+ out = frames[0].new_zeros((len(frames), max_len))
34
+ else:
35
+ out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1)))
36
+ for i, v in enumerate(frames):
37
+ out[i, : v.size(0)] = v
38
+ return out
39
+
40
+ def load_audio(manifest_path, max_keep, min_keep):
41
+ """manifest tsv: src_wav, src_nframe, tgt_wav, tgt_nframe, tgt_spkemb"""
42
+ n_long, n_short = 0, 0
43
+ src_names, tgt_names, inds, sizes, tgt_sizes, spk_embeds = [], [], [], [], [], []
44
+ with open(manifest_path) as f:
45
+ root = f.readline().strip()
46
+ for ind, line in enumerate(f):
47
+ items = line.strip().split("\t")
48
+ assert len(items) >= 2, line
49
+ sz = int(items[1])
50
+ if min_keep is not None and sz < min_keep:
51
+ n_short += 1
52
+ elif max_keep is not None and sz > max_keep:
53
+ n_long += 1
54
+ else:
55
+ src_names.append(items[0])
56
+ tgt_names.append(items[2])
57
+ tgt_sizes.append(items[3])
58
+ spk_embeds.append(items[4])
59
+ inds.append(ind)
60
+ sizes.append(sz)
61
+ tot = ind + 1
62
+ logger.info(
63
+ (
64
+ f"max_keep={max_keep}, min_keep={min_keep}, "
65
+ f"loaded {len(src_names)}, skipped {n_short} short and {n_long} long, "
66
+ f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
67
+ )
68
+ )
69
+ return root, src_names, inds, tot, sizes, tgt_names, tgt_sizes, spk_embeds
70
+
71
+
72
+ def logmelfilterbank(
73
+ audio,
74
+ sampling_rate,
75
+ fft_size=1024,
76
+ hop_size=256,
77
+ win_length=None,
78
+ window="hann",
79
+ num_mels=80,
80
+ fmin=80,
81
+ fmax=7600,
82
+ eps=1e-10,
83
+ ):
84
+ """Compute log-Mel filterbank feature.
85
+ (https://github.com/kan-bayashi/ParallelWaveGAN/blob/master/parallel_wavegan/bin/preprocess.py)
86
+
87
+ Args:
88
+ audio (ndarray): Audio signal (T,).
89
+ sampling_rate (int): Sampling rate.
90
+ fft_size (int): FFT size.
91
+ hop_size (int): Hop size.
92
+ win_length (int): Window length. If set to None, it will be the same as fft_size.
93
+ window (str): Window function type.
94
+ num_mels (int): Number of mel basis.
95
+ fmin (int): Minimum frequency in mel basis calculation.
96
+ fmax (int): Maximum frequency in mel basis calculation.
97
+ eps (float): Epsilon value to avoid inf in log calculation.
98
+
99
+ Returns:
100
+ ndarray: Log Mel filterbank feature (#frames, num_mels).
101
+
102
+ """
103
+ # get amplitude spectrogram
104
+ x_stft = librosa.stft(audio, n_fft=fft_size, hop_length=hop_size,
105
+ win_length=win_length, window=window, pad_mode="reflect")
106
+ spc = np.abs(x_stft).T # (#frames, #bins)
107
+
108
+ # get mel basis
109
+ fmin = 0 if fmin is None else fmin
110
+ fmax = sampling_rate / 2 if fmax is None else fmax
111
+ mel_basis = librosa.filters.mel(sr=sampling_rate, n_fft=fft_size, n_mels=num_mels, fmin=fmin, fmax=fmax)
112
+
113
+ return np.log10(np.maximum(eps, np.dot(spc, mel_basis.T)))
114
+
115
+
116
+ class SpeechToSpeechDataset(FairseqDataset):
117
+ def __init__(
118
+ self,
119
+ manifest_path: str,
120
+ sample_rate: float,
121
+ max_keep_sample_size: Optional[int] = None,
122
+ min_keep_sample_size: Optional[int] = None,
123
+ shuffle: bool = True,
124
+ normalize: bool = False,
125
+ reduction_factor: int = 1,
126
+ ):
127
+ self.audio_root, self.audio_names, inds, tot, self.wav_sizes, self.tgt_audios, self.tgt_sizes, self.tgt_spkembs = load_audio(
128
+ manifest_path, max_keep_sample_size, min_keep_sample_size
129
+ )
130
+ self.sample_rate = sample_rate
131
+ self.shuffle = shuffle
132
+
133
+ self.normalize = normalize
134
+ self.reduction_factor = reduction_factor
135
+ logger.info(
136
+ f"reduction_factor={reduction_factor}, normalize={normalize}"
137
+ )
138
+
139
+ def get_audio(self, index):
140
+ import soundfile as sf
141
+
142
+ wav_fbank = []
143
+ for name in [self.audio_names[index], self.tgt_audios[index]]:
144
+ wav_path = os.path.join(self.audio_root, name)
145
+ wav, cur_sample_rate = sf.read(wav_path)
146
+ wav = torch.from_numpy(wav).float()
147
+ fbank = logmelfilterbank(
148
+ wav.view(-1).cpu().numpy(), 16000
149
+ )
150
+ fbank = torch.from_numpy(fbank).float()
151
+ wav = self.postprocess(wav, cur_sample_rate)
152
+ wav_fbank.append(wav)
153
+ wav_fbank.append(fbank)
154
+ src_wav, src_fbank, tgt_wav, tgt_fbank = wav_fbank
155
+ return src_wav, src_fbank, tgt_wav, tgt_fbank
156
+
157
+ def __getitem__(self, index):
158
+ src_wav, src_fbank, tgt_wav, tgt_fbank = self.get_audio(index)
159
+ spkembs = np.load(os.path.join(self.audio_root, self.tgt_spkembs[index]))
160
+ spkembs = torch.from_numpy(spkembs).float()
161
+ name = self.audio_names[index].replace("/", ".").replace(".wav", "") + "-" + self.tgt_audios[index].replace("/", ".").replace(".wav", "") + ".wav"
162
+ return {"id": index, "source": src_wav, "target": tgt_fbank, "spkembs": spkembs, "audio_name": name, "tgt_name": self.tgt_audios[index]}
163
+
164
+ def __len__(self):
165
+ return len(self.wav_sizes)
166
+
167
+ def collater(self, samples):
168
+ samples = [s for s in samples if s["source"] is not None]
169
+ if len(samples) == 0:
170
+ return {}
171
+
172
+ audios = [s["source"] for s in samples]
173
+ audio_sizes = [len(s) for s in audios]
174
+
175
+ audio_size = max(audio_sizes)
176
+ collated_audios, padding_mask = self.collater_audio(
177
+ audios, audio_size
178
+ )
179
+
180
+ fbanks = [s["target"] for s in samples]
181
+ fbank_sizes = [len(s) for s in fbanks]
182
+
183
+ collated_fbanks = _collate_frames(fbanks)
184
+ collated_fbanks_size = torch.tensor(fbank_sizes, dtype=torch.long)
185
+
186
+ # thin out frames for reduction factor (B, Lmax, odim) -> (B, Lmax//r, odim)
187
+ if self.reduction_factor > 1:
188
+ collated_fbanks_in = collated_fbanks[:, self.reduction_factor - 1 :: self.reduction_factor]
189
+ collated_fbanks_size_in = collated_fbanks_size.new([torch.div(olen, self.reduction_factor, rounding_mode='floor') for olen in collated_fbanks_size])
190
+ else:
191
+ collated_fbanks_in, collated_fbanks_size_in = collated_fbanks, collated_fbanks_size
192
+
193
+ prev_output_tokens = torch.cat(
194
+ [collated_fbanks_in.new_zeros((collated_fbanks_in.shape[0], 1, collated_fbanks_in.shape[2])), collated_fbanks_in[:, :-1]], dim=1
195
+ )
196
+
197
+ # make labels for stop prediction
198
+ labels = collated_fbanks.new_zeros(collated_fbanks.size(0), collated_fbanks.size(1))
199
+ for i, l in enumerate(fbank_sizes):
200
+ labels[i, l - 1 :] = 1.0
201
+
202
+ spkembs = _collate_frames([s["spkembs"] for s in samples], is_audio_input=True)
203
+
204
+ net_input = {
205
+ "source": collated_audios,
206
+ "padding_mask": padding_mask,
207
+ "prev_output_tokens": prev_output_tokens,
208
+ "tgt_lengths": collated_fbanks_size_in,
209
+ "spkembs": spkembs,
210
+ "task_name": "s2s",
211
+ }
212
+ batch = {
213
+ "id": torch.LongTensor([s["id"] for s in samples]),
214
+ "name": [s["audio_name"] for s in samples],
215
+ "tgt_name": [s["tgt_name"] for s in samples],
216
+ "net_input": net_input,
217
+ "labels": labels,
218
+ "dec_target": collated_fbanks,
219
+ "dec_target_lengths": collated_fbanks_size,
220
+ "src_lengths": torch.LongTensor(audio_sizes),
221
+ "task_name": "s2s",
222
+ "ntokens": sum(audio_sizes),
223
+ "target": collated_fbanks,
224
+ }
225
+
226
+ return batch
227
+
228
+ def collater_audio(self, audios, audio_size):
229
+ collated_audios = audios[0].new_zeros(len(audios), audio_size)
230
+ padding_mask = (
231
+ torch.BoolTensor(collated_audios.shape).fill_(False)
232
+ )
233
+ for i, audio in enumerate(audios):
234
+ diff = len(audio) - audio_size
235
+ if diff == 0:
236
+ collated_audios[i] = audio
237
+ elif diff < 0:
238
+ collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
239
+ padding_mask[i, diff:] = True
240
+ else:
241
+ raise Exception("Diff should not be larger than 0")
242
+ return collated_audios, padding_mask
243
+
244
+
245
+ def num_tokens(self, index):
246
+ return self.wav_sizes[index]
247
+
248
+ def size(self, index):
249
+ return self.wav_sizes[index], self.tgt_sizes[index]
250
+
251
+ @property
252
+ def sizes(self):
253
+ return np.array(self.wav_sizes)
254
+
255
+ @property
256
+ def can_reuse_epoch_itr_across_epochs(self):
257
+ """No cache dataset if dataset is large-scale. Cache dataset for small dataset."""
258
+ return True
259
+
260
+ def ordered_indices(self):
261
+ if self.shuffle:
262
+ order = [np.random.permutation(len(self))]
263
+ else:
264
+ order = [np.arange(len(self))]
265
+
266
+ order.append(self.wav_sizes)
267
+ return np.lexsort(order)[::-1]
268
+
269
+ def postprocess(self, wav, cur_sample_rate):
270
+ if wav.dim() == 2:
271
+ wav = wav.mean(-1)
272
+ assert wav.dim() == 1, wav.dim()
273
+
274
+ if cur_sample_rate != self.sample_rate:
275
+ raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
276
+
277
+ if self.normalize:
278
+ with torch.no_grad():
279
+ wav = F.layer_norm(wav, wav.shape)
280
+ return wav
artst/data/speech_to_text_dataset.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+ # Based on speecht5, fairseq and espnet code bases
5
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
6
+ # --------------------------------------------------------
7
+
8
+ import itertools
9
+ import logging
10
+ import os
11
+ import mmap
12
+ from typing import Any, List, Optional
13
+
14
+ import numpy as np
15
+
16
+ import torch
17
+ torch.set_printoptions(profile="full")
18
+ import torch.nn.functional as F
19
+ from fairseq.data import data_utils, Dictionary
20
+ from fairseq.data.fairseq_dataset import FairseqDataset
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ def load_audio(manifest_path, max_keep, min_keep):
26
+ n_long, n_short = 0, 0
27
+ names, inds, sizes = [], [], []
28
+ with open(manifest_path) as f:
29
+ root = f.readline().strip()
30
+ for ind, line in enumerate(f):
31
+ items = line.strip().split("\t")
32
+ assert len(items) >= 2, line
33
+ sz = int(items[1])
34
+ if min_keep is not None and sz < min_keep:
35
+ n_short += 1
36
+ elif max_keep is not None and sz > max_keep:
37
+ n_long += 1
38
+ else:
39
+ names.append(items[0])
40
+ inds.append(ind)
41
+ sizes.append(sz)
42
+ tot = ind + 1
43
+ logger.info(
44
+ (
45
+ f"max_keep={max_keep}, min_keep={min_keep}, "
46
+ f"loaded {len(names)}, skipped {n_short} short and {n_long} long, "
47
+ f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
48
+ )
49
+ )
50
+ return root, names, inds, tot, sizes
51
+
52
+
53
+ def load_label(label_path, inds, tot):
54
+ with open(label_path) as f:
55
+ labels = [line.rstrip() for line in f]
56
+ assert (
57
+ len(labels) == tot
58
+ ), f"number of labels does not match ({len(labels)} != {tot})"
59
+ labels = [labels[i] for i in inds]
60
+ return labels
61
+
62
+
63
+ def load_label_offset(label_path, inds, tot):
64
+ with open(label_path) as f:
65
+ # Hawau:
66
+ # changed line length reading as it's incorrect
67
+ code_lengths = [len(line.encode("utf-8")) for line in f] #original
68
+ # code_lengths = [len(line) for line in f] #fix
69
+ assert (
70
+ len(code_lengths) == tot
71
+ ), f"number of labels does not match ({len(code_lengths)} != {tot})"
72
+ offsets = list(itertools.accumulate([0] + code_lengths))
73
+ offsets = [(offsets[i], offsets[i + 1]) for i in inds]
74
+ return offsets
75
+
76
+
77
+ class SpeechToTextDataset(FairseqDataset):
78
+ def __init__(
79
+ self,
80
+ manifest_path: str,
81
+ sample_rate: float,
82
+ label_paths: List[str],
83
+ label_processors: Optional[List[Any]] = None,
84
+ max_keep_sample_size: Optional[int] = None,
85
+ min_keep_sample_size: Optional[int] = None,
86
+ shuffle: bool = True,
87
+ normalize: bool = False,
88
+ store_labels: bool = True,
89
+ tgt_dict: Optional[Dictionary] = None,
90
+ tokenizer = None,
91
+ ):
92
+ self.audio_root, self.audio_names, inds, tot, self.wav_sizes = load_audio(
93
+ manifest_path, max_keep_sample_size, min_keep_sample_size
94
+ )
95
+
96
+ self.sample_rate = sample_rate
97
+ self.shuffle = shuffle
98
+ self.tgt_dict = tgt_dict
99
+ self.tokenizer = tokenizer
100
+
101
+ self.num_labels = len(label_paths)
102
+ self.label_processors = label_processors
103
+ self.store_labels = store_labels
104
+
105
+ if store_labels:
106
+ self.label_list = [load_label(p, inds, tot) for p in label_paths]
107
+ logger.info(f"label_list: {self.label_list}")
108
+ else:
109
+ self.label_paths = label_paths
110
+ self.label_offsets_list = [
111
+ load_label_offset(p, inds, tot) for p in label_paths
112
+ ]
113
+ # logger.info(f"label_offsets_list: {self.label_offsets_list}")
114
+ assert label_processors is None or len(label_processors) == self.num_labels
115
+
116
+ self.normalize = normalize
117
+ logger.info(
118
+ f"normalize={normalize}"
119
+ )
120
+
121
+ def get_audio(self, index):
122
+ import soundfile as sf
123
+ # Hawau:
124
+ # logger.info(f"loaded_audio: {self.audio_names[index]}")
125
+ wav_path = os.path.join(self.audio_root, self.audio_names[index])
126
+ wav, cur_sample_rate = sf.read(wav_path)
127
+ wav = torch.from_numpy(wav).float()
128
+ wav = self.postprocess(wav, cur_sample_rate)
129
+ return wav
130
+
131
+ def get_label(self, index, label_idx):
132
+ if self.store_labels:
133
+ label = self.label_list[label_idx][index]
134
+ else:
135
+ # list slicing method
136
+ # with open(self.label_paths[label_idx]) as f:
137
+ # offset_s, offset_e = self.label_offsets_list[label_idx][index]
138
+ # # Hawau:
139
+ # # f.seek(offset_s)
140
+ # # label = f.read(offset_e - offset_s)
141
+ # label = f.read()[offset_s : offset_e]
142
+ # Hawau:
143
+ # mmap method
144
+ with open(self.label_paths[label_idx], encoding='utf-8') as f:
145
+ with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm:
146
+ offset_s, offset_e = self.label_offsets_list[label_idx][index]
147
+ label = mm[offset_s:offset_e].decode("utf-8")
148
+
149
+
150
+ # Hawau:
151
+ # logger.info(f"loaded_label: {label}")
152
+ if self.tokenizer is not None:
153
+ label = self.tokenizer.encode(label)
154
+
155
+ if self.label_processors is not None:
156
+ label = self.label_processors[label_idx](label)
157
+ # logger.info(f"processed_label: {label}")
158
+ return label
159
+
160
+ def get_labels(self, index):
161
+ return [self.get_label(index, i) for i in range(self.num_labels)]
162
+
163
+ def __getitem__(self, index):
164
+ wav = self.get_audio(index)
165
+ labels = self.get_labels(index)
166
+ return {"id": index, "source": wav, "label_list": labels}
167
+
168
+ def __len__(self):
169
+ return len(self.wav_sizes)
170
+
171
+ def collater(self, samples):
172
+ samples = [s for s in samples if s["source"] is not None]
173
+ if len(samples) == 0:
174
+ return {}
175
+
176
+ audios = [s["source"] for s in samples]
177
+ audio_sizes = [len(s) for s in audios]
178
+
179
+ audio_size = max(audio_sizes)
180
+ collated_audios, padding_mask = self.collater_audio(
181
+ audios, audio_size
182
+ )
183
+
184
+ targets_by_label = [
185
+ [s["label_list"][i] for s in samples] for i in range(self.num_labels)
186
+ ]
187
+ targets_list, lengths_list, ntokens_list = self.collater_label(targets_by_label)
188
+
189
+ # Hawau:
190
+ # logger.info(f'targets_list: {targets_list}')
191
+
192
+
193
+ decoder_label = [
194
+ torch.cat((targets_list[0][i, :lengths_list[0][i]], torch.tensor([self.tgt_dict.eos()])), 0).long()
195
+ for i in range(targets_list[0].size(0))
196
+ ]
197
+
198
+ decoder_target = data_utils.collate_tokens(
199
+ decoder_label,
200
+ self.tgt_dict.pad(),
201
+ self.tgt_dict.eos(),
202
+ left_pad=False,
203
+ move_eos_to_beginning=False,
204
+ )
205
+ decoder_target_lengths = torch.tensor(
206
+ [x.size(0) for x in decoder_label], dtype=torch.long
207
+ )
208
+ prev_output_tokens = data_utils.collate_tokens(
209
+ decoder_label,
210
+ self.tgt_dict.pad(),
211
+ self.tgt_dict.eos(),
212
+ left_pad=False,
213
+ move_eos_to_beginning=True,
214
+ )
215
+
216
+ net_input = {
217
+ "source": collated_audios,
218
+ "padding_mask": padding_mask,
219
+ "prev_output_tokens": prev_output_tokens,
220
+ "task_name": "s2t",
221
+ }
222
+ batch = {
223
+ "id": torch.LongTensor([s["id"] for s in samples]),
224
+ "net_input": net_input,
225
+ "target": decoder_target,
226
+ "target_lengths": decoder_target_lengths,
227
+ "task_name": "s2t",
228
+ "ntokens": ntokens_list[0]
229
+ }
230
+
231
+ return batch
232
+
233
+ def collater_audio(self, audios, audio_size):
234
+ collated_audios = audios[0].new_zeros(len(audios), audio_size)
235
+ padding_mask = (
236
+ torch.BoolTensor(collated_audios.shape).fill_(False)
237
+ )
238
+ for i, audio in enumerate(audios):
239
+ diff = len(audio) - audio_size
240
+ if diff == 0:
241
+ collated_audios[i] = audio
242
+ elif diff < 0:
243
+ collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
244
+ padding_mask[i, diff:] = True
245
+ else:
246
+ raise Exception("Diff should not be larger than 0")
247
+ return collated_audios, padding_mask
248
+
249
+ def collater_seq_label(self, targets, pad):
250
+ lengths = torch.LongTensor([len(t) for t in targets])
251
+ ntokens = lengths.sum().item()
252
+ targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
253
+ return targets, lengths, ntokens
254
+
255
+ def collater_label(self, targets_by_label):
256
+ targets_list, lengths_list, ntokens_list = [], [], []
257
+ itr = zip(targets_by_label, [self.tgt_dict.pad()])
258
+
259
+ for targets, pad in itr:
260
+ # Hawau:
261
+ # logger.info(f'targets: {targets}')
262
+ targets, lengths, ntokens = self.collater_seq_label(targets, pad)
263
+ targets_list.append(targets)
264
+ lengths_list.append(lengths)
265
+ ntokens_list.append(ntokens)
266
+ return targets_list, lengths_list, ntokens_list
267
+
268
+ def num_tokens(self, index):
269
+ return self.size(index)
270
+
271
+ def size(self, index):
272
+ return self.wav_sizes[index]
273
+
274
+ @property
275
+ def sizes(self):
276
+ return np.array(self.wav_sizes)
277
+
278
+ def ordered_indices(self):
279
+ if self.shuffle:
280
+ order = [np.random.permutation(len(self))]
281
+ else:
282
+ order = [np.arange(len(self))]
283
+
284
+ order.append(self.wav_sizes)
285
+ return np.lexsort(order)[::-1]
286
+
287
+ def postprocess(self, wav, cur_sample_rate):
288
+ if wav.dim() == 2:
289
+ wav = wav.mean(-1)
290
+ assert wav.dim() == 1, wav.dim()
291
+
292
+ if cur_sample_rate != self.sample_rate:
293
+ raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
294
+
295
+ if self.normalize:
296
+ with torch.no_grad():
297
+ wav = F.layer_norm(wav, wav.shape)
298
+ return wav
artst/data/text_dataset.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+ # Based on speecht5, fairseq and espnet code bases
5
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
6
+ # --------------------------------------------------------
7
+
8
+ import math
9
+
10
+ import numpy as np
11
+ import torch
12
+
13
+ from fairseq.data import FairseqDataset, data_utils
14
+
15
+
16
+ def collate(
17
+ samples,
18
+ pad_idx,
19
+ eos_idx,
20
+ vocab,
21
+ left_pad_source=False,
22
+ left_pad_target=False,
23
+ input_feeding=True,
24
+ pad_to_length=None,
25
+ ):
26
+ assert input_feeding
27
+ if len(samples) == 0:
28
+ return {}
29
+
30
+ def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
31
+ return data_utils.collate_tokens(
32
+ [s[key] for s in samples],
33
+ pad_idx,
34
+ eos_idx=None, # use eos_idx of each sample instead of vocab.eos()
35
+ left_pad=left_pad,
36
+ move_eos_to_beginning=move_eos_to_beginning,
37
+ pad_to_length=pad_to_length,
38
+ )
39
+
40
+ id = torch.LongTensor([s["id"] for s in samples])
41
+ src_tokens = merge(
42
+ "source",
43
+ left_pad=left_pad_source,
44
+ pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
45
+ )
46
+ # sort by descending source length
47
+ src_lengths = torch.LongTensor([s["source"].numel() for s in samples])
48
+ src_lengths, sort_order = src_lengths.sort(descending=True)
49
+ id = id.index_select(0, sort_order)
50
+ src_tokens = src_tokens.index_select(0, sort_order)
51
+
52
+ prev_output_tokens = None
53
+ target = None
54
+ if samples[0].get("target", None) is not None:
55
+ target = merge(
56
+ "target",
57
+ left_pad=left_pad_target,
58
+ pad_to_length=pad_to_length["target"]
59
+ if pad_to_length is not None
60
+ else None,
61
+ )
62
+ target = target.index_select(0, sort_order)
63
+ ntokens = sum(len(s["target"]) for s in samples)
64
+
65
+ if input_feeding:
66
+ # we create a shifted version of targets for feeding the
67
+ # previous output token(s) into the next decoder step
68
+ prev_output_tokens = merge(
69
+ "target",
70
+ left_pad=left_pad_target,
71
+ move_eos_to_beginning=True,
72
+ pad_to_length=pad_to_length["target"]
73
+ if pad_to_length is not None
74
+ else None,
75
+ )
76
+ prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
77
+ else:
78
+ ntokens = sum(len(s["source"]) for s in samples)
79
+
80
+ batch = {
81
+ "id": id,
82
+ "ntokens": ntokens,
83
+ "net_input": {
84
+ "src_tokens": src_tokens,
85
+ "src_lengths": src_lengths,
86
+ },
87
+ "target": target,
88
+ "nsentences": samples[0]["source"].size(0),
89
+ "sort_order": sort_order,
90
+ "task_name": 'text_pretrain',
91
+ }
92
+ if prev_output_tokens is not None:
93
+ batch["net_input"]["prev_output_tokens"] = prev_output_tokens
94
+
95
+ return batch
96
+
97
+
98
+ class TextPretrainDataset(FairseqDataset):
99
+ """
100
+ A wrapper around TokenBlockDataset for BART dataset.
101
+
102
+ Args:
103
+ dataset (TokenBlockDataset): dataset to wrap
104
+ sizes (List[int]): sentence lengths
105
+ vocab (~fairseq.data.Dictionary): vocabulary
106
+ mask_idx (int): dictionary index used for masked token
107
+ mask_whole_words: only mask whole words. This should be a byte mask
108
+ over vocab indices, indicating whether it is the beginning of a
109
+ word. We will extend any mask to encompass the whole word.
110
+ shuffle (bool, optional): shuffle the elements before batching.
111
+ Default: ``True``
112
+ seed: Seed for random number generator for reproducibility.
113
+ args: argparse arguments.
114
+ """
115
+
116
+ def __init__(
117
+ self,
118
+ dataset,
119
+ sizes,
120
+ vocab,
121
+ mask_idx,
122
+ mask_whole_words,
123
+ shuffle,
124
+ seed,
125
+ args,
126
+ eos=None,
127
+ item_transform_func=None,
128
+ iid_noise_target=False,
129
+ uni_mask_idxs=None,
130
+ ):
131
+ self.dataset = dataset
132
+
133
+ self.sizes = sizes
134
+
135
+ self.vocab = vocab
136
+ self.shuffle = shuffle
137
+ self.seed = seed
138
+ if iid_noise_target:
139
+ assert isinstance(uni_mask_idxs, torch.Tensor), "if use iid_noise_target, the uni_mask_idxs must be a tensor which contain the mask indexs"
140
+ self.iid_noise_target = iid_noise_target
141
+ self.uni_mask_idxs = uni_mask_idxs
142
+ self.mask_idx = mask_idx
143
+ self.mask_whole_word = mask_whole_words
144
+ self.mask_ratio = args.mask
145
+ self.random_ratio = args.mask_random
146
+ self.insert_ratio = args.insert
147
+ self.rotate_ratio = args.rotate
148
+ self.permute_sentence_ratio = args.permute_sentences
149
+ self.eos = eos if eos is not None else vocab.eos()
150
+ self.item_transform_func = item_transform_func
151
+
152
+ if args.bpe != "gpt2":
153
+ self.full_stop_index = self.vocab.eos()
154
+ else:
155
+ assert args.bpe == "gpt2"
156
+ self.full_stop_index = self.vocab.index("13")
157
+
158
+ self.replace_length = args.replace_length
159
+ if self.replace_length not in [-1, 0, 1]:
160
+ raise ValueError(f"invalid arg: replace_length={self.replace_length}")
161
+ if args.mask_length not in ["subword", "word", "span-poisson"]:
162
+ raise ValueError(f"invalid arg: mask-length={args.mask_length}")
163
+ if args.mask_length == "subword" and args.replace_length not in [0, 1]:
164
+ raise ValueError(f"if using subwords, use replace-length=1 or 0")
165
+
166
+ self.mask_span_distribution = None
167
+ if args.mask_length == "span-poisson":
168
+ _lambda = args.poisson_lambda
169
+
170
+ lambda_to_the_k = 1
171
+ e_to_the_minus_lambda = math.exp(-_lambda)
172
+ k_factorial = 1
173
+ ps = []
174
+ for k in range(0, 128):
175
+ ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial)
176
+ lambda_to_the_k *= _lambda
177
+ k_factorial *= k + 1
178
+ if ps[-1] < 0.0000001:
179
+ break
180
+ ps = torch.FloatTensor(ps)
181
+ self.mask_span_distribution = torch.distributions.Categorical(ps)
182
+
183
+ self.epoch = 0
184
+
185
+ @property
186
+ def can_reuse_epoch_itr_across_epochs(self):
187
+ return True # only the noise changes, not item sizes
188
+
189
+ def set_epoch(self, epoch, **unused):
190
+ self.epoch = epoch
191
+
192
+ def __getitem__(self, index):
193
+ with data_utils.numpy_seed(self.seed, self.epoch, index):
194
+ tokens = self.dataset[index]
195
+ assert tokens[-1] == self.eos
196
+ source, target = tokens, tokens.clone()
197
+
198
+ if self.permute_sentence_ratio > 0.0:
199
+ source = self.permute_sentences(source, self.permute_sentence_ratio)
200
+
201
+ if self.mask_ratio > 0:
202
+ source, new_target = self.add_whole_word_mask(source, self.mask_ratio)
203
+ if new_target is not None:
204
+ target = new_target
205
+
206
+ if self.insert_ratio > 0:
207
+ source = self.add_insertion_noise(source, self.insert_ratio)
208
+
209
+ if self.rotate_ratio > 0.0 and np.random.random() < self.rotate_ratio:
210
+ source = self.add_rolling_noise(source)
211
+ # there can additional changes to make:
212
+ if self.item_transform_func is not None:
213
+ source, target = self.item_transform_func(source, target)
214
+
215
+ assert (source >= 0).all()
216
+ assert (source[1:-1] >= 1).all()
217
+ assert (source <= len(self.vocab)).all()
218
+ assert source[0] == self.vocab.bos()
219
+ assert source[-1] == self.eos
220
+ return {
221
+ "id": index,
222
+ "source": source,
223
+ "target": target,
224
+ }
225
+
226
+ def __len__(self):
227
+ return len(self.dataset)
228
+
229
+ def permute_sentences(self, source, p=1.0):
230
+ full_stops = source == self.full_stop_index
231
+ # Pretend it ends with a full stop so last span is a sentence
232
+ full_stops[-2] = 1
233
+
234
+ # Tokens that are full stops, where the previous token is not
235
+ sentence_ends = (full_stops[1:] * ~full_stops[:-1]).nonzero(as_tuple=False) + 2
236
+ result = source.clone()
237
+
238
+ num_sentences = sentence_ends.size(0)
239
+ num_to_permute = math.ceil((num_sentences * 2 * p) / 2.0)
240
+ substitutions = torch.randperm(num_sentences)[:num_to_permute]
241
+ ordering = torch.arange(0, num_sentences)
242
+ ordering[substitutions] = substitutions[torch.randperm(num_to_permute)]
243
+
244
+ # Ignore <bos> at start
245
+ index = 1
246
+ for i in ordering:
247
+ sentence = source[(sentence_ends[i - 1] if i > 0 else 1) : sentence_ends[i]]
248
+ result[index : index + sentence.size(0)] = sentence
249
+ index += sentence.size(0)
250
+ return result
251
+
252
+ def word_starts(self, source):
253
+ if self.mask_whole_word is not None:
254
+ is_word_start = self.mask_whole_word.gather(0, source)
255
+ else:
256
+ is_word_start = torch.ones(source.size())
257
+ is_word_start[0] = 0
258
+ is_word_start[-1] = 0
259
+ return is_word_start
260
+
261
+ def add_whole_word_mask(self, source, p):
262
+ source_ori = source.clone()
263
+ is_word_start = self.word_starts(source)
264
+ num_to_mask = int(math.ceil(is_word_start.float().sum() * p))
265
+ num_inserts = 0
266
+ if num_to_mask == 0:
267
+ return source
268
+
269
+ if self.mask_span_distribution is not None:
270
+ lengths = self.mask_span_distribution.sample(sample_shape=(num_to_mask,))
271
+
272
+ # Make sure we have enough to mask
273
+ cum_length = torch.cumsum(lengths, 0)
274
+ while cum_length[-1] < num_to_mask:
275
+ lengths = torch.cat(
276
+ [
277
+ lengths,
278
+ self.mask_span_distribution.sample(sample_shape=(num_to_mask,)),
279
+ ],
280
+ dim=0,
281
+ )
282
+ cum_length = torch.cumsum(lengths, 0)
283
+
284
+ # Trim to masking budget
285
+ i = 0
286
+ while cum_length[i] < num_to_mask:
287
+ i += 1
288
+ lengths[i] = num_to_mask - (0 if i == 0 else cum_length[i - 1])
289
+ num_to_mask = i + 1
290
+ lengths = lengths[:num_to_mask]
291
+
292
+ # Handle 0-length mask (inserts) separately
293
+ lengths = lengths[lengths > 0]
294
+ num_inserts = num_to_mask - lengths.size(0)
295
+ num_to_mask -= num_inserts
296
+ if num_to_mask == 0:
297
+ return self.add_insertion_noise(source, num_inserts / source.size(0))
298
+
299
+ assert (lengths > 0).all()
300
+ else:
301
+ lengths = torch.ones((num_to_mask,)).long()
302
+ assert is_word_start[-1] == 0
303
+ word_starts = is_word_start.nonzero(as_tuple=False)
304
+ indices = word_starts[
305
+ torch.randperm(word_starts.size(0))[:num_to_mask]
306
+ ].squeeze(1)
307
+ mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio
308
+
309
+ source_length = source.size(0)
310
+ assert source_length - 1 not in indices
311
+ to_keep = torch.ones(source_length, dtype=torch.bool)
312
+ is_word_start[
313
+ -1
314
+ ] = 255 # acts as a long length, so spans don't go over the end of doc
315
+ if self.replace_length == 0:
316
+ to_keep[indices] = 0
317
+ else:
318
+ # keep index, but replace it with [MASK]
319
+ source[indices] = self.mask_idx
320
+ source[indices[mask_random]] = torch.randint(
321
+ 1, len(self.vocab), size=(mask_random.sum(),)
322
+ )
323
+
324
+ if self.mask_span_distribution is not None:
325
+ assert len(lengths.size()) == 1
326
+ assert lengths.size() == indices.size()
327
+ lengths -= 1
328
+ while indices.size(0) > 0:
329
+ assert lengths.size() == indices.size()
330
+ lengths -= is_word_start[indices + 1].long()
331
+ uncompleted = lengths >= 0
332
+ indices = indices[uncompleted] + 1
333
+ mask_random = mask_random[uncompleted]
334
+ lengths = lengths[uncompleted]
335
+ if self.replace_length != -1:
336
+ # delete token
337
+ to_keep[indices] = 0
338
+ else:
339
+ # keep index, but replace it with [MASK]
340
+ source[indices] = self.mask_idx
341
+ source[indices[mask_random]] = torch.randint(
342
+ 1, len(self.vocab), size=(mask_random.sum(),)
343
+ )
344
+ else:
345
+ # A bit faster when all lengths are 1
346
+ while indices.size(0) > 0:
347
+ uncompleted = is_word_start[indices + 1] == 0
348
+ indices = indices[uncompleted] + 1
349
+ mask_random = mask_random[uncompleted]
350
+ if self.replace_length != -1:
351
+ # delete token
352
+ to_keep[indices] = 0
353
+ else:
354
+ # keep index, but replace it with [MASK]
355
+ source[indices] = self.mask_idx
356
+ source[indices[mask_random]] = torch.randint(
357
+ 1, len(self.vocab), size=(mask_random.sum(),)
358
+ )
359
+
360
+ assert source_length - 1 not in indices
361
+
362
+ if not self.iid_noise_target:
363
+ source = source[to_keep]
364
+ target = None
365
+ else:
366
+ ## Prepare source
367
+ source_mask_idx = (source == self.mask_idx).nonzero().view(-1)
368
+ source[source_mask_idx] = self.uni_mask_idxs[:source_mask_idx.size(0)]
369
+ source = source[to_keep]
370
+
371
+ ## Prepare target
372
+ to_keep[source_mask_idx] = 0
373
+
374
+ # source_mask_idx: from [a, b, c, ...] to [a, b + 1, c + 2, ...]
375
+ source_mask_idx = source_mask_idx + torch.arange(source_mask_idx.size(0))
376
+ # target: source_length + mask_length
377
+ target = source_ori.new_zeros(source_mask_idx.size(0) + source_ori.size(0))
378
+ # target: [0, 0, 0, X, 0, 0, Y, ....]
379
+ target[source_mask_idx] = self.uni_mask_idxs[:source_mask_idx.size(0)]
380
+
381
+ target_to_keep = to_keep.new_zeros(source_mask_idx.size(0) + source_ori.size(0))
382
+
383
+ # Copy original value to target and target_to_keep
384
+ target_to_keep[target == 0] = to_keep
385
+ target_to_keep[-1] = 0
386
+ target[target == 0] = source_ori
387
+
388
+ target = target[~target_to_keep]
389
+
390
+ if num_inserts > 0:
391
+ source = self.add_insertion_noise(source, num_inserts / source.size(0))
392
+
393
+ return source, target
394
+
395
+ def add_permuted_noise(self, tokens, p):
396
+ num_words = len(tokens)
397
+ num_to_permute = math.ceil(((num_words * 2) * p) / 2.0)
398
+ substitutions = torch.randperm(num_words - 2)[:num_to_permute] + 1
399
+ tokens[substitutions] = tokens[substitutions[torch.randperm(num_to_permute)]]
400
+ return tokens
401
+
402
+ def add_rolling_noise(self, tokens):
403
+ offset = np.random.randint(1, max(1, tokens.size(-1) - 1) + 1)
404
+ tokens = torch.cat(
405
+ (tokens[0:1], tokens[offset:-1], tokens[1:offset], tokens[-1:]),
406
+ dim=0,
407
+ )
408
+ return tokens
409
+
410
+ def add_insertion_noise(self, tokens, p):
411
+ if p == 0.0:
412
+ return tokens
413
+
414
+ num_tokens = len(tokens)
415
+ n = int(math.ceil(num_tokens * p))
416
+
417
+ noise_indices = torch.randperm(num_tokens + n - 2)[:n] + 1
418
+ noise_mask = torch.zeros(size=(num_tokens + n,), dtype=torch.bool)
419
+ noise_mask[noise_indices] = 1
420
+ result = torch.LongTensor(n + len(tokens)).fill_(-1)
421
+
422
+ num_random = int(math.ceil(n * self.random_ratio))
423
+ result[noise_indices[num_random:]] = self.mask_idx
424
+ result[noise_indices[:num_random]] = torch.randint(
425
+ low=1, high=len(self.vocab), size=(num_random,)
426
+ )
427
+
428
+ result[~noise_mask] = tokens
429
+
430
+ assert (result >= 0).all()
431
+ return result
432
+
433
+ def collater(self, samples, pad_to_length=None):
434
+ """Merge a list of samples to form a mini-batch.
435
+ Args:
436
+ samples (List[dict]): samples to collate
437
+ Returns:
438
+ dict: a mini-batch of data
439
+ """
440
+ return collate(
441
+ samples, self.vocab.pad(), self.eos, self.vocab, pad_to_length=pad_to_length
442
+ )
443
+
444
+ def num_tokens(self, index):
445
+ """Return the number of tokens in a sample. This value is used to
446
+ enforce ``--max-tokens`` during batching."""
447
+ return self.sizes[index]
448
+
449
+ def size(self, index):
450
+ """Return an example's size as a float or tuple. This value is used when
451
+ filtering a dataset with ``--max-positions``."""
452
+ return self.sizes[index]
453
+
454
+ def ordered_indices(self):
455
+ """Return an ordered list of indices. Batches will be constructed based
456
+ on this order."""
457
+ if self.shuffle:
458
+ indices = np.random.permutation(len(self))
459
+ else:
460
+ indices = np.arange(len(self))
461
+ return indices[np.argsort(self.sizes[indices], kind="mergesort")]
462
+
463
+ def prefetch(self, indices):
464
+ self.src.prefetch(indices)
465
+ self.tgt.prefetch(indices)
466
+
467
+ @property
468
+ def supports_prefetch(self):
469
+ return (
470
+ hasattr(self.src, "supports_prefetch")
471
+ and self.src.supports_prefetch
472
+ and hasattr(self.tgt, "supports_prefetch")
473
+ and self.tgt.supports_prefetch
474
+ )
artst/data/text_to_speech_dataset.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+ # Based on speecht5, fairseq and espnet code bases
5
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
6
+ # --------------------------------------------------------
7
+
8
+ import itertools
9
+ import logging
10
+ import os
11
+ from typing import Any, List, Optional
12
+ import mmap
13
+
14
+ import numpy as np
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ import librosa
19
+ from fairseq.data.audio.speech_to_text_dataset import get_features_or_waveform
20
+ from fairseq.data import data_utils, Dictionary
21
+ from fairseq.data.fairseq_dataset import FairseqDataset
22
+
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ def _collate_frames(
27
+ frames: List[torch.Tensor], is_audio_input: bool = False
28
+ ):
29
+ """
30
+ Convert a list of 2D frames into a padded 3D tensor
31
+ Args:
32
+ frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is
33
+ length of i-th frame and f_dim is static dimension of features
34
+ Returns:
35
+ 3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
36
+ """
37
+ max_len = max(frame.size(0) for frame in frames)
38
+ if is_audio_input:
39
+ out = frames[0].new_zeros((len(frames), max_len))
40
+ else:
41
+ out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1)))
42
+ for i, v in enumerate(frames):
43
+ out[i, : v.size(0)] = v
44
+ return out
45
+
46
+ def load_audio(manifest_path, max_keep, min_keep):
47
+ n_long, n_short = 0, 0
48
+ names, inds, sizes, spk_embeds = [], [], [], []
49
+ with open(manifest_path) as f:
50
+ root = f.readline().strip()
51
+ for ind, line in enumerate(f):
52
+ items = line.strip().split("\t")
53
+ assert len(items) == 3, line
54
+ sz = int(items[1])
55
+ if min_keep is not None and sz < min_keep:
56
+ n_short += 1
57
+ elif max_keep is not None and sz > max_keep:
58
+ n_long += 1
59
+ else:
60
+ names.append(items[0])
61
+ spk_embeds.append(items[2])
62
+ inds.append(ind)
63
+ sizes.append(sz)
64
+ tot = ind + 1
65
+ logger.info(
66
+ (
67
+ f"max_keep={max_keep}, min_keep={min_keep}, "
68
+ f"loaded {len(names)}, skipped {n_short} short and {n_long} long, "
69
+ f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
70
+ )
71
+ )
72
+ return root, names, inds, tot, sizes, spk_embeds
73
+
74
+
75
+ def load_label(label_path, inds, tot):
76
+ with open(label_path) as f:
77
+ labels = [line.rstrip() for line in f]
78
+ assert (
79
+ len(labels) == tot
80
+ ), f"number of labels does not match ({len(labels)} != {tot})"
81
+ labels = [labels[i] for i in inds]
82
+ return labels
83
+
84
+
85
+ def load_label_offset(label_path, inds, tot):
86
+ with open(label_path, encoding='utf-8') as f:
87
+ code_lengths = [len(line.encode("utf-8")) for line in f] #changed as in speech_to_text_dataset.py
88
+ assert (
89
+ len(code_lengths) == tot
90
+ ), f"number of labels does not match ({len(code_lengths)} != {tot})"
91
+ offsets = list(itertools.accumulate([0] + code_lengths))
92
+ offsets = [(offsets[i], offsets[i + 1]) for i in inds]
93
+ return offsets
94
+
95
+
96
+ def logmelfilterbank(
97
+ audio,
98
+ sampling_rate,
99
+ fft_size=1024,
100
+ hop_size=256,
101
+ win_length=None,
102
+ window="hann",
103
+ num_mels=80,
104
+ fmin=80,
105
+ fmax=7600,
106
+ eps=1e-10,
107
+ ):
108
+ """Compute log-Mel filterbank feature.
109
+ (https://github.com/kan-bayashi/ParallelWaveGAN/blob/master/parallel_wavegan/bin/preprocess.py)
110
+
111
+ Args:
112
+ audio (ndarray): Audio signal (T,).
113
+ sampling_rate (int): Sampling rate.
114
+ fft_size (int): FFT size.
115
+ hop_size (int): Hop size.
116
+ win_length (int): Window length. If set to None, it will be the same as fft_size.
117
+ window (str): Window function type.
118
+ num_mels (int): Number of mel basis.
119
+ fmin (int): Minimum frequency in mel basis calculation.
120
+ fmax (int): Maximum frequency in mel basis calculation.
121
+ eps (float): Epsilon value to avoid inf in log calculation.
122
+
123
+ Returns:
124
+ ndarray: Log Mel filterbank feature (#frames, num_mels).
125
+
126
+ """
127
+ # get amplitude spectrogram
128
+ x_stft = librosa.stft(audio, n_fft=fft_size, hop_length=hop_size,
129
+ win_length=win_length, window=window, pad_mode="reflect")
130
+ spc = np.abs(x_stft).T # (#frames, #bins)
131
+
132
+ # get mel basis
133
+ fmin = 0 if fmin is None else fmin
134
+ fmax = sampling_rate / 2 if fmax is None else fmax
135
+ mel_basis = librosa.filters.mel(sr=sampling_rate, n_fft=fft_size, n_mels=num_mels, fmin=fmin, fmax=fmax)
136
+
137
+ return np.log10(np.maximum(eps, np.dot(spc, mel_basis.T)))
138
+
139
+
140
+
141
+ class TextToSpeechDataset(FairseqDataset):
142
+ def __init__(
143
+ self,
144
+ manifest_path: str,
145
+ sample_rate: float,
146
+ label_paths: List[str],
147
+ label_processors: Optional[List[Any]] = None,
148
+ max_keep_sample_size: Optional[int] = None,
149
+ min_keep_sample_size: Optional[int] = None,
150
+ shuffle: bool = True,
151
+ normalize: bool = False,
152
+ store_labels: bool = True,
153
+ src_dict: Optional[Dictionary] = None,
154
+ tokenizer = None,
155
+ reduction_factor: int = 1,
156
+ inference: bool = False,
157
+ ):
158
+
159
+ self.audio_root, self.audio_names, inds, tot, self.wav_sizes, self.spk_embeds = load_audio(
160
+ manifest_path, max_keep_sample_size, min_keep_sample_size
161
+ )
162
+ self.inference = inference
163
+
164
+ self.sample_rate = sample_rate
165
+ self.shuffle = shuffle
166
+ self.src_dict = src_dict
167
+ self.tokenizer = tokenizer
168
+
169
+ self.num_labels = len(label_paths)
170
+ self.label_processors = label_processors
171
+ self.store_labels = store_labels
172
+ if store_labels:
173
+ self.label_list = [load_label(p, inds, tot) for p in label_paths]
174
+ else:
175
+ self.label_paths = label_paths
176
+ self.label_offsets_list = [
177
+ load_label_offset(p, inds, tot) for p in label_paths
178
+ ]
179
+ assert label_processors is None or len(label_processors) == self.num_labels
180
+
181
+ self.normalize = normalize
182
+ self.reduction_factor = reduction_factor
183
+ logger.info(
184
+ f"reduction_factor={reduction_factor}, normalize={normalize}"
185
+ )
186
+
187
+ def get_audio(self, index):
188
+ import soundfile as sf
189
+
190
+ wav_path = os.path.join(self.audio_root, self.audio_names[index])
191
+ wav, cur_sample_rate = sf.read(wav_path)
192
+ wav = torch.from_numpy(wav).float()
193
+ fbank = logmelfilterbank(
194
+ wav.view(-1).cpu().numpy(), 16000
195
+ )
196
+ fbank = torch.from_numpy(fbank).float()
197
+ wav = self.postprocess(wav, cur_sample_rate)
198
+ return wav, fbank
199
+
200
+ def get_label(self, index, label_idx):
201
+ if self.store_labels:
202
+ label = self.label_list[label_idx][index]
203
+ else:
204
+ # with open(self.label_paths[label_idx]) as f:
205
+ # offset_s, offset_e = self.label_offsets_list[label_idx][index]
206
+ # f.seek(offset_s)
207
+ # label = f.read(offset_e - offset_s)
208
+
209
+ # Hawau:
210
+ # mmap method
211
+ with open(self.label_paths[label_idx], encoding='utf-8') as f:
212
+ with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm:
213
+ offset_s, offset_e = self.label_offsets_list[label_idx][index]
214
+ label = mm[offset_s:offset_e].decode("utf-8")
215
+
216
+
217
+ if self.tokenizer is not None:
218
+ label = self.tokenizer.encode(label)
219
+
220
+ if self.label_processors is not None:
221
+ label = self.label_processors[label_idx](label)
222
+ return label
223
+
224
+ def get_labels(self, index):
225
+ return [self.get_label(index, i) for i in range(self.num_labels)]
226
+
227
+ def __getitem__(self, index):
228
+ wav, fbank = self.get_audio(index)
229
+ labels = self.get_labels(index)
230
+ spkembs = get_features_or_waveform(
231
+ os.path.join(self.audio_root, self.spk_embeds[index])
232
+ )
233
+ spkembs = torch.from_numpy(spkembs).float()
234
+
235
+ return {"id": index, "source": labels, "target": fbank, "spkembs": spkembs, "audio_name": self.audio_names[index]}
236
+
237
+
238
+ def __len__(self):
239
+ return len(self.wav_sizes)
240
+
241
+ def collater(self, samples):
242
+ samples = [s for s in samples if s["source"] is not None]
243
+ if len(samples) == 0:
244
+ return {}
245
+
246
+ fbanks = [s["target"] for s in samples]
247
+ fbank_sizes = [len(s) for s in fbanks]
248
+
249
+ collated_fbanks = _collate_frames(fbanks)
250
+ collated_fbanks_size = torch.tensor(fbank_sizes, dtype=torch.long)
251
+
252
+ # thin out frames for reduction factor (B, Lmax, odim) -> (B, Lmax//r, odim)
253
+ if self.reduction_factor > 1:
254
+ collated_fbanks_in = collated_fbanks[:, self.reduction_factor - 1 :: self.reduction_factor]
255
+ collated_fbanks_size_in = collated_fbanks_size.new([torch.div(olen, self.reduction_factor, rounding_mode='floor') for olen in collated_fbanks_size])
256
+ else:
257
+ collated_fbanks_in, collated_fbanks_size_in = collated_fbanks, collated_fbanks_size
258
+
259
+ prev_output_tokens = torch.cat(
260
+ [collated_fbanks_in.new_zeros((collated_fbanks_in.shape[0], 1, collated_fbanks_in.shape[2])), collated_fbanks_in[:, :-1]], dim=1
261
+ )
262
+
263
+ # make labels for stop prediction
264
+ labels = collated_fbanks.new_zeros(collated_fbanks.size(0), collated_fbanks.size(1))
265
+ for i, l in enumerate(fbank_sizes):
266
+ labels[i, l - 1 :] = 1.0
267
+
268
+ spkembs = _collate_frames([s["spkembs"] for s in samples], is_audio_input=True)
269
+
270
+ sources_by_label = [
271
+ [s["source"][i] for s in samples] for i in range(self.num_labels)
272
+ ]
273
+ sources_list, lengths_list, ntokens_list = self.collater_label(sources_by_label)
274
+
275
+ net_input = {
276
+ "src_tokens": sources_list[0],
277
+ "src_lengths": lengths_list[0],
278
+ "prev_output_tokens": prev_output_tokens,
279
+ "tgt_lengths": collated_fbanks_size_in,
280
+ "spkembs": spkembs,
281
+ "task_name": "t2s",
282
+ }
283
+ batch = {
284
+ "id": torch.LongTensor([s["id"] for s in samples]),
285
+ "name": [s["audio_name"] for s in samples],
286
+ "net_input": net_input,
287
+ "labels": labels,
288
+ "dec_target": collated_fbanks,
289
+ "dec_target_lengths": collated_fbanks_size,
290
+ "src_lengths": lengths_list[0],
291
+ "task_name": "t2s",
292
+ "ntokens": ntokens_list[0],
293
+ "target": collated_fbanks,
294
+ }
295
+
296
+ return batch
297
+
298
+ def collater_seq_label(self, targets, pad):
299
+ lengths = torch.LongTensor([len(t) for t in targets])
300
+ ntokens = lengths.sum().item()
301
+ targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
302
+ return targets, lengths, ntokens
303
+
304
+ def collater_label(self, targets_by_label):
305
+ targets_list, lengths_list, ntokens_list = [], [], []
306
+ itr = zip(targets_by_label, [self.src_dict.pad()])
307
+ for targets, pad in itr:
308
+ targets, lengths, ntokens = self.collater_seq_label(targets, pad)
309
+ targets_list.append(targets)
310
+ lengths_list.append(lengths)
311
+ ntokens_list.append(ntokens)
312
+ return targets_list, lengths_list, ntokens_list
313
+
314
+ def num_tokens(self, index):
315
+ return self.size(index)
316
+
317
+ def size(self, index):
318
+ return self.wav_sizes[index]
319
+
320
+ @property
321
+ def sizes(self):
322
+ return np.array(self.wav_sizes)
323
+
324
+ def ordered_indices(self):
325
+ if self.shuffle:
326
+ order = [np.random.permutation(len(self))]
327
+ else:
328
+ order = [np.arange(len(self))]
329
+
330
+ order.append(self.wav_sizes)
331
+ return np.lexsort(order)[::-1]
332
+
333
+ def postprocess(self, wav, cur_sample_rate):
334
+ if wav.dim() == 2:
335
+ wav = wav.mean(-1)
336
+ assert wav.dim() == 1, wav.dim()
337
+
338
+ if cur_sample_rate != self.sample_rate:
339
+ raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
340
+
341
+ if self.normalize:
342
+ with torch.no_grad():
343
+ wav = F.layer_norm(wav, wav.shape)
344
+ return wav
artst/models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .artst import * # noqa
2
+ from .t5_transformer_lm import * # noqa
artst/models/artst.py ADDED
@@ -0,0 +1,1448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+ # Based on speecht5, fairseq and espnet code bases
5
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
6
+ # --------------------------------------------------------
7
+
8
+ import logging
9
+ from ast import literal_eval
10
+ from typing import Dict, List, Optional, Tuple
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from fairseq import utils
15
+ from fairseq.models import (
16
+ FairseqEncoderDecoderModel,
17
+ FairseqIncrementalDecoder,
18
+ register_model,
19
+ register_model_architecture,
20
+ )
21
+ from .modules.text_encoder_prenet import TextEncoderPrenet
22
+ from .modules.text_decoder_prenet import TextDecoderPrenet
23
+ from .modules.text_decoder_postnet import TextDecoderPostnet
24
+ from .modules.speech_encoder_prenet import SpeechEncoderPrenet
25
+ from .modules.speech_encoder_postnet import SpeechEncoderPostnet
26
+ from .modules.speech_decoder_prenet import SpeechDecoderPrenet
27
+ from .modules.speech_decoder_postnet import SpeechDecoderPostnet
28
+ from .modules.speaker_decoder_postnet import SpeakerDecoderPostnet
29
+ from .modules.encoder import TransformerEncoder
30
+ from .modules.decoder import TransformerDecoder
31
+ from fairseq.modules.transformer_sentence_encoder import init_bert_params
32
+ from fairseq.models.transformer import Embedding
33
+ from fairseq.modules import (
34
+ GumbelVectorQuantizer,
35
+ )
36
+ from torch import Tensor
37
+
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+ DEFAULT_MAX_TEXT_POSITIONS = 450
42
+ DEFAULT_MAX_SPEECH_POSITIONS = 4000
43
+
44
+
45
+ @register_model("artst_transformer")
46
+ class ArTSTTransformerModel(FairseqEncoderDecoderModel):
47
+ """Adapted Transformer model (https://arxiv.org/abs/1706.03762) for
48
+ speech-to-text tasks. The Transformer encoder/decoder remains the same.
49
+ A trainable input subsampler is prepended to the Transformer encoder to
50
+ project inputs into the encoder dimension as well as downsample input
51
+ sequence for computational efficiency."""
52
+
53
+ def __init__(
54
+ self,
55
+ args,
56
+ encoder, decoder,
57
+ text_encoder_prenet, speech_encoder_prenet,
58
+ text_decoder_prenet, speech_decoder_prenet,
59
+ text_decoder_postnet, speech_decoder_postnet,
60
+ speaker_decoder_postnet, speech_encoder_postnet,
61
+ ):
62
+ super().__init__(encoder, decoder)
63
+
64
+ self.encoder = encoder
65
+ self.decoder = decoder
66
+
67
+ self.text_encoder_prenet = text_encoder_prenet
68
+ self.speech_encoder_prenet = speech_encoder_prenet
69
+
70
+ self.text_decoder_prenet = text_decoder_prenet
71
+ self.speech_decoder_prenet = speech_decoder_prenet
72
+
73
+ self.text_decoder_postnet = text_decoder_postnet
74
+ self.speech_decoder_postnet = speech_decoder_postnet
75
+ self.speaker_decoder_postnet = speaker_decoder_postnet
76
+
77
+ self.hubert_layer = speech_encoder_postnet
78
+
79
+ self.reduction_factor = args.reduction_factor
80
+ self.spk_embed_dim = args.spk_embed_dim
81
+
82
+ # define projection layer
83
+ self.spk_embed_integration_type = args.spk_embed_integration_type
84
+ if self.spk_embed_dim is not None and self.spk_embed_integration_type != 'pre':
85
+ if self.spk_embed_integration_type == "add":
86
+ self.projection = torch.nn.Linear(self.spk_embed_dim, args.decoder_embed_dim)
87
+ else:
88
+ self.projection = torch.nn.Linear(
89
+ args.decoder_embed_dim + self.spk_embed_dim, args.decoder_embed_dim
90
+ )
91
+
92
+ # Hawau: here we can add language embedding integration
93
+
94
+ self.use_codebook = args.use_codebook
95
+ self.codebook_prob = getattr(args, "codebook_prob", 0.5) # args.codebook_prob
96
+ if self.use_codebook:
97
+ vq_dim = args.latent_dim if args.latent_dim > 0 else args.encoder_embed_dim
98
+ self.quantizer = GumbelVectorQuantizer(
99
+ dim=args.encoder_embed_dim,
100
+ num_vars=args.latent_vars,
101
+ temp=args.latent_temp,
102
+ groups=args.latent_groups,
103
+ combine_groups=False,
104
+ vq_dim=vq_dim,
105
+ time_first=True,
106
+ weight_proj_depth=args.quantizer_depth,
107
+ weight_proj_factor=args.quantizer_factor,
108
+ )
109
+
110
+ self.num_updates = 0
111
+
112
+ # # Follow BERT's random weight initialization (for BART)
113
+ if args.bert_init:
114
+ self.apply(init_bert_params)
115
+ self.args = args
116
+ self.prune_modules(args.modules_filter)
117
+
118
+ @staticmethod
119
+ def add_args(parser):
120
+ """Add model-specific arguments to the parser."""
121
+ # Transformer
122
+ parser.add_argument(
123
+ "--activation-fn",
124
+ type=str,
125
+ choices=utils.get_available_activation_fns(),
126
+ help="activation function to use",
127
+ )
128
+ parser.add_argument(
129
+ "--dropout", type=float, metavar="D", help="dropout probability"
130
+ )
131
+ parser.add_argument(
132
+ "--attention-dropout",
133
+ type=float,
134
+ metavar="D",
135
+ help="dropout probability for attention weights",
136
+ )
137
+ parser.add_argument(
138
+ "--activation-dropout",
139
+ "--relu-dropout",
140
+ type=float,
141
+ metavar="D",
142
+ help="dropout probability after activation in FFN.",
143
+ )
144
+ parser.add_argument(
145
+ "--encoder-embed-dim",
146
+ type=int,
147
+ metavar="N",
148
+ help="encoder embedding dimension",
149
+ )
150
+ parser.add_argument(
151
+ "--encoder-ffn-embed-dim",
152
+ type=int,
153
+ metavar="N",
154
+ help="encoder embedding dimension for FFN",
155
+ )
156
+ parser.add_argument(
157
+ "--encoder-layers", type=int, metavar="N", help="num encoder layers"
158
+ )
159
+ parser.add_argument(
160
+ "--encoder-attention-heads",
161
+ type=int,
162
+ metavar="N",
163
+ help="num encoder attention heads",
164
+ )
165
+ parser.add_argument(
166
+ "--encoder-normalize-before",
167
+ action="store_true",
168
+ help="apply layernorm before each encoder block",
169
+ )
170
+ parser.add_argument(
171
+ "--decoder-normalize-before",
172
+ action="store_true",
173
+ help="apply layernorm before each decoder block",
174
+ )
175
+ parser.add_argument(
176
+ "--decoder-embed-dim",
177
+ type=int,
178
+ metavar="N",
179
+ help="decoder embedding dimension",
180
+ )
181
+ parser.add_argument(
182
+ "--decoder-ffn-embed-dim",
183
+ type=int,
184
+ metavar="N",
185
+ help="decoder embedding dimension for FFN",
186
+ )
187
+ parser.add_argument(
188
+ "--decoder-layers", type=int, metavar="N", help="num decoder layers"
189
+ )
190
+ parser.add_argument(
191
+ "--decoder-attention-heads",
192
+ type=int,
193
+ metavar="N",
194
+ help="num decoder attention heads",
195
+ )
196
+ parser.add_argument(
197
+ "--reduction-factor",
198
+ type=int,
199
+ help="reduction factor for decoder",
200
+ )
201
+ parser.add_argument(
202
+ "--spk-embed-dim",
203
+ type=int,
204
+ help="speaker embedding dimension",
205
+ )
206
+ parser.add_argument(
207
+ "--layernorm-embedding",
208
+ action="store_true",
209
+ help="add layernorm to embedding",
210
+ )
211
+ parser.add_argument(
212
+ "--load-pretrained-encoder-from",
213
+ type=str,
214
+ metavar="STR",
215
+ help="model to take encoder weights from (for initialization)",
216
+ )
217
+ parser.add_argument(
218
+ '--freeze-encoder-updates',
219
+ type=int,
220
+ help='number of steps to freeze encoder before finetune'
221
+ )
222
+ parser.add_argument(
223
+ '--freeze-decoder-updates',
224
+ type=int,
225
+ help='number of steps to freeze decoder before finetune'
226
+ )
227
+ parser.add_argument(
228
+ '--no-freeze-encoder-layer',
229
+ type=str,
230
+ help='which encoder layer not freeze during finetune'
231
+ )
232
+ parser.add_argument(
233
+ "--share-input-output-embed",
234
+ action="store_true",
235
+ help="share decoder input and output embeddings",
236
+ )
237
+ parser.add_argument(
238
+ "--share-ctc-embed",
239
+ action="store_true",
240
+ help="share ctc embed and decoder embed",
241
+ )
242
+ parser.add_argument(
243
+ "--encoder-sliding-window-attn",
244
+ default=None,
245
+ type=int,
246
+ help="If not None but a even number, set sliding window attention to encoder's attn_mask, e.g., 4, 10, and 20",
247
+ )
248
+
249
+ # Convolutional subsampler
250
+ parser.add_argument(
251
+ "--encoder-speech-prenet",
252
+ default="conv",
253
+ type=str,
254
+ choices=["conv", "linear"],
255
+ help="The type of encoder speech prenet, e.g., conv or linear."
256
+ )
257
+ parser.add_argument(
258
+ "--conv-kernel-sizes",
259
+ default="5,5",
260
+ type=str,
261
+ help="The layer of convolution of encoder speech prenet."
262
+ )
263
+ parser.add_argument(
264
+ "--conv-channels",
265
+ default=1024,
266
+ type=int,
267
+ help="The channels of encoder speech prenet."
268
+ )
269
+ parser.add_argument(
270
+ "--subsample-stride",
271
+ default="2,2",
272
+ type=str,
273
+ help="The subsample stride for conv1dsubsample."
274
+ )
275
+ parser.add_argument(
276
+ "--spk-embed-integration-type",
277
+ type=str,
278
+ choices=["pre", "add"],
279
+ help="speaker embedding integration type"
280
+ )
281
+ parser.add_argument(
282
+ "--dprenet-dropout-rate",
283
+ default=0.5,
284
+ type=float,
285
+ help="The dropout rate of decoder speech prenet."
286
+ )
287
+
288
+ ## SE
289
+ parser.add_argument(
290
+ "--se-predict",
291
+ default=None,
292
+ choices=["masking", "target", "delta"],
293
+ help="If set, source speech inputs decoder to predict the masking/target/delta of corresponding inputs."
294
+ + "masking is [0, 1], target is predicted output, delta is difference between inputs and outputs",
295
+ )
296
+ parser.add_argument(
297
+ "--se-decoder-input",
298
+ type=str,
299
+ default="previous_target",
300
+ choices=["previous_target", "source"],
301
+ )
302
+
303
+ ## SID
304
+ parser.add_argument(
305
+ "--modules-filter",
306
+ default=None,
307
+ type=str,
308
+ help="Remove unused modules for, e.g., SID.",
309
+ )
310
+ parser.add_argument(
311
+ "--sid-pad-prenet",
312
+ action="store_true",
313
+ help="If set, the size of text dictionary is as small as for <pad> token.",
314
+ )
315
+ parser.add_argument(
316
+ "--encoder-attn-branch",
317
+ type=str,
318
+ default="identity,full",
319
+ help="encoder attention branch sliding window, e.g., 'identity,0,2,4,full'",
320
+ )
321
+ parser.add_argument(
322
+ "--encoder-block-branch",
323
+ type=str,
324
+ help="average the output of encoder, e.g., '4,5,6'",
325
+ )
326
+ parser.add_argument(
327
+ "--sid-encoder-cls",
328
+ default=None,
329
+ choices=["encoder"],
330
+ help="If set, add cls vector to the encoder input, e.g., constant vector.",
331
+ )
332
+ parser.add_argument(
333
+ "--sid-shuffle-encoder-input",
334
+ action="store_true",
335
+ help="If set, shuffle encoder input in time.",
336
+ )
337
+ parser.add_argument(
338
+ "--sid-decoder-speaker",
339
+ action="store_true",
340
+ help="If set, apply speaker decoder as transformer decoder.",
341
+ )
342
+ parser.add_argument(
343
+ "--sid-decoder-attn-dim",
344
+ default=128,
345
+ type=int,
346
+ help="Attention dimension in attensive statistics pooling of speaker decoder.",
347
+ )
348
+ parser.add_argument(
349
+ "--sid-t5-postnet",
350
+ action="store_true",
351
+ help="If set, apply TextDecoderPostnet as speaker classification.",
352
+ )
353
+ parser.add_argument(
354
+ "--sid-embed-dim",
355
+ default=128,
356
+ type=int,
357
+ help="Embedding dimension in speaker postnet for speaker identification if embed postnet.",
358
+ )
359
+ parser.add_argument(
360
+ "--sid-pooling-layer",
361
+ default="decoder",
362
+ type=str,
363
+ choices=["decoder-las", "decoder", "encoder", "encoder-cls", "encoder-speaker"],
364
+ help="The output of decoder or encoder uses as SID pooling layer over temporal dimension.",
365
+ )
366
+ parser.add_argument(
367
+ "--sid-no-pooling-bn",
368
+ action="store_true",
369
+ help="If set, not attention batchnorm.",
370
+ )
371
+ parser.add_argument(
372
+ "--sid-no-embed-postnet",
373
+ action="store_true",
374
+ help="If set, no layer between decoder output and classification layer.",
375
+ )
376
+ parser.add_argument(
377
+ "--sid-normalize-postnet",
378
+ action="store_true",
379
+ help="If set, normalize input and weight in postnet/classifier.",
380
+ )
381
+ parser.add_argument(
382
+ "--sid-softmax-type",
383
+ default="softmax",
384
+ choices=["softmax", "amsoftmax", "aamsoftmax"],
385
+ help="If using amsoftmax or aamsoftmax, the target should be given.",
386
+ )
387
+ parser.add_argument(
388
+ "--softmax-scale",
389
+ default=1.0,
390
+ type=float,
391
+ help="Scale for AMSoftmax or AAMSoftmax.",
392
+ )
393
+ parser.add_argument(
394
+ "--softmax-margin",
395
+ default=0.0,
396
+ type=float,
397
+ help="Margin for AMSoftmax or AAMSoftmax.",
398
+ )
399
+ parser.add_argument(
400
+ "--softmax-easy-margin",
401
+ action="store_true",
402
+ help="Enable easy margin for AAMSoftmax.",
403
+ )
404
+ parser.add_argument(
405
+ "--encoder-layerdrop",
406
+ type=float,
407
+ metavar="D",
408
+ help="LayerDrop probability for encoder",
409
+ )
410
+ parser.add_argument(
411
+ "--decoder-layerdrop",
412
+ type=float,
413
+ metavar="D",
414
+ help="LayerDrop probability for decoder",
415
+ )
416
+
417
+ ## Hubert
418
+ parser.add_argument(
419
+ '--feature-grad-mult',
420
+ type=float,
421
+ help='multiply feature extractor var grads by this'
422
+ )
423
+ parser.add_argument(
424
+ '--logit-temp',
425
+ type=float,
426
+ help='temperature to divide logits by'
427
+ )
428
+ parser.add_argument(
429
+ '--final-dim',
430
+ type=int,
431
+ help="project final representations and targets to this many "
432
+ "dimensions. set to encoder_embed_dim is <= 0"
433
+ )
434
+
435
+ # mask
436
+ parser.add_argument(
437
+ '--hubert-mask-length',
438
+ type=int,
439
+ help='mask length'
440
+ )
441
+ parser.add_argument(
442
+ '--mask-prob',
443
+ type=float,
444
+ help='probability of replacing a token with mask'
445
+ )
446
+ parser.add_argument(
447
+ "--mask-selection",
448
+ choices=["static", "uniform", "normal", "poisson"],
449
+ help="how to choose mask length",
450
+ )
451
+ parser.add_argument(
452
+ '--mask-other',
453
+ type=float,
454
+ help="secondary mask argument "
455
+ "(used for more complex distributions), "
456
+ "see help in compute_mask_indices"
457
+ )
458
+ parser.add_argument(
459
+ '--mask-min-space',
460
+ type=int,
461
+ help='min space between spans (if no overlap is enabled)'
462
+ )
463
+
464
+ # channel masking
465
+ parser.add_argument(
466
+ '--mask-channel-length',
467
+ type=int,
468
+ help='length of the mask for features (channels)'
469
+ )
470
+ parser.add_argument(
471
+ '--mask-channel-prob',
472
+ type=float,
473
+ help="probability of replacing a feature with 0"
474
+ )
475
+ parser.add_argument(
476
+ "--mask-channel-selection",
477
+ choices=["static", "uniform", "normal", "poisson"],
478
+ help="how to choose mask length for channel masking",
479
+ )
480
+ parser.add_argument(
481
+ '--mask-channel-other',
482
+ type=float,
483
+ help="secondary mask argument "
484
+ "(used for more complex distributions), "
485
+ "see help in compute_mask_indices"
486
+ )
487
+ parser.add_argument(
488
+ '--mask-channel-min-space',
489
+ type=int,
490
+ help='min space between spans (if no overlap is enabled)'
491
+ )
492
+
493
+ # abs positional embeddings
494
+ parser.add_argument(
495
+ '--conv-pos',
496
+ type=int,
497
+ help='number of filters for convolutional positional embeddings'
498
+ )
499
+ parser.add_argument(
500
+ '--conv-pos-groups',
501
+ type=int,
502
+ help='number of groups for convolutional positional embedding'
503
+ )
504
+
505
+ # codebook related
506
+ parser.add_argument(
507
+ "--use-codebook",
508
+ action="store_true",
509
+ help="whether to use codebook",
510
+ )
511
+ parser.add_argument(
512
+ "--codebook-prob",
513
+ type=float,
514
+ help="probability to use codebook",
515
+ )
516
+ parser.add_argument(
517
+ "--latent-vars",
518
+ type=int,
519
+ help="number of latent variables V in each group of the codebook",
520
+ )
521
+ parser.add_argument(
522
+ "--latent-groups",
523
+ type=int,
524
+ help="number of groups G of latent variables in the codebook",
525
+ )
526
+ parser.add_argument(
527
+ "--latent-dim",
528
+ type=int,
529
+ help="if > 0, uses this dimensionality for latent variables. "
530
+ "otherwise uses final_dim / latent_groups",
531
+ )
532
+ parser.add_argument(
533
+ "--latent-temp",
534
+ type=literal_eval,
535
+ help="temperature for latent variable sampling. "
536
+ "can be tuple of 3 values (start, end, decay)",
537
+ )
538
+ parser.add_argument(
539
+ "--quantizer-depth",
540
+ type=int,
541
+ help="number of quantizer layers",
542
+ )
543
+ parser.add_argument(
544
+ "--quantizer-factor",
545
+ type=int,
546
+ help="number of quantizer layers",
547
+ )
548
+ parser.add_argument(
549
+ "--get-code-distribution",
550
+ action='store_true',
551
+ help="whether to get the code distribution (for test)",
552
+ )
553
+
554
+ # relative pos enc
555
+ parser.add_argument(
556
+ "--relative-position-embedding",
557
+ action='store_true',
558
+ help="whether to use relative position embedding",
559
+ )
560
+ parser.add_argument(
561
+ "--num-buckets",
562
+ type=int,
563
+ default=320,
564
+ help="num of buckets for relative position embedding",
565
+ )
566
+ parser.add_argument(
567
+ "--max-distance",
568
+ type=int,
569
+ default=1280,
570
+ help="max distance for relative position embedding",
571
+ )
572
+ parser.add_argument(
573
+ "--encoder-max-relative-position",
574
+ type=int,
575
+ help="max distance for relative position embedding in encoder",
576
+ )
577
+ parser.add_argument(
578
+ "--decoder-max-relative-position",
579
+ type=int,
580
+ help="max distance for relative position embedding in decoder",
581
+ )
582
+
583
+ # hubert feature extractor
584
+ parser.add_argument(
585
+ "--conv-feature-layers",
586
+ type=str,
587
+ help= "string describing convolutional feature extraction "
588
+ "layers in form of a python list that contains "
589
+ "[(dim, kernel_size, stride), ...]",
590
+ )
591
+ parser.add_argument(
592
+ "--conv-bias",
593
+ action='store_true',
594
+ help="include bias in conv encoder",
595
+ )
596
+ parser.add_argument(
597
+ "--extractor-mode",
598
+ choices=["default", "layer_norm"],
599
+ help="mode for feature extractor. default has a single group "
600
+ "norm with d groups in the first conv block, whereas layer_norm "
601
+ "has layer norms in every block (meant to use with normalize=True)"
602
+ )
603
+
604
+ # others
605
+ parser.add_argument(
606
+ "--bert-init",
607
+ action='store_true',
608
+ help="initilize as bert",
609
+ )
610
+ parser.add_argument(
611
+ "--unb-enc-layer",
612
+ type=int,
613
+ default=-1,
614
+ help="which layer's output is used as the input of decoder",
615
+ )
616
+
617
+ # Encoder, Decoder
618
+ @classmethod
619
+ def build_encoder(cls, args, dictionary=None, embed_tokens=None):
620
+ return TransformerEncoder(args, dictionary, embed_tokens)
621
+
622
+ @classmethod
623
+ def build_decoder(cls, args):
624
+ return TransformerDecoder(args)
625
+
626
+ # Encoder Prenet
627
+ @classmethod
628
+ def build_text_encoder_prenet(cls, embed_tokens, args):
629
+ return TextEncoderPrenet(embed_tokens, args)
630
+
631
+ @classmethod
632
+ def build_speech_encoder_prenet(cls, args):
633
+ return SpeechEncoderPrenet(args)
634
+
635
+ # Decoder Prenet
636
+ @classmethod
637
+ def build_text_decoder_prenet(cls, embed_tokens, args):
638
+ return TextDecoderPrenet(embed_tokens, args)
639
+
640
+ @classmethod
641
+ def build_speech_decoder_prenet(cls, odim, args):
642
+ return SpeechDecoderPrenet(odim, args)
643
+
644
+ # Decoder Postnet
645
+ @classmethod
646
+ def build_text_decoder_postnet(cls, embed_tokens, dictionary, args):
647
+ return TextDecoderPostnet(embed_tokens, dictionary, args)
648
+
649
+ @classmethod
650
+ def build_speaker_decoder_postnet(cls, embed_dim, class_num, args):
651
+ return SpeakerDecoderPostnet(embed_dim, class_num, args)
652
+
653
+ @classmethod
654
+ def build_speech_decoder_postnet(cls, odim, args):
655
+ return SpeechDecoderPostnet(odim, args)
656
+
657
+ @classmethod
658
+ def build_speech_encoder_postnet(cls, dictionaries, args):
659
+ return SpeechEncoderPostnet(dictionaries, args)
660
+
661
+ @classmethod
662
+ def build_model(cls, args, task):
663
+ """Build a new model instance."""
664
+
665
+ # make sure all arguments are present in older models
666
+ base_architecture(args)
667
+
668
+ def build_embedding(dictionary, embed_dim, max_num_embeddings=None):
669
+ num_embeddings = len(dictionary)
670
+ if max_num_embeddings is not None and isinstance(max_num_embeddings, int):
671
+ num_embeddings = min(num_embeddings, max_num_embeddings)
672
+ padding_idx = dictionary.pad()
673
+ return Embedding(num_embeddings, embed_dim, padding_idx)
674
+
675
+ if hasattr(args, "sid_pad_prenet") and args.sid_pad_prenet:
676
+ max_num_embeddings = 3 # <pad> at index 2
677
+ else:
678
+ max_num_embeddings = None
679
+
680
+ text_decoder_embed_tokens = build_embedding(
681
+ task.dicts["text"], args.decoder_embed_dim, max_num_embeddings
682
+ )
683
+
684
+ if args.share_input_output_embed:
685
+ text_encoder_embed_tokens = text_decoder_embed_tokens
686
+ else:
687
+ text_encoder_embed_tokens = build_embedding(
688
+ task.dicts["text"], args.encoder_embed_dim
689
+ )
690
+
691
+ speech_odim = args.speech_odim
692
+ if "text" in task.dicts:
693
+ encoder = cls.build_encoder(args, task.dicts["text"], text_encoder_embed_tokens)
694
+ else:
695
+ encoder = cls.build_encoder(args)
696
+ decoder = cls.build_decoder(args)
697
+
698
+ text_encoder_prenet = cls.build_text_encoder_prenet(text_encoder_embed_tokens, args)
699
+ speech_encoder_prenet = cls.build_speech_encoder_prenet(args)
700
+
701
+ text_decoder_prenet = cls.build_text_decoder_prenet(text_decoder_embed_tokens, args)
702
+ if getattr(args, "sid_pooling_layer", None) == "decoder-las":
703
+ speech_decoder_prenet = cls.build_speech_encoder_prenet(args)
704
+ else:
705
+ speech_decoder_prenet = cls.build_speech_decoder_prenet(speech_odim, args)
706
+
707
+ text_decoder_postnet = cls.build_text_decoder_postnet(text_decoder_embed_tokens, task.dicts['text'], args)
708
+ speech_decoder_postnet = cls.build_speech_decoder_postnet(speech_odim, args)
709
+
710
+ if getattr(args, "sid_t5_postnet", False):
711
+ speaker_decoder_postnet = None
712
+ else:
713
+ if task.t5_task == "s2c":
714
+ speaker_decoder_postnet = cls.build_speaker_decoder_postnet(args.sid_embed_dim, len(task.dicts['text']), args)
715
+ else:
716
+ speaker_decoder_postnet = None
717
+
718
+ if "hubert" in task.dicts:
719
+ speech_encoder_postnet = cls.build_speech_encoder_postnet(task.dicts['hubert'], args)
720
+ else:
721
+ speech_encoder_postnet = None
722
+
723
+ return cls(
724
+ args,
725
+ encoder, decoder,
726
+ text_encoder_prenet, speech_encoder_prenet,
727
+ text_decoder_prenet, speech_decoder_prenet,
728
+ text_decoder_postnet, speech_decoder_postnet,
729
+ speaker_decoder_postnet, speech_encoder_postnet,
730
+ )
731
+
732
+ def get_normalized_probs(
733
+ self,
734
+ net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
735
+ log_probs: bool,
736
+ sample: Optional[Dict[str, Tensor]] = None,
737
+ ):
738
+ # net_output['encoder_out'] is a (B, T, D) tensor
739
+ lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, sample)
740
+ lprobs.batch_first = True
741
+ return lprobs
742
+
743
+ def get_normalized_probs_for_ctc(self, net_output, log_probs):
744
+ """Get normalized probabilities (or log probs) from a net's output."""
745
+
746
+ logits = net_output["encoder_out_for_ctc"][0]
747
+ if log_probs:
748
+ return utils.log_softmax(logits.float(), dim=-1)
749
+ else:
750
+ return utils.softmax(logits.float(), dim=-1)
751
+
752
+ def get_logits(self, net_output, is_masked=True):
753
+ if is_masked:
754
+ logits_list = net_output["logit_m_list"]
755
+ else:
756
+ logits_list = net_output["logit_u_list"]
757
+ logits_list = [x.float() for x in logits_list if x is not None]
758
+ return logits_list
759
+
760
+ def get_targets(self, sample, net_output, is_masked=True):
761
+ if "logit_m_list" in net_output:
762
+ logits_list = self.get_logits(net_output, is_masked)
763
+ targets_list = [
764
+ x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list
765
+ ]
766
+ return targets_list
767
+ else:
768
+ return sample["target"]
769
+
770
+ def get_extra_losses(self, net_output):
771
+ extra_losses = []
772
+ names = []
773
+
774
+ if "features_pen" in net_output:
775
+ extra_losses.append(net_output["features_pen"])
776
+ names.append("features_pen")
777
+
778
+ if "prob_perplexity" in net_output:
779
+ extra_losses.append(
780
+ (net_output["num_vars"] - net_output["prob_perplexity"])
781
+ / net_output["num_vars"]
782
+ )
783
+ names.append("prob_perplexity")
784
+
785
+ return extra_losses, names
786
+
787
+ def forward(self, source=None, src_tokens=None, src_lengths=None, prev_output_tokens=None, tgt_lengths=None, spkembs=None, target_list=None, task_name=None, padding_mask=None, only_hubert=False, only_ctc=False, feature_only=False, tgt_enc_layer=None, mask=True):
788
+ """
789
+ The forward method inherited from the base class has a **kwargs
790
+ argument in its input, which is not supported in torchscript. This
791
+ method overwrites the forward method definition without **kwargs.
792
+ """
793
+ assert source is not None or src_tokens is not None
794
+ # padding_mask is not none only when input is waveform
795
+ if source is None and padding_mask is None and not feature_only:
796
+ input_type = 'text'
797
+ else:
798
+ input_type = 'speech'
799
+
800
+ if prev_output_tokens is not None and len(prev_output_tokens.size()) == 2:
801
+ output_type = 'text'
802
+ codebook_out = {}
803
+ else:
804
+ output_type = 'speech'
805
+
806
+ if task_name is not None and task_name == "s2c":
807
+ if target_list is not None and target_list.size(1) == 1 and not getattr(self.args, "sid_t5_postnet", False):
808
+ sid_target = F.one_hot(target_list.squeeze(1), num_classes=self.speaker_decoder_postnet.class_num)
809
+ else:
810
+ sid_target = None
811
+ target_list = None
812
+
813
+ # Encoder Prenet
814
+ if input_type == 'text':
815
+ encoder_input, encoder_padding_mask = self.text_encoder_prenet(src_tokens)
816
+ else:
817
+ if target_list is not None:
818
+ encoder_input, encoder_padding_mask = self.speech_encoder_prenet(source, require_feat_pen=True, target_list=target_list, padding_mask=padding_mask, mask=mask)
819
+ encoder_input, features_pen, mask_indices, target_list = encoder_input
820
+ else:
821
+ encoder_input, encoder_padding_mask = self.speech_encoder_prenet(source, padding_mask=padding_mask, mask=self.training)
822
+ # shuffle a batch of inputs of encoder
823
+ if self.training and hasattr(self.args, "sid_shuffle_encoder_input") and getattr(self.args, "sid_shuffle_encoder_input", False):
824
+ shuffle_index = torch.randperm(encoder_padding_mask.size(1), device=encoder_padding_mask.device)
825
+ encoder_input = torch.index_select(encoder_input, 1, shuffle_index)
826
+ encoder_padding_mask = torch.index_select(encoder_padding_mask, 1, shuffle_index)
827
+ if getattr(self.args, "sid_encoder_cls", None) == "encoder":
828
+ prev_output_tokens = torch.zeros_like(prev_output_tokens)
829
+ encoder_input, encoder_padding_mask = self._integrate_with_speaker_cls(prev_output_tokens, encoder_input, encoder_padding_mask)
830
+
831
+ # Encoder: T x B x C
832
+ encoder_output = self.encoder(encoder_input, encoder_padding_mask, tgt_layer=tgt_enc_layer)
833
+
834
+ if task_name is not None and task_name == 'speech_pretrain' and feature_only:
835
+ return encoder_output["encoder_out"][0].transpose(0, 1)
836
+
837
+ if task_name is not None and task_name == 's2c':
838
+ if self.args.sid_pooling_layer == "encoder":
839
+ return self.speaker_decoder_postnet(encoder_output["encoder_out"][0].transpose(0, 1).mean(1), sid_target), None
840
+ elif self.args.sid_pooling_layer == "encoder-cls":
841
+ return self.speaker_decoder_postnet(encoder_output["encoder_out"][0].transpose(0, 1)[:,0], sid_target), None
842
+ elif self.args.sid_pooling_layer == "encoder-speaker" or getattr(self.args, "sid_decoder_speaker", False):
843
+ return self.speaker_decoder_postnet(encoder_output["encoder_out"][0].transpose(0, 1), sid_target), None
844
+
845
+ if target_list is not None:
846
+ hubert_results = self.hubert_layer(
847
+ encoder_output["encoder_out"][0].transpose(0, 1),
848
+ encoder_padding_mask,
849
+ mask_indices,
850
+ target_list
851
+ )
852
+
853
+ hubert_results['features_pen'] = features_pen
854
+
855
+ if "decoder_input" in encoder_output and encoder_output["decoder_input"][0] is not None:
856
+ # Change the encoder output to decoder input once set unb-enc-layer
857
+ encoder_output["encoder_out"] = encoder_output["decoder_input"]
858
+
859
+ if self.use_codebook:
860
+ q = self.quantizer(encoder_output["encoder_out"][0].transpose(0, 1))
861
+
862
+ # q["x"]: B x T x C
863
+ # Sample indexs according to the codebook prob
864
+ random_idx = torch.randperm(q["x"].size(1))[:int(q["x"].size(1) * self.codebook_prob)]
865
+ # Make weight for q
866
+ q_w = q["x"].new_zeros(q["x"].size(1))
867
+ q_w[random_idx] = 1.0
868
+ # Combine quantized codes and encoder output
869
+ encoder_output["encoder_out"][0] = (
870
+ q_w.view(-1, 1) * q["x"] + (- q_w + 1).view(-1, 1) * encoder_output["encoder_out"][0].transpose(0, 1)
871
+ ).transpose(0, 1)
872
+
873
+ # encoder_output["encoder_out"][0] = q["x"].transpose(0, 1)
874
+ if output_type == 'speech':
875
+ hubert_results["prob_perplexity"] = q["prob_perplexity"]
876
+ hubert_results["code_perplexity"] = q["code_perplexity"]
877
+ hubert_results["num_vars"] = q["num_vars"]
878
+ hubert_results["temp"] = q["temp"]
879
+ elif output_type == 'text':
880
+ codebook_out["prob_perplexity"] = q["prob_perplexity"]
881
+ codebook_out["code_perplexity"] = q["code_perplexity"]
882
+ codebook_out["num_vars"] = q["num_vars"]
883
+ codebook_out["temp"] = q["temp"]
884
+
885
+ if only_hubert and target_list is not None:
886
+ return hubert_results, None
887
+
888
+ if only_ctc and task_name is not None and task_name == "s2t":
889
+ return None, encoder_output
890
+ elif not self.training and prev_output_tokens is None and task_name == "s2t" and task_name is not None:
891
+ return encoder_output
892
+
893
+ # Decoder Prenet
894
+ if output_type == 'text':
895
+ # _ is the incremental state
896
+ prev_output_tokens, tgt_mask, _ = self.text_decoder_prenet(prev_output_tokens)
897
+ if task_name is not None and task_name == 's2c':
898
+ prev_output_tokens = torch.zeros_like(prev_output_tokens)
899
+ else:
900
+ # integrate speaker embedding
901
+ if self.spk_embed_integration_type == "pre" and self.spk_embed_dim is not None:
902
+ # Decoder Prenet
903
+ prev_output_tokens, tgt_mask = self.speech_decoder_prenet(prev_output_tokens, tgt_lengths, spkembs)
904
+ else:
905
+ if self.spk_embed_dim is not None:
906
+ encoder_output["encoder_out"] = [self._integrate_with_spk_embed(
907
+ encoder_output["encoder_out"][0].transpose(0, 1), spkembs
908
+ ).transpose(0, 1)]
909
+
910
+ prev_output_tokens, tgt_mask = self.speech_decoder_prenet(prev_output_tokens, tgt_lengths)
911
+
912
+ # BART Sequence Classification: cat <pad> + feature before decoder
913
+ if task_name is not None and task_name == 's2c' and self.args.sid_pooling_layer == "decoder-las":
914
+ decoder_feat_input, decoder_feat_mask = self.speech_decoder_prenet(src_tokens, src_lengths)
915
+ prev_output_tokens, tgt_mask = self._integrate_with_speaker_cls((prev_output_tokens, tgt_mask), decoder_feat_input, decoder_feat_mask, cls_first=False)
916
+
917
+ # SE predict masking to corresponding inputs and source speech replaces the prev_output_tokens as the input of decoder
918
+ if task_name is not None and task_name == "s2s" and getattr(self.args, "se_decoder_input", "previous_target") == "source":
919
+ prev_output_tokens, tgt_mask = self.speech_decoder_prenet(src_tokens, src_lengths)
920
+
921
+ # Decoder
922
+ decoder_output, extra = self.decoder(prev_output_tokens, tgt_mask, encoder_output,
923
+ full_context_alignment=getattr(self.args, "decoder_full_context_alignment", False),
924
+ alignment_layer=(-1 if target_list is None and output_type == 'speech' else None))
925
+ # Decoder Postnet
926
+ if task_name is not None and task_name == 's2c':
927
+ if not getattr(self.args, "sid_t5_postnet", False):
928
+ if self.args.sid_pooling_layer == "decoder":
929
+ return self.speaker_decoder_postnet(decoder_output.mean(1), sid_target), None
930
+ elif self.args.sid_pooling_layer == "decoder-las":
931
+ indices = (tgt_mask.eq(False).float().sum(1) - 1.0).type(torch.int64)
932
+ indices = indices.unsqueeze(1).unsqueeze(2).expand(-1, -1, decoder_output.size(2))
933
+ return self.speaker_decoder_postnet(decoder_output.gather(1, indices), sid_target), None
934
+ else:
935
+ return (self.text_decoder_postnet(decoder_output), None), encoder_output
936
+
937
+ # SE predict: masking, target, delta. Ensure reduction factor 1
938
+ if task_name is not None and task_name == 's2s' and getattr(self.args, "se_predict", None) is not None:
939
+ assert self.reduction_factor == 1, f"{self.reduction_factor} != 1"
940
+ before_outs, after_outs, logits = self.speech_decoder_postnet(decoder_output)
941
+ se_predict = getattr(self.args, "se_predict")
942
+ if se_predict == "masking":
943
+ before_outs = torch.sigmoid(before_outs) * src_tokens
944
+ after_outs = torch.sigmoid(after_outs) * src_tokens
945
+ return before_outs, after_outs, logits, extra['attn'][0]
946
+ elif se_predict == "target":
947
+ return before_outs, after_outs, logits, extra['attn'][0]
948
+ elif se_predict == "delta":
949
+ before_outs = before_outs - src_tokens
950
+ after_outs = after_outs - src_tokens
951
+ return before_outs, after_outs, logits, extra['attn'][0]
952
+ else:
953
+ raise ValueError(f"{se_predict} not in [masking, target, delta]")
954
+
955
+ if task_name is not None and task_name == 's2t':
956
+ #return self.text_decoder_postnet(decoder_output), None
957
+ return (self.text_decoder_postnet(decoder_output), None), encoder_output
958
+ if output_type == 'text':
959
+ return (self.text_decoder_postnet(decoder_output), None), codebook_out, encoder_output
960
+ else:
961
+ if target_list is not None:
962
+ return hubert_results, (self.speech_decoder_postnet(decoder_output) + (extra['attn'][0],))
963
+ else:
964
+ return self.speech_decoder_postnet(decoder_output) + (extra['attn'][0],)
965
+
966
+ def _integrate_with_speaker_cls(self, pad_input, encoder_input, encoder_padding_mask=None, cls_first=True):
967
+ """
968
+ encoder_input: [B, T, C]
969
+ encoder_padding_mask: [B, T]
970
+ """
971
+ if hasattr(self, "text_decoder_prenet"):
972
+ if isinstance(pad_input, tuple):
973
+ repeat_cls_vector, repeat_cls_mask = pad_input
974
+ else:
975
+ repeat_cls_vector, repeat_cls_mask, _ = self.text_decoder_prenet(pad_input)
976
+
977
+ if encoder_padding_mask is not None:
978
+ bsz = encoder_input.size(0)
979
+ tsz = encoder_input.size(1)
980
+ encoder_padding_mask = encoder_input.new_zeros((bsz, tsz)) == 1.0
981
+ if repeat_cls_mask is None:
982
+ mask_size = (encoder_padding_mask.size(0), 1)
983
+ mask_type = encoder_padding_mask.dtype
984
+ repeat_cls_mask = encoder_padding_mask.new_zeros(mask_size) == 1.0
985
+ ret_encoder_padding_mask = torch.cat([repeat_cls_mask, encoder_padding_mask], dim=1)
986
+
987
+ if cls_first:
988
+ ret_encoder_input = torch.cat([repeat_cls_vector, encoder_input], dim=1)
989
+ else:
990
+ ret_encoder_input = torch.cat([encoder_input, encoder_input[:,-1:,:]], dim=1)
991
+ mask_size = (encoder_padding_mask.size(0), 1)
992
+ mask_type = encoder_padding_mask.dtype
993
+ repeat_cls_mask_ = encoder_padding_mask.new_ones(mask_size) == 1.0
994
+ encoder_padding_mask_ = torch.cat([encoder_padding_mask, repeat_cls_mask_], dim=1)
995
+ indices = encoder_padding_mask.eq(False).float().sum(1).type(torch.int64).unsqueeze(1)
996
+ indices_mask = torch.zeros_like(ret_encoder_padding_mask).scatter(1, indices, 1.0)
997
+ ret_encoder_input = ret_encoder_input * (1.0 - encoder_padding_mask_.type(ret_encoder_input.dtype).unsqueeze(2)) \
998
+ + repeat_cls_vector * indices_mask.type(repeat_cls_vector.dtype).unsqueeze(2)
999
+
1000
+ return ret_encoder_input, ret_encoder_padding_mask
1001
+
1002
+ def _integrate_with_spk_embed(self, hs, spembs):
1003
+ """Integrate speaker embedding with hidden states.
1004
+ Args:
1005
+ hs (Tensor): Batch of hidden state sequences (B, Tmax, adim).
1006
+ spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim).
1007
+ Returns:
1008
+ Tensor: Batch of integrated hidden state sequences (B, Tmax, adim)
1009
+ """
1010
+ if self.spk_embed_integration_type == "add":
1011
+ # apply projection and then add to hidden states
1012
+ spembs = self.projection(F.normalize(spembs))
1013
+ hs = hs + spembs.unsqueeze(1)
1014
+ elif self.spk_embed_integration_type == "concat":
1015
+ # concat hidden states with spk embeds and then apply projection
1016
+ spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1)
1017
+ hs = self.projection(torch.cat([hs, spembs], dim=-1))
1018
+ else:
1019
+ raise NotImplementedError("support only add or concat.")
1020
+
1021
+ return hs
1022
+
1023
+ def load_state_dict(
1024
+ self,
1025
+ state_dict,
1026
+ strict=True,
1027
+ model_cfg=None,
1028
+ args=None,
1029
+ ):
1030
+ """NOT STRICT Copies parameters and buffers from *state_dict* into this module and
1031
+ its descendants.
1032
+
1033
+ Overrides the method in :class:`nn.Module`. Compared with that method
1034
+ this additionally "upgrades" *state_dicts* from old checkpoints.
1035
+ """
1036
+ # self.prune_modules(model_cfg.modules_filter)
1037
+ model_dict_size = self.text_decoder_postnet.output_projection.out_features
1038
+ ckpt_dict_size = state_dict["text_decoder_postnet.output_projection.weight"].size(0)
1039
+ if model_dict_size != ckpt_dict_size:
1040
+ # reset dictionary-related modules, such as embedding table and encoder ctc embed
1041
+ logger.warn(f"not equal dictionary between model and checkpoint: {model_dict_size} vs {ckpt_dict_size}")
1042
+ logger.info(f"reset model dictionary with size of {model_dict_size}")
1043
+ removed_keys = [
1044
+ key for key in state_dict.keys() if any(
1045
+ key.startswith(previ) for previ in [
1046
+ "encoder.proj", "text_encoder_prenet", "text_decoder_prenet", "text_decoder_postnet"
1047
+ ]
1048
+ )
1049
+ ]
1050
+ for key in removed_keys:
1051
+ state_dict.pop(key, None)
1052
+ logger.info(f"removed loaded checkpoint: {key}")
1053
+ for m in self._modules.keys():
1054
+ m_state_dict = {
1055
+ key.replace(f"{m}.", ""): value for key, value in state_dict.items() if key.startswith(f"{m}.")
1056
+ }
1057
+ if hasattr(self, m):
1058
+ self._modules[m].load_state_dict(m_state_dict, False)
1059
+ return self
1060
+
1061
+ def prune_modules(self, modules_filter=None):
1062
+ """Prune unused modules for specific tasks."""
1063
+ if modules_filter is None:
1064
+ return
1065
+ elif modules_filter == "s2c":
1066
+ if hasattr(self, "text_encoder_prenet"): del self.text_encoder_prenet
1067
+ if hasattr(self, "speech_decoder_prenet") and getattr(self.args, "sid_pooling_layer", None) != "decoder-las":
1068
+ del self.speech_decoder_prenet
1069
+ if hasattr(self, "speech_decoder_postnet"): del self.speech_decoder_postnet
1070
+ if hasattr(self, "text_decoder_postnet"): del self.text_decoder_postnet
1071
+ if hasattr(self, "speech_encoder_postnet"): del self.speech_encoder_postnet
1072
+ if hasattr(self.encoder, "proj"): self.encoder.proj = None
1073
+ if hasattr(self, "projection"): del self.projection
1074
+ if hasattr(self, "quantizer"): del self.quantizer
1075
+ if getattr(self.args, "sid_pooling_layer", "decoder").startswith("encoder") or getattr(self.args, "sid_decoder_speaker", False):
1076
+ if hasattr(self.decoder, "dropout_module"): del self.decoder.dropout_module
1077
+ if hasattr(self.decoder, "layers"): del self.decoder.layers
1078
+ if hasattr(self.decoder, "layer_norm"): del self.decoder.layer_norm
1079
+ if hasattr(self, "text_decoder_prenet"): del self.text_decoder_prenet
1080
+ elif modules_filter == "s2s":
1081
+ if hasattr(self, "speaker_decoder_postnet"): del self.speaker_decoder_postnet
1082
+ if hasattr(self, "text_encoder_prenet"): del self.text_encoder_prenet
1083
+ if hasattr(self, "text_decoder_prenet"): del self.text_decoder_prenet
1084
+ if hasattr(self, "text_decoder_postnet"): del self.text_decoder_postnet
1085
+ if hasattr(self, "speech_encoder_postnet"): del self.speech_encoder_postnet
1086
+ if hasattr(self.encoder, "proj"): self.encoder.proj = None
1087
+ if hasattr(self, "projection"): del self.projection
1088
+ if hasattr(self, "quantizer"): del self.quantizer
1089
+ elif modules_filter == "t2s":
1090
+ if hasattr(self, "speaker_decoder_postnet"): del self.speaker_decoder_postnet
1091
+ if hasattr(self, "speech_encoder_prenet"): del self.speech_encoder_prenet
1092
+ if hasattr(self, "text_decoder_prenet"): del self.text_decoder_prenet
1093
+ if hasattr(self, "text_decoder_postnet"): del self.text_decoder_postnet
1094
+ if hasattr(self, "speech_encoder_postnet"): del self.speech_encoder_postnet
1095
+ if hasattr(self.encoder, "proj"): self.encoder.proj = None
1096
+ if hasattr(self, "projection"): del self.projection
1097
+ if hasattr(self, "quantizer"): del self.quantizer
1098
+ elif modules_filter == "s3prl":
1099
+ # remain the encoder and the pre/post net
1100
+ if hasattr(self.decoder, "dropout_module"): del self.decoder.dropout_module
1101
+ if hasattr(self.decoder, "layers"): del self.decoder.layers
1102
+ if hasattr(self.decoder, "layer_norm"): del self.decoder.layer_norm
1103
+ if hasattr(self, "speaker_decoder_postnet"): del self.speaker_decoder_postnet
1104
+ if hasattr(self, "text_decoder_prenet"): del self.text_decoder_prenet
1105
+ if hasattr(self, "text_decoder_postnet"): del self.text_decoder_postnet
1106
+ if hasattr(self, "speech_decoder_prenet"): del self.speech_decoder_prenet
1107
+ if hasattr(self, "speech_decoder_postnet"): del self.speech_decoder_postnet
1108
+ if hasattr(self, "speech_encoder_postnet"): del self.speech_encoder_postnet
1109
+ if hasattr(self.encoder, "proj"): self.encoder.proj = None
1110
+ if hasattr(self, "projection"): del self.projection
1111
+ if hasattr(self, "quantizer"): del self.quantizer
1112
+
1113
+ def forward_encoder_torchscript(self, net_input: Dict[str, Tensor]):
1114
+ """A TorchScript-compatible version of forward.
1115
+
1116
+ Encoders which use additional arguments may want to override
1117
+ this method for TorchScript compatibility.
1118
+ """
1119
+ if torch.jit.is_scripting():
1120
+ return self.forward_encoder(
1121
+ source=net_input["source"],
1122
+ padding_mask=net_input["padding_mask"]
1123
+ )
1124
+ else:
1125
+ return self.forward_encoder_non_torchscript(net_input)
1126
+
1127
+ @torch.jit.unused
1128
+ def forward_encoder_non_torchscript(self, net_input: Dict[str, Tensor]):
1129
+ encoder_input = {
1130
+ k: v for k, v in net_input.items() if k != "prev_output_tokens" and k != "task_name"
1131
+ }
1132
+ return self.forward_encoder(**encoder_input)
1133
+
1134
+ def forward_encoder(self, source, padding_mask=None):
1135
+ # Encoder Prenet
1136
+ encoder_input, encoder_padding_mask = self.speech_encoder_prenet(source, padding_mask=padding_mask, mask=False)
1137
+
1138
+ # Encoder
1139
+ encoder_output = self.encoder(encoder_input, encoder_padding_mask)
1140
+
1141
+ return encoder_output
1142
+
1143
+ def forward_text_encoder(self, src_tokens):
1144
+ # Text Encoder Prenet
1145
+ encoder_input, encoder_padding_mask = self.text_encoder_prenet(src_tokens)
1146
+
1147
+ # Encoder
1148
+ encoder_output = self.encoder(encoder_input, encoder_padding_mask)
1149
+
1150
+ return encoder_output
1151
+
1152
+ def forward_decoder(self, tokens, encoder_out, incremental_state):
1153
+ # Decoder Prenet
1154
+ prev_output_tokens, tgt_mask, incremental_state = self.text_decoder_prenet(tokens, incremental_state)
1155
+
1156
+ # Decoder
1157
+ decoder_output, extra = self.decoder(
1158
+ prev_output_tokens,
1159
+ tgt_mask,
1160
+ encoder_out=encoder_out,
1161
+ incremental_state=incremental_state,
1162
+ )
1163
+
1164
+ # Decoder Postnet
1165
+ return self.text_decoder_postnet(decoder_output), extra
1166
+
1167
+ def set_num_updates(self, num_updates):
1168
+ """Set the number of parameters updates."""
1169
+ super().set_num_updates(num_updates)
1170
+ self.num_updates = num_updates
1171
+
1172
+ def generate_class(self, source, prev_output_tokens, **kwargs):
1173
+ encoder_out = self.forward_encoder(source, padding_mask=kwargs["padding_mask"])
1174
+
1175
+ prev_output_tokens, tgt_mask, _ = self.text_decoder_prenet(prev_output_tokens, {})
1176
+ prev_output_tokens = torch.zeros_like(prev_output_tokens) # s2c use zero vector as [CLS]
1177
+
1178
+ decoder_output, extra = self.decoder(
1179
+ prev_output_tokens,
1180
+ tgt_mask,
1181
+ encoder_out=encoder_out,
1182
+ )
1183
+
1184
+ decoder_out, embed = self.speaker_decoder_postnet(decoder_output.mean(1))
1185
+
1186
+ pred_class = decoder_out.argmax(1)
1187
+ return pred_class
1188
+
1189
+ def generate_speech(self, source=None, src_tokens=None, spkembs=None, **kwargs):
1190
+ assert source is not None or src_tokens is not None
1191
+
1192
+ threshold = kwargs.get("threshold", 0.5)
1193
+ minlenratio = kwargs.get("threshold", 0.0)
1194
+
1195
+ if source is None:
1196
+ assert src_tokens.size(0) == 1
1197
+ encoder_out = self.forward_text_encoder(src_tokens)
1198
+ maxlenratio = kwargs.get("threshold", 20.0)
1199
+ else:
1200
+ assert source.size(0) == 1
1201
+ encoder_out = self.forward_encoder(source, padding_mask=kwargs["padding_mask"])
1202
+ maxlenratio = kwargs.get("threshold", 10.0)
1203
+
1204
+ if spkembs is not None and self.spk_embed_integration_type != "pre":
1205
+ encoder_out["encoder_out"] = [self._integrate_with_spk_embed(
1206
+ encoder_out["encoder_out"][0].transpose(0, 1), spkembs
1207
+ ).transpose(0, 1)]
1208
+ spkembs = None
1209
+
1210
+ maxlen = int(encoder_out["encoder_out"][0].size(0) * maxlenratio / self.reduction_factor)
1211
+ minlen = int(encoder_out["encoder_out"][0].size(0) * minlenratio / self.reduction_factor)
1212
+
1213
+ idx = 0
1214
+ ys = encoder_out["encoder_out"][0].new_zeros(1, 1, self.speech_decoder_postnet.odim)
1215
+ outs, probs = [], []
1216
+
1217
+ # forward decoder step-by-step
1218
+ if isinstance(self.decoder, FairseqIncrementalDecoder):
1219
+ incremental_states = {}
1220
+ else:
1221
+ incremental_states = None
1222
+ attns = []
1223
+ while True:
1224
+ # update index
1225
+ idx += 1
1226
+ # calculate output and stop prob at idx-th step
1227
+ decoder_in, _ = self.speech_decoder_prenet(ys, spkembs=spkembs)
1228
+ z, extra = self.decoder(decoder_in[:,-1:], None, encoder_out, incremental_states, alignment_layer=-1)
1229
+ outs += [self.speech_decoder_postnet.feat_out(z[0, -1]).view(self.reduction_factor, self.speech_decoder_postnet.odim)] # [(r, odim), ...]
1230
+ probs += [torch.sigmoid(self.speech_decoder_postnet.prob_out(z[0, -1]))] # [(r), ...]
1231
+
1232
+ # update next inputs
1233
+ ys = torch.cat((ys, outs[-1][-1].view(1, 1, self.speech_decoder_postnet.odim)), dim=1) # (1, idx + 1, odim)
1234
+ attns.append(torch.stack([att_l[0] for att_l in extra['attn'][0]], dim=0))
1235
+ # check whether to finish generation
1236
+ if int(sum(probs[-1] >= threshold)) > 0 or idx >= maxlen:
1237
+ # check mininum length
1238
+ if idx < minlen:
1239
+ continue
1240
+ outs = (torch.cat(outs, dim=0).unsqueeze(0).transpose(1, 2)) # (L, odim) -> (1, L, odim) -> (1, odim, L)
1241
+ if self.speech_decoder_postnet.postnet is not None:
1242
+ outs = outs + self.speech_decoder_postnet.postnet(outs) # (1, odim, L)
1243
+ outs = outs.transpose(2, 1).squeeze(0) # (L, odim)
1244
+ probs = torch.cat(probs, dim=0)
1245
+ attn = torch.cat(attns, dim=2)
1246
+ break
1247
+
1248
+ if outs.size(0) == maxlen:
1249
+ logging.warning("output length reaches maximum length")
1250
+ return outs, probs, attn
1251
+
1252
+
1253
+ @register_model_architecture(model_name="artst_transformer", arch_name="artst_transformer")
1254
+ def base_architecture(args):
1255
+ # Transformer
1256
+ args.bert_init = getattr(args, "bert_init", False)
1257
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
1258
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 768 * 4)
1259
+ args.encoder_layers = getattr(args, "encoder_layers", 12)
1260
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
1261
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
1262
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
1263
+ args.decoder_ffn_embed_dim = getattr(
1264
+ args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
1265
+ )
1266
+ args.decoder_layers = getattr(args, "decoder_layers", 6)
1267
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 12)
1268
+ args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
1269
+ args.dropout = getattr(args, "dropout", 0.1)
1270
+ args.attention_dropout = getattr(args, "attention_dropout", args.dropout)
1271
+ args.activation_dropout = getattr(args, "activation_dropout", args.dropout)
1272
+ args.activation_fn = getattr(args, "activation_fn", "gelu")
1273
+ args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
1274
+ args.decoder_output_dim = getattr(
1275
+ args, "decoder_output_dim", args.decoder_embed_dim
1276
+ )
1277
+ args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
1278
+ args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
1279
+ args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
1280
+ args.max_text_positions = getattr(args, "max_text_positions", DEFAULT_MAX_TEXT_POSITIONS)
1281
+ args.max_speech_positions = getattr(args, "max_speech_positions", DEFAULT_MAX_SPEECH_POSITIONS)
1282
+
1283
+ # Espnet related, including prenet, postnet
1284
+ args.eprenet_conv_layers = getattr(args, "eprenet_conv_layers", 0)
1285
+ args.eprenet_conv_filts = getattr(args, "eprenet_conv_filts", 0)
1286
+ args.eprenet_conv_chans = getattr(args, "eprenet_conv_chans", 0)
1287
+ args.use_batch_norm = getattr(args, "use_batch_norm", True)
1288
+ args.eprenet_dropout_rate = getattr(args, "eprenet_dropout_rate", 0.0)
1289
+ args.enc_use_scaled_pos_enc = getattr(args, "enc_use_scaled_pos_enc", True)
1290
+ args.dec_use_scaled_pos_enc = getattr(args, "dec_use_scaled_pos_enc", True)
1291
+ args.postnet_layers = getattr(args, "postnet_layers", 5)
1292
+ args.postnet_chans = getattr(args, "postnet_chans", 256)
1293
+ args.postnet_filts = getattr(args, "postnet_filts", 5)
1294
+ args.postnet_dropout_rate = getattr(args, "postnet_dropout_rate", 0.5)
1295
+ args.dprenet_dropout_rate = getattr(args, "dprenet_dropout_rate", 0.5)
1296
+ args.dprenet_layers = getattr(args, "dprenet_layers", 2)
1297
+ args.dprenet_units = getattr(args, "dprenet_units", 256)
1298
+ args.initial_encoder_alpha = getattr(args, "initial_encoder_alpha", 1.0)
1299
+ args.initial_decoder_alpha = getattr(args, "initial_decoder_alpha", 1.0)
1300
+ args.spk_embed_integration_type = getattr(args, "spk_embed_integration_type", "pre")
1301
+ args.spk_embed_dim = getattr(args, "spk_embed_dim", 512)
1302
+ args.encoder_reduction_factor = getattr(args, "encoder_reduction_factor", 1)
1303
+ args.reduction_factor = getattr(args, "reduction_factor", 2)
1304
+ args.transformer_enc_positional_dropout_rate = getattr(args, "transformer_enc_positional_dropout_rate", 0.1)
1305
+ args.transformer_dec_positional_dropout_rate = getattr(args, "transformer_dec_positional_dropout_rate", 0.1)
1306
+ args.layer_norm_eps = getattr(args, "layer_norm_eps", 1e-5)
1307
+ args.no_scale_embedding = getattr(args, "no_scale_embedding", True)
1308
+ # Convolutional subsampler
1309
+ args.encoder_speech_prenet = getattr(args, "encoder_speech_prenet", "conv")
1310
+ args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5")
1311
+ args.conv_channels = getattr(args, "conv_channels", 1024)
1312
+ args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
1313
+
1314
+ args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
1315
+ args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
1316
+ args.no_token_positional_embeddings = getattr(
1317
+ args, "no_token_positional_embeddings", False
1318
+ )
1319
+ args.adaptive_input = getattr(args, "adaptive_input", False)
1320
+ args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
1321
+ args.share_input_output_embed = getattr(args, "share_input_output_embed", False)
1322
+ args.share_ctc_embed = getattr(args, "share_ctc_embed", False)
1323
+ args.freeze_encoder_updates = getattr(args, "freeze_encoder_updates", 0)
1324
+ args.freeze_decoder_updates = getattr(args, "freeze_decoder_updates", 0)
1325
+ args.no_freeze_encoder_layer = getattr(args, "no_freeze_encoder_layer", None)
1326
+
1327
+ ## sid
1328
+ args.sid_embed_dim = getattr(args, "sid_embed_dim", 128)
1329
+ args.sid_pooling_layer = getattr(args, "sid_pooling_layer", "decoder")
1330
+ args.softmax_scale = getattr(args, "softmax_scale", 1)
1331
+ args.softmax_margin = getattr(args, "softmax_margin", 0)
1332
+ args.softmax_easy_margin = getattr(args, "softmax_easy_margin", False)
1333
+ args.modules_filter = getattr(args, "modules_filter", None)
1334
+
1335
+ ## Hubert
1336
+ args.conv_pos = getattr(args, "conv_pos", 128)
1337
+ args.conv_pos_groups = getattr(args, "conv_pos_groups", 16)
1338
+ args.target_glu = getattr(args, "target_glu", False)
1339
+ args.logit_temp = getattr(args, "logit_temp", 0.1)
1340
+ args.final_dim = getattr(args, "final_dim", 256)
1341
+ args.untie_final_proj = getattr(args, "untie_final_proj", True)
1342
+ args.feature_grad_mult = getattr(args, "feature_grad_mult", 0.1)
1343
+ args.use_sent_enc_layer = getattr(args, "use_sent_enc_layer", True)
1344
+ # hubert feature extractor
1345
+ args.extractor_mode = getattr(args, "extractor_mode", "default")
1346
+ args.conv_feature_layers = getattr(args, "conv_feature_layers", "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2")
1347
+ args.conv_bias = getattr(args, "conv_bias", False)
1348
+ # mask
1349
+ args.hubert_mask_length = getattr(args, "hubert_mask_length", 10)
1350
+ args.mask_prob = getattr(args, "mask_prob", 0.0)
1351
+ args.mask_selection = getattr(args, "mask_selection", "static")
1352
+ args.mask_other = getattr(args, "mask_other", 0)
1353
+ args.no_mask_overlap = getattr(args, "no_mask_overlap", False)
1354
+ args.mask_min_space = getattr(args, "mask_min_space", 1)
1355
+ # channel mask
1356
+ args.mask_channel_length = getattr(args, "mask_channel_length", 10)
1357
+ args.mask_channel_prob = getattr(args, "mask_channel_prob", 0.0)
1358
+ args.mask_channel_selection = getattr(args, "mask_channel_selection", "static")
1359
+ args.mask_channel_other = getattr(args, "mask_channel_other", 0)
1360
+ args.no_mask_channel_overlap = getattr(args, "no_mask_channel_overlap", False)
1361
+ args.mask_channel_min_space = getattr(args, "mask_channel_min_space", 1)
1362
+ # loss computation
1363
+ args.skip_masked = getattr(args, "skip_masked", False)
1364
+ args.skip_nomask = getattr(args, "skip_nomask", False)
1365
+ # conv Pos
1366
+ args.use_conv_pos = getattr(args, "use_conv_pos", False)
1367
+ args.use_sinc_pos = getattr(args, "use_sinc_pos", False)
1368
+
1369
+ # codebook
1370
+ args.use_codebook = getattr(args, "use_codebook", False)
1371
+ args.latent_vars = getattr(args, "latent_vars", 100)
1372
+ args.latent_groups = getattr(args, "latent_groups", 2)
1373
+ args.latent_dim = getattr(args, "latent_dim", 0)
1374
+ args.latent_temp = getattr(args, "latent_temp", (2, 0.5, 0.999995))
1375
+ args.quantizer_depth = getattr(args, "quantizer_depth", 1)
1376
+ args.quantizer_factor = getattr(args, "quantizer_factor", 3)
1377
+ args.codebook_prob = getattr(args, "codebook_prob", 0.5)
1378
+
1379
+ # Relative pos embed
1380
+ args.relative_position_embedding = getattr(args, "relative_position_embedding", False)
1381
+ args.num_buckets = getattr(args, "num_buckets", 320)
1382
+ args.max_distance = getattr(args, "max_distance", 1280)
1383
+ args.encoder_max_relative_position = getattr(args, "encoder_max_relative_position", 160)
1384
+ args.decoder_max_relative_position = getattr(args, "decoder_max_relative_position", 160)
1385
+
1386
+ @register_model_architecture("artst_transformer", "artst_transformer_base")
1387
+ def artst_transformer_base(args):
1388
+ args.use_conv_pos = getattr(args, "use_conv_pos", True)
1389
+ args.use_sinc_pos = getattr(args, "use_sinc_pos", True)
1390
+ args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
1391
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
1392
+ args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
1393
+ args.layer_norm_first = getattr(args, "layer_norm_first", False)
1394
+ args.relative_position_embedding = getattr(args, "relative_position_embedding", True)
1395
+ args.dropout = getattr(args, "dropout", 0.1)
1396
+ args.activation_dropout = getattr(args, "activation_dropout", 0.0)
1397
+ args.attention_dropout = getattr(args, "attention_dropout", 0.1)
1398
+ args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.05)
1399
+ args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.05)
1400
+ args.mask_prob = getattr(args, "mask_prob", 0.80)
1401
+ base_architecture(args)
1402
+
1403
+ @register_model_architecture("artst_transformer", "artst_transformer_large")
1404
+ def artst_transformer_large(args):
1405
+ args.use_conv_pos = getattr(args, "use_conv_pos", True)
1406
+ args.use_sinc_pos = getattr(args, "use_sinc_pos", True)
1407
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
1408
+ args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
1409
+ args.layer_norm_first = getattr(args, "layer_norm_first", True)
1410
+ args.relative_position_embedding = getattr(args, "relative_position_embedding", True)
1411
+ args.dropout = getattr(args, "dropout", 0.0)
1412
+ args.activation_dropout = getattr(args, "activation_dropout", 0.0)
1413
+ args.attention_dropout = getattr(args, "attention_dropout", 0.0)
1414
+ args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
1415
+ args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
1416
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
1417
+ args.encoder_layers = getattr(args, "encoder_layers", 24)
1418
+ args.decoder_layers = getattr(args, "decoder_layers", 6)
1419
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
1420
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
1421
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
1422
+ args.feature_grad_mult = getattr(args, "feature_grad_mult", 1.0)
1423
+ args.extractor_mode = getattr(args, "extractor_mode", "layer_norm")
1424
+ args.final_dim = getattr(args, "final_dim", 768)
1425
+ args.mask_prob = getattr(args, "mask_prob", 0.80)
1426
+ base_architecture(args)
1427
+
1428
+ @register_model_architecture("artst_transformer", "artst_transformer_base_asr")
1429
+ def artst_transformer_base_asr(args):
1430
+ args.use_conv_pos = getattr(args, "use_conv_pos", True)
1431
+ args.use_sinc_pos = getattr(args, "use_sinc_pos", True)
1432
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
1433
+ args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
1434
+ args.layer_norm_first = getattr(args, "layer_norm_first", False)
1435
+ args.relative_position_embedding = getattr(args, "relative_position_embedding", True)
1436
+ args.dropout = getattr(args, "dropout", 0.1)
1437
+ args.activation_dropout = getattr(args, "activation_dropout", 0.1)
1438
+ args.attention_dropout = getattr(args, "attention_dropout", 0.1)
1439
+ args.feature_grad_mult = getattr(args, "feature_grad_mult", 0.0)
1440
+ args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.1)
1441
+ args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.1)
1442
+ args.mask_prob = getattr(args, "mask_prob", 0.75)
1443
+ args.mask_selection = getattr(args, "mask_selection", "static")
1444
+ args.mask_channel_length = getattr(args, "mask_channel_length", 64)
1445
+ args.mask_channel_prob = getattr(args, "mask_channel_prob", 0.5)
1446
+ args.mask_channel_selection = getattr(args, "mask_channel_selection", "static")
1447
+ args.max_text_positions = getattr(args, "max_text_positions", 600)
1448
+ base_architecture(args)
artst/models/modules/__init__.py ADDED
File without changes
artst/models/modules/decoder.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+
5
+ # Based on speecht5, fairseq and espnet code bases
6
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
7
+ # --------------------------------------------------------
8
+
9
+ from typing import Any, Dict, List, Optional
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from fairseq import utils
14
+ from fairseq.distributed import fsdp_wrap
15
+ from fairseq.models import (
16
+ FairseqIncrementalDecoder,
17
+ )
18
+ from fairseq.modules import (
19
+ FairseqDropout,
20
+ LayerDropModuleList,
21
+ LayerNorm,
22
+ )
23
+ from fairseq.modules.checkpoint_activations import checkpoint_wrapper
24
+ from torch import Tensor
25
+
26
+ from .encoder import RelativePositionalEncoding
27
+ from .transformer_layer import TransformerDecoderLayer
28
+
29
+ DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8)
30
+
31
+
32
+ class TransformerDecoder(FairseqIncrementalDecoder):
33
+ """
34
+ Transformer decoder consisting of *args.decoder_layers* layers. Each layer
35
+ is a :class:`TransformerDecoderLayer`.
36
+
37
+ Args:
38
+ args (argparse.Namespace): parsed command-line arguments
39
+ dictionary (~fairseq.data.Dictionary): decoding dictionary
40
+ embed_tokens (torch.nn.Embedding): output embedding
41
+ no_encoder_attn (bool, optional): whether to attend to encoder outputs
42
+ (default: False).
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ args,
48
+ no_encoder_attn=False,
49
+ ):
50
+ self.args = args
51
+ super().__init__(None)
52
+ self.register_buffer("version", torch.Tensor([3]))
53
+ self._future_mask = torch.empty(0)
54
+
55
+ self.dropout_module = FairseqDropout(
56
+ args.dropout, module_name=self.__class__.__name__
57
+ )
58
+ self.decoder_layerdrop = args.decoder_layerdrop
59
+ # self.max_s_positions = args.max_target_positions
60
+ export = getattr(args, "export", False)
61
+ self.cross_self_attention = getattr(args, "cross_self_attention", False)
62
+
63
+ if self.decoder_layerdrop > 0.0:
64
+ self.layers = LayerDropModuleList(p=self.decoder_layerdrop)
65
+ else:
66
+ self.layers = nn.ModuleList([])
67
+ self.layers.extend(
68
+ [
69
+ self.build_decoder_layer(args, no_encoder_attn)
70
+ for _ in range(args.decoder_layers)
71
+ ]
72
+ )
73
+ self.num_layers = len(self.layers)
74
+
75
+ if args.decoder_normalize_before and not getattr(
76
+ args, "no_decoder_final_norm", False
77
+ ):
78
+ self.layer_norm = LayerNorm(args.decoder_embed_dim, eps=args.layer_norm_eps, export=export)
79
+ else:
80
+ self.layer_norm = None
81
+
82
+ if args.relative_position_embedding:
83
+ self.pos_emb = RelativePositionalEncoding(args.encoder_embed_dim//args.encoder_attention_heads, args.decoder_max_relative_position)
84
+
85
+ def build_decoder_layer(self, args, no_encoder_attn=False):
86
+ layer = TransformerDecoderLayer(args, no_encoder_attn=no_encoder_attn, has_relative_attention_bias=args.relative_position_embedding)
87
+ checkpoint = getattr(args, "checkpoint_activations", False)
88
+ if checkpoint:
89
+ offload_to_cpu = getattr(args, "offload_activations", False)
90
+ layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
91
+ # if we are checkpointing, enforce that FSDP always wraps the
92
+ # checkpointed layer, regardless of layer size
93
+ min_params_to_wrap = (
94
+ getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP)
95
+ if not checkpoint
96
+ else 0
97
+ )
98
+ layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
99
+ return layer
100
+
101
+ def forward(
102
+ self,
103
+ prev_output_tokens,
104
+ tgt_mask,
105
+ encoder_out: Optional[Dict[str, List[Tensor]]] = None,
106
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
107
+ full_context_alignment: bool = False,
108
+ alignment_layer: Optional[int] = None,
109
+ alignment_heads: Optional[int] = None,
110
+ src_lengths: Optional[Any] = None,
111
+ return_all_hiddens: bool = False,
112
+ ):
113
+ """
114
+ Args:
115
+ prev_output_tokens (LongTensor): previous decoder outputs of shape
116
+ `(batch, tgt_len)`, for teacher forcing
117
+ encoder_out (optional): output from the encoder, used for
118
+ encoder-side attention, should be of size T x B x C
119
+ incremental_state (dict): dictionary used for storing state during
120
+ :ref:`Incremental decoding`
121
+ features_only (bool, optional): only return features without
122
+ applying output layer (default: False).
123
+ full_context_alignment (bool, optional): don't apply
124
+ auto-regressive mask to self-attention (default: False).
125
+
126
+ Returns:
127
+ tuple:
128
+ - the decoder's output of shape `(batch, tgt_len, vocab)`
129
+ - a dictionary with any model-specific outputs
130
+ """
131
+
132
+ x, extra = self.extract_features(
133
+ prev_output_tokens,
134
+ tgt_mask,
135
+ encoder_out=encoder_out,
136
+ incremental_state=incremental_state,
137
+ full_context_alignment=full_context_alignment,
138
+ alignment_layer=alignment_layer,
139
+ alignment_heads=alignment_heads,
140
+ )
141
+
142
+ return x, extra
143
+
144
+ def extract_features(
145
+ self,
146
+ prev_output_tokens,
147
+ tgt_mask,
148
+ encoder_out: Optional[Dict[str, List[Tensor]]],
149
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
150
+ full_context_alignment: bool = False,
151
+ alignment_layer: Optional[int] = None,
152
+ alignment_heads: Optional[int] = None,
153
+ ):
154
+ return self.extract_features_scriptable(
155
+ prev_output_tokens,
156
+ tgt_mask,
157
+ encoder_out,
158
+ incremental_state,
159
+ full_context_alignment,
160
+ alignment_layer,
161
+ alignment_heads,
162
+ )
163
+
164
+ """
165
+ A scriptable subclass of this class has an extract_features method and calls
166
+ super().extract_features, but super() is not supported in torchscript. A copy of
167
+ this function is made to be used in the subclass instead.
168
+ """
169
+
170
+ def extract_features_scriptable(
171
+ self,
172
+ prev_output_tokens,
173
+ tgt_mask,
174
+ encoder_out: Optional[Dict[str, List[Tensor]]],
175
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
176
+ full_context_alignment: bool = False,
177
+ alignment_layer: Optional[int] = None,
178
+ alignment_heads: Optional[int] = None,
179
+ ):
180
+ """
181
+ Similar to *forward* but only return features.
182
+
183
+ Includes several features from "Jointly Learning to Align and
184
+ Translate with Transformer Models" (Garg et al., EMNLP 2019).
185
+
186
+ Args:
187
+ full_context_alignment (bool, optional): don't apply
188
+ auto-regressive mask to self-attention (default: False).
189
+ alignment_layer (int, optional): return mean alignment over
190
+ heads at this layer (default: last layer).
191
+ alignment_heads (int, optional): only average alignment over
192
+ this many heads (default: all heads).
193
+
194
+ Returns:
195
+ tuple:
196
+ - the decoder's features of shape `(batch, tgt_len, embed_dim)`
197
+ - a dictionary with any model-specific outputs
198
+ """
199
+ bs = prev_output_tokens.size(0)
200
+ if alignment_layer is None:
201
+ alignment_layer = self.num_layers - 1
202
+
203
+ enc: Optional[Tensor] = None
204
+ padding_mask: Optional[Tensor] = None
205
+ if encoder_out is not None and len(encoder_out["encoder_out"]) > 0:
206
+ enc = encoder_out["encoder_out"][0]
207
+ assert (
208
+ enc.size()[1] == bs
209
+ ), f"Expected enc.shape == (t, {bs}, c) got {enc.shape}"
210
+ if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0:
211
+ padding_mask = encoder_out["encoder_padding_mask"][0]
212
+
213
+ # B x T x C -> T x B x C
214
+ x = prev_output_tokens.transpose(0, 1)
215
+
216
+ self_attn_padding_mask: Optional[Tensor] = None
217
+ if self.cross_self_attention or tgt_mask is not None:
218
+ self_attn_padding_mask = tgt_mask
219
+
220
+ ## relative position embedding
221
+ if self.args.relative_position_embedding:
222
+ x_len = x.shape[0]
223
+ pos_seq = torch.arange(0, x_len).long().to(x.device)
224
+ pos_seq = pos_seq[:, None] - pos_seq[None, :]
225
+ pos_k, pos_v = self.pos_emb(pos_seq)
226
+ else:
227
+ pos_k = None
228
+
229
+ # decoder layers
230
+ attn_list = []
231
+ attn: Optional[Tensor] = None
232
+ inner_states: List[Optional[Tensor]] = [x]
233
+ for idx, layer in enumerate(self.layers):
234
+ if incremental_state is None and not full_context_alignment:
235
+ self_attn_mask = self.buffered_future_mask(x)
236
+ else:
237
+ self_attn_mask = None
238
+
239
+ x, layer_attn, _ = layer(
240
+ x,
241
+ enc,
242
+ padding_mask,
243
+ incremental_state,
244
+ self_attn_mask=self_attn_mask,
245
+ self_attn_padding_mask=self_attn_padding_mask,
246
+ need_attn=bool((idx == alignment_layer or alignment_layer == -1)),
247
+ need_head_weights=bool((idx == alignment_layer or alignment_layer == -1)),
248
+ pos_bias=pos_k,
249
+ )
250
+ inner_states.append(x)
251
+ if layer_attn is not None and (idx == alignment_layer or alignment_layer == -1):
252
+ attn = layer_attn.float().to(x)
253
+ attn_list.append(attn.transpose(0, 1))
254
+
255
+ if attn is not None and len(attn_list) == 1:
256
+ if alignment_heads is not None:
257
+ attn = attn[:alignment_heads]
258
+
259
+ # average probabilities over heads
260
+ attn = attn.mean(dim=0)
261
+
262
+ if self.layer_norm is not None:
263
+ x = self.layer_norm(x)
264
+
265
+ # T x B x C -> B x T x C
266
+ x = x.transpose(0, 1)
267
+
268
+ return x, {"attn": [attn if len(attn_list) <= 1 else attn_list], "inner_states": inner_states}
269
+
270
+ # def max_positions(self):
271
+ # """Maximum output length supported by the decoder."""
272
+ # return self.max_target_positions
273
+
274
+ def buffered_future_mask(self, tensor):
275
+ dim = tensor.size(0)
276
+ # self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround.
277
+ if (
278
+ self._future_mask.size(0) == 0
279
+ or (not self._future_mask.device == tensor.device)
280
+ or self._future_mask.size(0) < dim
281
+ ):
282
+ self._future_mask = torch.triu(
283
+ utils.fill_with_neg_inf(torch.zeros([dim, dim], device=tensor.device)), 1,
284
+ )
285
+ else:
286
+ self._future_mask = self._future_mask.to(tensor)
287
+ return self._future_mask[:dim, :dim]
288
+
289
+ def upgrade_state_dict_named(self, state_dict, name):
290
+ """Upgrade a (possibly old) state dict for new versions of fairseq."""
291
+ for i in range(self.num_layers):
292
+ # update layer norms
293
+ layer_norm_map = {
294
+ "0": "self_attn_layer_norm",
295
+ "1": "encoder_attn_layer_norm",
296
+ "2": "final_layer_norm",
297
+ }
298
+ for old, new in layer_norm_map.items():
299
+ for m in ("weight", "bias"):
300
+ k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m)
301
+ if k in state_dict:
302
+ state_dict[
303
+ "{}.layers.{}.{}.{}".format(name, i, new, m)
304
+ ] = state_dict[k]
305
+ del state_dict[k]
306
+
307
+ version_key = "{}.version".format(name)
308
+ if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2:
309
+ # earlier checkpoints did not normalize after the stack of layers
310
+ self.layer_norm = None
311
+ self.normalize = False
312
+ state_dict[version_key] = torch.Tensor([1])
313
+
314
+ return state_dict
315
+
316
+ def set_num_updates(self, num_updates):
317
+ """State from trainer to pass along to model at every update."""
318
+
319
+ def _apply(m):
320
+ if hasattr(m, "set_num_updates") and m != self:
321
+ m.set_num_updates(num_updates)
322
+
323
+ self.apply(_apply)
artst/models/modules/encoder.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+
5
+ # Based on speecht5, fairseq and espnet code bases
6
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
7
+ # --------------------------------------------------------
8
+
9
+ from typing import Dict, List
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import contextlib
15
+ from fairseq import utils
16
+ from fairseq.models import (
17
+ FairseqEncoder,
18
+ )
19
+ from fairseq.modules import (
20
+ FairseqDropout,
21
+ LayerNorm,
22
+ TransformerEncoderLayer,
23
+ )
24
+ from torch import Tensor
25
+ from .transformer_layer import TransformerSentenceEncoderLayer
26
+
27
+
28
+
29
+ DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8)
30
+
31
+ def Linear(in_features, out_features, bias=True):
32
+ m = nn.Linear(in_features, out_features, bias)
33
+ nn.init.xavier_uniform_(m.weight)
34
+ if bias:
35
+ nn.init.constant_(m.bias, 0.0)
36
+ return m
37
+
38
+
39
+ class RelativePositionalEncoding(torch.nn.Module):
40
+ def __init__(self, d_model, maxlen=1000, embed_v=False):
41
+ super(RelativePositionalEncoding, self).__init__()
42
+
43
+ self.d_model = d_model
44
+ self.maxlen = maxlen
45
+ self.pe_k = torch.nn.Embedding(2*maxlen, d_model)
46
+ if embed_v:
47
+ self.pe_v = torch.nn.Embedding(2*maxlen, d_model)
48
+ self.embed_v = embed_v
49
+
50
+
51
+ def forward(self, pos_seq):
52
+ pos_seq[pos_seq < -self.maxlen] = -self.maxlen
53
+ pos_seq[pos_seq >= self.maxlen] = self.maxlen - 1
54
+ pos_seq = pos_seq + self.maxlen
55
+ if self.embed_v:
56
+ return self.pe_k(pos_seq), self.pe_v(pos_seq)
57
+ else:
58
+ return self.pe_k(pos_seq), None
59
+
60
+ class TransformerEncoder(FairseqEncoder):
61
+ """
62
+ Transformer encoder consisting of *args.encoder_layers* layers. Each layer
63
+ is a :class:`TransformerEncoderLayer`.
64
+
65
+ Args:
66
+ args (argparse.Namespace): parsed command-line arguments
67
+ dictionary (~fairseq.data.Dictionary): encoding dictionary
68
+ embed_tokens (torch.nn.Embedding): input embedding
69
+ """
70
+
71
+ def __init__(self, args, tgt_dict=None, embed_tokens=None):
72
+ self.args = args
73
+ super().__init__(None)
74
+ self.register_buffer("version", torch.Tensor([3]))
75
+
76
+ self.dropout_module = FairseqDropout(
77
+ args.dropout, module_name=self.__class__.__name__
78
+ )
79
+ self.encoder_layerdrop = args.encoder_layerdrop
80
+ self.freeze_encoder_updates = args.freeze_encoder_updates
81
+ if args.no_freeze_encoder_layer is not None:
82
+ self.no_freeze_encoder_layer = eval(args.no_freeze_encoder_layer)
83
+ else:
84
+ self.no_freeze_encoder_layer = None
85
+ self.num_updates = 0
86
+ export = getattr(args, "export", False)
87
+
88
+ self.layers = nn.ModuleList([])
89
+ self.layers.extend(
90
+ [self.build_encoder_layer(args) for i in range(args.encoder_layers)]
91
+ )
92
+ self.num_layers = len(self.layers)
93
+
94
+ self.use_sent_enc_layer = args.use_sent_enc_layer
95
+ self.unb_enc_layer = getattr(args, "unb_enc_layer", -1)
96
+
97
+ self.layer_norm_first = args.layer_norm_first
98
+ self.layer_norm = LayerNorm(args.encoder_embed_dim, eps=args.layer_norm_eps, export=export)
99
+
100
+ if args.share_ctc_embed and embed_tokens is not None:
101
+ self.proj = nn.Linear(
102
+ embed_tokens.weight.shape[1],
103
+ embed_tokens.weight.shape[0],
104
+ bias=False,
105
+ )
106
+ self.proj.weight = embed_tokens.weight
107
+ elif tgt_dict is not None:
108
+ self.proj = Linear(args.encoder_embed_dim, len(tgt_dict))
109
+ else:
110
+ self.proj = None
111
+
112
+ if args.relative_position_embedding:
113
+ self.pos_emb = RelativePositionalEncoding(args.encoder_embed_dim//args.encoder_attention_heads, args.encoder_max_relative_position)
114
+
115
+
116
+ def build_encoder_layer(self, args):
117
+ if args.use_sent_enc_layer:
118
+ layer = TransformerSentenceEncoderLayer(
119
+ embedding_dim=args.encoder_embed_dim,
120
+ ffn_embedding_dim=args.encoder_ffn_embed_dim,
121
+ num_attention_heads=args.encoder_attention_heads,
122
+ dropout=args.dropout,
123
+ attention_dropout=args.attention_dropout,
124
+ activation_dropout=args.activation_dropout,
125
+ activation_fn=args.activation_fn,
126
+ layer_norm_first=args.layer_norm_first,
127
+ has_relative_attention_bias=args.relative_position_embedding,
128
+ )
129
+ else:
130
+ layer = TransformerEncoderLayer(args)
131
+ return layer
132
+
133
+ def forward(
134
+ self,
135
+ encoder_in,
136
+ encoder_padding_mask,
137
+ return_all_hiddens: bool = False,
138
+ tgt_layer=None,
139
+ ):
140
+ """
141
+ Args:
142
+ src_tokens (LongTensor): tokens in the source language of shape
143
+ `(batch, src_len)`
144
+ src_lengths (torch.LongTensor): lengths of each source sentence of
145
+ shape `(batch)`
146
+ return_all_hiddens (bool, optional): also return all of the
147
+ intermediate hidden states (default: False).
148
+ token_embeddings (torch.Tensor, optional): precomputed embeddings
149
+ default `None` will recompute embeddings
150
+
151
+ Returns:
152
+ dict:
153
+ - **encoder_out** (Tensor): the last encoder layer's output of
154
+ shape `(src_len, batch, embed_dim)`
155
+ - **encoder_padding_mask** (ByteTensor): the positions of
156
+ padding elements of shape `(batch, src_len)`
157
+ - **encoder_embedding** (Tensor): the (scaled) embedding lookup
158
+ of shape `(batch, src_len, embed_dim)`
159
+ - **encoder_states** (List[Tensor]): all intermediate
160
+ hidden states of shape `(src_len, batch, embed_dim)`.
161
+ Only populated if *return_all_hiddens* is True.
162
+ """
163
+ if self.no_freeze_encoder_layer is None:
164
+ ft = self.freeze_encoder_updates <= self.num_updates
165
+ else:
166
+ ft = True
167
+ with torch.no_grad() if not ft else contextlib.ExitStack():
168
+ encoder_out = self.forward_scriptable(
169
+ encoder_in, encoder_padding_mask, return_all_hiddens, tgt_layer=tgt_layer,
170
+ )
171
+
172
+ # CTC and bert
173
+ if self.proj:
174
+ x_for_ctc = self.proj(self.dropout_module(encoder_out["encoder_out"][0]))
175
+ else:
176
+ x_for_ctc = None
177
+
178
+ encoder_out["encoder_out_for_ctc"] = [x_for_ctc] # T x B x C
179
+
180
+ return encoder_out
181
+
182
+ # TorchScript doesn't support super() method so that the scriptable Subclass
183
+ # can't access the base class model in Torchscript.
184
+ # Current workaround is to add a helper function with different name and
185
+ # call the helper function from scriptable Subclass.
186
+ def forward_scriptable(
187
+ self,
188
+ encoder_in,
189
+ encoder_padding_mask,
190
+ return_all_hiddens: bool = False,
191
+ tgt_layer=None,
192
+ ):
193
+ """
194
+ Args:
195
+ src_tokens (LongTensor): tokens in the source language of shape
196
+ `(batch, src_len)`
197
+ src_lengths (torch.LongTensor): lengths of each source sentence of
198
+ shape `(batch)`
199
+ return_all_hiddens (bool, optional): also return all of the
200
+ intermediate hidden states (default: False).
201
+ token_embeddings (torch.Tensor, optional): precomputed embeddings
202
+ default `None` will recompute embeddings
203
+
204
+ Returns:
205
+ dict:
206
+ - **encoder_out** (Tensor): the last encoder layer's output of
207
+ shape `(src_len, batch, embed_dim)`
208
+ - **encoder_padding_mask** (ByteTensor): the positions of
209
+ padding elements of shape `(batch, src_len)`
210
+ - **encoder_embedding** (Tensor): the (scaled) embedding lookup
211
+ of shape `(batch, src_len, embed_dim)`
212
+ - **encoder_states** (List[Tensor]): all intermediate
213
+ hidden states of shape `(src_len, batch, embed_dim)`.
214
+ Only populated if *return_all_hiddens* is True.
215
+ """
216
+ if self.no_freeze_encoder_layer is not None:
217
+ ft = self.freeze_encoder_updates <= self.num_updates
218
+ else:
219
+ ft = True
220
+ with torch.no_grad() if not ft else contextlib.ExitStack():
221
+ # compute padding mask
222
+ if not self.use_sent_enc_layer:
223
+ has_pads = encoder_in.device.type == "xla" or encoder_padding_mask.any()
224
+
225
+ if not self.layer_norm_first:
226
+ encoder_in = self.layer_norm(encoder_in)
227
+
228
+ encoder_in = self.dropout_module(encoder_in)
229
+
230
+ # B x T x C -> T x B x C
231
+ x = encoder_in.transpose(0, 1)
232
+
233
+ encoder_states = []
234
+
235
+ if return_all_hiddens:
236
+ encoder_states.append(x)
237
+
238
+ ## relative position embedding
239
+ if self.args.relative_position_embedding:
240
+ x_len = x.shape[0]
241
+ pos_seq = torch.arange(0, x_len).long().to(x.device)
242
+ pos_seq = pos_seq[:, None] - pos_seq[None, :]
243
+ pos_k, pos_v = self.pos_emb(pos_seq)
244
+ else:
245
+ pos_k = None
246
+
247
+ # encoder layers
248
+ r = None
249
+ d = None
250
+ for i, layer in enumerate(self.layers):
251
+ dropout_probability = np.random.random()
252
+
253
+ with torch.no_grad() if (not ft) and i not in self.no_freeze_encoder_layer else contextlib.ExitStack():
254
+ if not self.training or (dropout_probability > self.encoder_layerdrop) or i == self.unb_enc_layer:
255
+ if self.use_sent_enc_layer:
256
+ x, _ = layer(x, self_attn_padding_mask=encoder_padding_mask, self_attn_mask=None, need_weights=False, pos_bias=pos_k)
257
+ # x, _ = layer(x, self_attn_padding_mask=encoder_padding_mask, need_weights=False, pos_bias=pos_k)
258
+ else:
259
+ x = layer(x, encoder_padding_mask=encoder_padding_mask if has_pads else None, attn_mask=None)
260
+ # x = layer(x, encoder_padding_mask=encoder_padding_mask if has_pads else None)
261
+ if i == self.unb_enc_layer:
262
+ d = x
263
+
264
+ if i == tgt_layer:
265
+ r = x
266
+ break
267
+
268
+ if return_all_hiddens:
269
+ assert encoder_states is not None
270
+ encoder_states.append(x)
271
+
272
+ with torch.no_grad() if not ft else contextlib.ExitStack():
273
+ # Finally T x B x C
274
+ if self.layer_norm_first:
275
+ x = self.layer_norm(x.transpose(0, 1)).transpose(0, 1)
276
+
277
+ if r is not None:
278
+ x = r
279
+
280
+ # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
281
+ # `forward` so we use a dictionary instead.
282
+ # TorchScript does not support mixed values so the values are all lists.
283
+ # The empty list is equivalent to None.
284
+ return {
285
+ "encoder_out": [x], # T x B x C
286
+ "encoder_padding_mask": [encoder_padding_mask], # B x T
287
+ "encoder_states": encoder_states, # List[T x B x C]
288
+ "src_tokens": [],
289
+ "decoder_input": [d],
290
+ }
291
+
292
+ @torch.jit.export
293
+ def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order):
294
+ """
295
+ Reorder encoder output according to *new_order*.
296
+
297
+ Args:
298
+ encoder_out: output from the ``forward()`` method
299
+ new_order (LongTensor): desired order
300
+
301
+ Returns:
302
+ *encoder_out* rearranged according to *new_order*
303
+ """
304
+ if len(encoder_out["encoder_out"]) == 0:
305
+ new_encoder_out = []
306
+ else:
307
+ new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)]
308
+
309
+ if len(encoder_out["encoder_out_for_ctc"]) == 0:
310
+ new_x_for_ctc = []
311
+ else:
312
+ new_x_for_ctc = [encoder_out["encoder_out_for_ctc"][0].index_select(1, new_order)]
313
+
314
+ if len(encoder_out["encoder_padding_mask"]) == 0:
315
+ new_encoder_padding_mask = []
316
+ else:
317
+ new_encoder_padding_mask = [
318
+ encoder_out["encoder_padding_mask"][0].index_select(0, new_order)
319
+ ]
320
+
321
+ if len(encoder_out["src_tokens"]) == 0:
322
+ src_tokens = []
323
+ else:
324
+ src_tokens = [(encoder_out["src_tokens"][0]).index_select(0, new_order)]
325
+
326
+ if len(encoder_out["decoder_input"]) == 0 or encoder_out["decoder_input"][0] is None:
327
+ new_decoder_input = []
328
+ else:
329
+ new_decoder_input = [
330
+ encoder_out["decoder_input"][0].index_select(0, new_order)
331
+ ]
332
+
333
+ encoder_states = encoder_out["encoder_states"]
334
+ if len(encoder_states) > 0:
335
+ for idx, state in enumerate(encoder_states):
336
+ encoder_states[idx] = state.index_select(1, new_order)
337
+
338
+ return {
339
+ "encoder_out": new_encoder_out, # T x B x C
340
+ "encoder_padding_mask": new_encoder_padding_mask, # B x T
341
+ "encoder_states": encoder_states, # List[T x B x C]
342
+ "src_tokens": src_tokens, # B x T
343
+ "encoder_out_for_ctc": new_x_for_ctc, # T x B x C
344
+ "decoder_input": new_decoder_input,
345
+ }
346
+
347
+ # def max_positions(self):
348
+ # """Maximum input length supported by the encoder."""
349
+ # return self.max_source_positions
350
+
351
+ def upgrade_state_dict_named(self, state_dict, name):
352
+ """Upgrade a (possibly old) state dict for new versions of fairseq."""
353
+ # if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
354
+ # weights_key = "{}.embed_positions.weights".format(name)
355
+ # if weights_key in state_dict:
356
+ # print("deleting {0}".format(weights_key))
357
+ # del state_dict[weights_key]
358
+ # state_dict[
359
+ # "{}.embed_positions._float_tensor".format(name)
360
+ # ] = torch.FloatTensor(1)
361
+ for i in range(self.num_layers):
362
+ # update layer norms
363
+ if not isinstance(self.layers[i], TransformerSentenceEncoderLayer):
364
+ self.layers[i].upgrade_state_dict_named(
365
+ state_dict, "{}.layers.{}".format(name, i)
366
+ )
367
+
368
+ version_key = "{}.version".format(name)
369
+ if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
370
+ # earlier checkpoints did not normalize after the stack of layers
371
+ self.layer_norm = None
372
+ self.normalize = False
373
+ state_dict[version_key] = torch.Tensor([1])
374
+ return state_dict
375
+
376
+ def set_num_updates(self, num_updates):
377
+ """Set the number of parameters updates."""
378
+ super().set_num_updates(num_updates)
379
+ self.num_updates = num_updates
380
+
artst/models/modules/multihead_attention.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+
5
+ # Based on speecht5, fairseq and espnet code bases
6
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
7
+ # --------------------------------------------------------
8
+
9
+ import math
10
+ from typing import Dict, Optional, Tuple
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from fairseq import utils
15
+ from fairseq.incremental_decoding_utils import with_incremental_state
16
+ from fairseq.modules.fairseq_dropout import FairseqDropout
17
+ from fairseq.modules.quant_noise import quant_noise
18
+ from torch import Tensor, nn
19
+ from torch.nn import Parameter
20
+
21
+
22
+ @with_incremental_state
23
+ class MultiheadAttention(nn.Module):
24
+ """Multi-headed attention.
25
+
26
+ See "Attention Is All You Need" for more details.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ embed_dim,
32
+ num_heads,
33
+ kdim=None,
34
+ vdim=None,
35
+ dropout=0.0,
36
+ bias=True,
37
+ add_bias_kv=False,
38
+ add_zero_attn=False,
39
+ self_attention=False,
40
+ encoder_decoder_attention=False,
41
+ q_noise=0.0,
42
+ qn_block_size=8,
43
+ has_relative_attention_bias=False,
44
+ ):
45
+ super().__init__()
46
+ self.embed_dim = embed_dim
47
+ self.kdim = kdim if kdim is not None else embed_dim
48
+ self.vdim = vdim if vdim is not None else embed_dim
49
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
50
+
51
+ self.num_heads = num_heads
52
+ self.dropout_module = FairseqDropout(
53
+ dropout, module_name=self.__class__.__name__
54
+ )
55
+
56
+ self.has_relative_attention_bias = has_relative_attention_bias
57
+ self.head_dim = embed_dim // num_heads
58
+ assert (
59
+ self.head_dim * num_heads == self.embed_dim
60
+ ), "embed_dim must be divisible by num_heads"
61
+ self.scaling = self.head_dim ** -0.5
62
+
63
+ self.self_attention = self_attention
64
+ self.encoder_decoder_attention = encoder_decoder_attention
65
+
66
+ assert not self.self_attention or self.qkv_same_dim, (
67
+ "Self-attention requires query, key and " "value to be of the same size"
68
+ )
69
+
70
+ self.k_proj = quant_noise(
71
+ nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size
72
+ )
73
+ self.v_proj = quant_noise(
74
+ nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
75
+ )
76
+ self.q_proj = quant_noise(
77
+ nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
78
+ )
79
+
80
+ self.out_proj = quant_noise(
81
+ nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
82
+ )
83
+
84
+ if add_bias_kv:
85
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
86
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
87
+ else:
88
+ self.bias_k = self.bias_v = None
89
+
90
+ self.add_zero_attn = add_zero_attn
91
+
92
+ self.reset_parameters()
93
+
94
+ self.onnx_trace = False
95
+
96
+ def prepare_for_onnx_export_(self):
97
+ self.onnx_trace = True
98
+
99
+ def reset_parameters(self):
100
+ if self.qkv_same_dim:
101
+ # Empirically observed the convergence to be much better with
102
+ # the scaled initialization
103
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
104
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
105
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
106
+ else:
107
+ nn.init.xavier_uniform_(self.k_proj.weight)
108
+ nn.init.xavier_uniform_(self.v_proj.weight)
109
+ nn.init.xavier_uniform_(self.q_proj.weight)
110
+
111
+ nn.init.xavier_uniform_(self.out_proj.weight)
112
+ if self.out_proj.bias is not None:
113
+ nn.init.constant_(self.out_proj.bias, 0.0)
114
+ if self.bias_k is not None:
115
+ nn.init.xavier_normal_(self.bias_k)
116
+ if self.bias_v is not None:
117
+ nn.init.xavier_normal_(self.bias_v)
118
+
119
+ def forward(
120
+ self,
121
+ query,
122
+ key: Optional[Tensor],
123
+ value: Optional[Tensor],
124
+ key_padding_mask: Optional[Tensor] = None,
125
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
126
+ need_weights: bool = True,
127
+ static_kv: bool = False,
128
+ attn_mask: Optional[Tensor] = None,
129
+ before_softmax: bool = False,
130
+ need_head_weights: bool = False,
131
+ position_bias: Optional[Tensor] = None
132
+ ) -> Tuple[Tensor, Optional[Tensor]]:
133
+ """Input shape: Time x Batch x Channel
134
+
135
+ Args:
136
+ key_padding_mask (ByteTensor, optional): mask to exclude
137
+ keys that are pads, of shape `(batch, src_len)`, where
138
+ padding elements are indicated by 1s.
139
+ need_weights (bool, optional): return the attention weights,
140
+ averaged over heads (default: False).
141
+ attn_mask (ByteTensor, optional): typically used to
142
+ implement causal attention, where the mask prevents the
143
+ attention from looking forward in time (default: None).
144
+ before_softmax (bool, optional): return the raw attention
145
+ weights and values before the attention softmax.
146
+ need_head_weights (bool, optional): return the attention
147
+ weights for each head. Implies *need_weights*. Default:
148
+ return the average attention weights over all heads.
149
+ """
150
+ if need_head_weights:
151
+ need_weights = True
152
+
153
+ is_tpu = query.device.type == "xla"
154
+
155
+ tgt_len, bsz, embed_dim = query.size()
156
+ src_len = tgt_len
157
+ assert embed_dim == self.embed_dim
158
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
159
+ if key is not None:
160
+ src_len, key_bsz, _ = key.size()
161
+ if not torch.jit.is_scripting():
162
+ assert key_bsz == bsz
163
+ assert value is not None
164
+ assert src_len, bsz == value.shape[:2]
165
+
166
+ if (
167
+ not self.onnx_trace
168
+ and not is_tpu # don't use PyTorch version on TPUs
169
+ and incremental_state is None
170
+ and not static_kv
171
+ # A workaround for quantization to work. Otherwise JIT compilation
172
+ # treats bias in linear module as method.
173
+ and not torch.jit.is_scripting()
174
+ and not self.has_relative_attention_bias
175
+ ):
176
+ assert key is not None and value is not None
177
+ # Hawau:
178
+ if query.dtype != attn_mask.dtype:
179
+ attn_mask = attn_mask.type(query.dtype)
180
+ # My code ends here
181
+ return F.multi_head_attention_forward(
182
+ query,
183
+ key,
184
+ value,
185
+ self.embed_dim,
186
+ self.num_heads,
187
+ torch.empty([0]),
188
+ torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
189
+ self.bias_k,
190
+ self.bias_v,
191
+ self.add_zero_attn,
192
+ self.dropout_module.p,
193
+ self.out_proj.weight,
194
+ self.out_proj.bias,
195
+ self.training or self.dropout_module.apply_during_inference,
196
+ key_padding_mask,
197
+ need_weights,
198
+ attn_mask,
199
+ use_separate_proj_weight=True,
200
+ q_proj_weight=self.q_proj.weight,
201
+ k_proj_weight=self.k_proj.weight,
202
+ v_proj_weight=self.v_proj.weight,
203
+ )
204
+
205
+ if incremental_state is not None:
206
+ saved_state = self._get_input_buffer(incremental_state)
207
+ if saved_state is not None and "prev_key" in saved_state:
208
+ # previous time steps are cached - no need to recompute
209
+ # key and value if they are static
210
+ if static_kv:
211
+ assert self.encoder_decoder_attention and not self.self_attention
212
+ key = value = None
213
+ else:
214
+ saved_state = None
215
+
216
+ if self.self_attention:
217
+ q = self.q_proj(query)
218
+ k = self.k_proj(query)
219
+ v = self.v_proj(query)
220
+ elif self.encoder_decoder_attention:
221
+ # encoder-decoder attention
222
+ q = self.q_proj(query)
223
+ if key is None:
224
+ assert value is None
225
+ k = v = None
226
+ else:
227
+ k = self.k_proj(key)
228
+ v = self.v_proj(key)
229
+
230
+ else:
231
+ assert key is not None and value is not None
232
+ q = self.q_proj(query)
233
+ k = self.k_proj(key)
234
+ v = self.v_proj(value)
235
+ q *= self.scaling
236
+
237
+ if self.bias_k is not None:
238
+ assert self.bias_v is not None
239
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
240
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
241
+ if attn_mask is not None:
242
+ attn_mask = torch.cat(
243
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
244
+ )
245
+ if key_padding_mask is not None:
246
+ key_padding_mask = torch.cat(
247
+ [
248
+ key_padding_mask,
249
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
250
+ ],
251
+ dim=1,
252
+ )
253
+
254
+ q = (
255
+ q.contiguous()
256
+ .view(tgt_len, bsz * self.num_heads, self.head_dim)
257
+ .transpose(0, 1)
258
+ )
259
+ if k is not None:
260
+ k = (
261
+ k.contiguous()
262
+ .view(-1, bsz * self.num_heads, self.head_dim)
263
+ .transpose(0, 1)
264
+ )
265
+ if v is not None:
266
+ v = (
267
+ v.contiguous()
268
+ .view(-1, bsz * self.num_heads, self.head_dim)
269
+ .transpose(0, 1)
270
+ )
271
+
272
+ if saved_state is not None:
273
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
274
+ if "prev_key" in saved_state:
275
+ _prev_key = saved_state["prev_key"]
276
+ assert _prev_key is not None
277
+ prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
278
+ if static_kv:
279
+ k = prev_key
280
+ else:
281
+ assert k is not None
282
+ k = torch.cat([prev_key, k], dim=1)
283
+ src_len = k.size(1)
284
+ if "prev_value" in saved_state:
285
+ _prev_value = saved_state["prev_value"]
286
+ assert _prev_value is not None
287
+ prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
288
+ if static_kv:
289
+ v = prev_value
290
+ else:
291
+ assert v is not None
292
+ v = torch.cat([prev_value, v], dim=1)
293
+ prev_key_padding_mask: Optional[Tensor] = None
294
+ if "prev_key_padding_mask" in saved_state:
295
+ prev_key_padding_mask = saved_state["prev_key_padding_mask"]
296
+ assert k is not None and v is not None
297
+ key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
298
+ key_padding_mask=key_padding_mask,
299
+ prev_key_padding_mask=prev_key_padding_mask,
300
+ batch_size=bsz,
301
+ src_len=k.size(1),
302
+ static_kv=static_kv,
303
+ )
304
+
305
+ saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
306
+ saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
307
+ saved_state["prev_key_padding_mask"] = key_padding_mask
308
+ # In this branch incremental_state is never None
309
+ assert incremental_state is not None
310
+ incremental_state = self._set_input_buffer(incremental_state, saved_state)
311
+ assert k is not None
312
+ assert k.size(1) == src_len
313
+
314
+ # This is part of a workaround to get around fork/join parallelism
315
+ # not supporting Optional types.
316
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
317
+ key_padding_mask = None
318
+
319
+ if key_padding_mask is not None:
320
+ assert key_padding_mask.size(0) == bsz
321
+ assert key_padding_mask.size(1) == src_len
322
+
323
+ if self.add_zero_attn:
324
+ assert v is not None
325
+ src_len += 1
326
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
327
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
328
+ if attn_mask is not None:
329
+ attn_mask = torch.cat(
330
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
331
+ )
332
+ if key_padding_mask is not None:
333
+ key_padding_mask = torch.cat(
334
+ [
335
+ key_padding_mask,
336
+ torch.zeros(key_padding_mask.size(0), 1).type_as(
337
+ key_padding_mask
338
+ ),
339
+ ],
340
+ dim=1,
341
+ )
342
+
343
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
344
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
345
+
346
+ if position_bias is not None and self.has_relative_attention_bias: ## first order
347
+ ## position_bias: [241, 241, 64]
348
+ #print ("attn_weights: ", attn_weights.size()) # [492, 241, 241]
349
+ reshape_q = q.contiguous().view(bsz * self.num_heads, -1, self.head_dim).transpose(0,1) #[241, 492, 64]
350
+ #print ("reshape_q: ", reshape_q.size())
351
+ B = torch.matmul(reshape_q, position_bias.transpose(-2, -1))
352
+ #print ("B: ", B.size()) ## [241, 492, 241]
353
+ #B = B.transpose(0, 1).view(bsz, self.num_heads, position_bias.size(0), position_bias.size(1))
354
+ B = B.transpose(0, 1).view(bsz*self.num_heads, position_bias.size(0), position_bias.size(1))
355
+ #print ("B 2: ", B.size())
356
+ attn_weights += B
357
+ else:
358
+ position_bias = None
359
+
360
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
361
+
362
+ if attn_mask is not None:
363
+ attn_mask = attn_mask.unsqueeze(0)
364
+ if self.onnx_trace:
365
+ attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
366
+ attn_weights += attn_mask
367
+
368
+ if key_padding_mask is not None:
369
+ # don't attend to padding symbols
370
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
371
+ if not is_tpu:
372
+ attn_weights = attn_weights.masked_fill(
373
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
374
+ float("-inf"),
375
+ )
376
+ else:
377
+ attn_weights = attn_weights.transpose(0, 2)
378
+ attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
379
+ attn_weights = attn_weights.transpose(0, 2)
380
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
381
+
382
+ if before_softmax:
383
+ return attn_weights, v
384
+
385
+ attn_weights_float = utils.softmax(
386
+ attn_weights, dim=-1, onnx_trace=self.onnx_trace
387
+ )
388
+ attn_weights = attn_weights_float.type_as(attn_weights)
389
+ attn_probs = self.dropout_module(attn_weights)
390
+
391
+ assert v is not None
392
+ attn = torch.bmm(attn_probs, v)
393
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
394
+ if self.onnx_trace and attn.size(1) == 1:
395
+ # when ONNX tracing a single decoder step (sequence length == 1)
396
+ # the transpose is a no-op copy before view, thus unnecessary
397
+ attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
398
+ else:
399
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
400
+ attn = self.out_proj(attn)
401
+ attn_weights: Optional[Tensor] = None
402
+ if need_weights:
403
+ attn_weights = attn_weights_float.view(
404
+ bsz, self.num_heads, tgt_len, src_len
405
+ ).transpose(1, 0)
406
+ if not need_head_weights:
407
+ # average attention weights over heads
408
+ attn_weights = attn_weights.mean(dim=0)
409
+
410
+ return attn, attn_weights
411
+
412
+ @staticmethod
413
+ def _append_prev_key_padding_mask(
414
+ key_padding_mask: Optional[Tensor],
415
+ prev_key_padding_mask: Optional[Tensor],
416
+ batch_size: int,
417
+ src_len: int,
418
+ static_kv: bool,
419
+ ) -> Optional[Tensor]:
420
+ # saved key padding masks have shape (bsz, seq_len)
421
+ if prev_key_padding_mask is not None and static_kv:
422
+ new_key_padding_mask = prev_key_padding_mask
423
+ elif prev_key_padding_mask is not None and key_padding_mask is not None:
424
+ new_key_padding_mask = torch.cat(
425
+ [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
426
+ )
427
+ # During incremental decoding, as the padding token enters and
428
+ # leaves the frame, there will be a time when prev or current
429
+ # is None
430
+ elif prev_key_padding_mask is not None:
431
+ if src_len > prev_key_padding_mask.size(1):
432
+ filler = torch.zeros(
433
+ (batch_size, src_len - prev_key_padding_mask.size(1)),
434
+ device=prev_key_padding_mask.device,
435
+ )
436
+ new_key_padding_mask = torch.cat(
437
+ [prev_key_padding_mask.float(), filler.float()], dim=1
438
+ )
439
+ else:
440
+ new_key_padding_mask = prev_key_padding_mask.float()
441
+ elif key_padding_mask is not None:
442
+ if src_len > key_padding_mask.size(1):
443
+ filler = torch.zeros(
444
+ (batch_size, src_len - key_padding_mask.size(1)),
445
+ device=key_padding_mask.device,
446
+ )
447
+ new_key_padding_mask = torch.cat(
448
+ [filler.float(), key_padding_mask.float()], dim=1
449
+ )
450
+ else:
451
+ new_key_padding_mask = key_padding_mask.float()
452
+ else:
453
+ new_key_padding_mask = prev_key_padding_mask
454
+ return new_key_padding_mask
455
+
456
+ @torch.jit.export
457
+ def reorder_incremental_state(
458
+ self,
459
+ incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
460
+ new_order: Tensor,
461
+ ):
462
+ """Reorder buffered internal state (for incremental generation)."""
463
+ input_buffer = self._get_input_buffer(incremental_state)
464
+ if input_buffer is not None:
465
+ for k in input_buffer.keys():
466
+ input_buffer_k = input_buffer[k]
467
+ if input_buffer_k is not None:
468
+ if self.encoder_decoder_attention and input_buffer_k.size(
469
+ 0
470
+ ) == new_order.size(0):
471
+ break
472
+ input_buffer[k] = input_buffer_k.index_select(0, new_order)
473
+ incremental_state = self._set_input_buffer(incremental_state, input_buffer)
474
+ return incremental_state
475
+
476
+ def _get_input_buffer(
477
+ self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
478
+ ) -> Dict[str, Optional[Tensor]]:
479
+ result = self.get_incremental_state(incremental_state, "attn_state")
480
+ if result is not None:
481
+ return result
482
+ else:
483
+ empty_result: Dict[str, Optional[Tensor]] = {}
484
+ return empty_result
485
+
486
+ def _set_input_buffer(
487
+ self,
488
+ incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
489
+ buffer: Dict[str, Optional[Tensor]],
490
+ ):
491
+ return self.set_incremental_state(incremental_state, "attn_state", buffer)
492
+
493
+ def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
494
+ return attn_weights
495
+
496
+ def upgrade_state_dict_named(self, state_dict, name):
497
+ prefix = name + "." if name != "" else ""
498
+ items_to_add = {}
499
+ keys_to_remove = []
500
+ for k in state_dict.keys():
501
+ if k.endswith(prefix + "in_proj_weight"):
502
+ # in_proj_weight used to be q + k + v with same dimensions
503
+ dim = int(state_dict[k].shape[0] / 3)
504
+ items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
505
+ items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
506
+ items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
507
+
508
+ keys_to_remove.append(k)
509
+
510
+ k_bias = prefix + "in_proj_bias"
511
+ if k_bias in state_dict.keys():
512
+ dim = int(state_dict[k].shape[0] / 3)
513
+ items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
514
+ items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][
515
+ dim : 2 * dim
516
+ ]
517
+ items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
518
+
519
+ keys_to_remove.append(prefix + "in_proj_bias")
520
+
521
+ for k in keys_to_remove:
522
+ del state_dict[k]
523
+
524
+ for key, value in items_to_add.items():
525
+ state_dict[key] = value
artst/models/modules/speaker_decoder_postnet.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+
5
+ # Based on speecht5, fairseq and espnet code bases
6
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
7
+ # --------------------------------------------------------
8
+
9
+ import torch.nn as nn
10
+ import math
11
+ import torch
12
+ import torch.nn.functional as F
13
+
14
+
15
+ class AngularMargin(nn.Module):
16
+ """
17
+ An implementation of Angular Margin (AM) proposed in the following
18
+ paper: '''Margin Matters: Towards More Discriminative Deep Neural Network
19
+ Embeddings for Speaker Recognition''' (https://arxiv.org/abs/1906.07317)
20
+
21
+ Arguments
22
+ ---------
23
+ margin : float
24
+ The margin for cosine similiarity
25
+ scale : float
26
+ The scale for cosine similiarity
27
+
28
+ Return
29
+ ---------
30
+ predictions : torch.Tensor
31
+
32
+ Example
33
+ -------
34
+ >>> pred = AngularMargin()
35
+ >>> outputs = torch.tensor([ [1., -1.], [-1., 1.], [0.9, 0.1], [0.1, 0.9] ])
36
+ >>> targets = torch.tensor([ [1., 0.], [0., 1.], [ 1., 0.], [0., 1.] ])
37
+ >>> predictions = pred(outputs, targets)
38
+ >>> predictions[:,0] > predictions[:,1]
39
+ tensor([ True, False, True, False])
40
+ """
41
+
42
+ def __init__(self, margin=0.0, scale=1.0):
43
+ super(AngularMargin, self).__init__()
44
+ self.margin = margin
45
+ self.scale = scale
46
+
47
+ def forward(self, outputs, targets):
48
+ """Compute AM between two tensors
49
+
50
+ Arguments
51
+ ---------
52
+ outputs : torch.Tensor
53
+ The outputs of shape [N, C], cosine similarity is required.
54
+ targets : torch.Tensor
55
+ The targets of shape [N, C], where the margin is applied for.
56
+
57
+ Return
58
+ ---------
59
+ predictions : torch.Tensor
60
+ """
61
+ outputs = outputs - self.margin * targets
62
+ return self.scale * outputs
63
+
64
+
65
+ class AdditiveAngularMargin(AngularMargin):
66
+ """
67
+ An implementation of Additive Angular Margin (AAM) proposed
68
+ in the following paper: '''Margin Matters: Towards More Discriminative Deep
69
+ Neural Network Embeddings for Speaker Recognition'''
70
+ (https://arxiv.org/abs/1906.07317)
71
+
72
+ Arguments
73
+ ---------
74
+ margin : float
75
+ The margin for cosine similiarity, usually 0.2.
76
+ scale: float
77
+ The scale for cosine similiarity, usually 30.
78
+
79
+ Returns
80
+ -------
81
+ predictions : torch.Tensor
82
+ Tensor.
83
+ Example
84
+ -------
85
+ >>> outputs = torch.tensor([ [1., -1.], [-1., 1.], [0.9, 0.1], [0.1, 0.9] ])
86
+ >>> targets = torch.tensor([ [1., 0.], [0., 1.], [ 1., 0.], [0., 1.] ])
87
+ >>> pred = AdditiveAngularMargin()
88
+ >>> predictions = pred(outputs, targets)
89
+ >>> predictions[:,0] > predictions[:,1]
90
+ tensor([ True, False, True, False])
91
+ """
92
+
93
+ def __init__(self, margin=0.0, scale=1.0, easy_margin=False):
94
+ super(AdditiveAngularMargin, self).__init__(margin, scale)
95
+ self.easy_margin = easy_margin
96
+
97
+ self.cos_m = math.cos(self.margin)
98
+ self.sin_m = math.sin(self.margin)
99
+ self.th = math.cos(math.pi - self.margin)
100
+ self.mm = math.sin(math.pi - self.margin) * self.margin
101
+
102
+ def forward(self, outputs, targets):
103
+ """
104
+ Compute AAM between two tensors
105
+
106
+ Arguments
107
+ ---------
108
+ outputs : torch.Tensor
109
+ The outputs of shape [N, C], cosine similarity is required.
110
+ targets : torch.Tensor
111
+ The targets of shape [N, C], where the margin is applied for.
112
+
113
+ Return
114
+ ---------
115
+ predictions : torch.Tensor
116
+ """
117
+ cosine = outputs.float()
118
+ sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
119
+ phi = cosine * self.cos_m - sine * self.sin_m # cos(theta + m)
120
+ if self.easy_margin:
121
+ phi = torch.where(cosine > 0, phi, cosine)
122
+ else:
123
+ phi = torch.where(cosine > self.th, phi, cosine - self.mm)
124
+ outputs = (targets * phi) + ((1.0 - targets) * cosine)
125
+ return self.scale * outputs
126
+
127
+
128
+ class SpeakerDecoderPostnet(nn.Module):
129
+ """Speaker Identification Postnet.
130
+
131
+ Arguments
132
+ ---------
133
+ embed_dim : int
134
+ The size of embedding.
135
+ class_num: int
136
+ The number of classes.
137
+ args : Namespace
138
+
139
+ Return
140
+ ---------
141
+ embed : torch.Tensor
142
+ output : torch.Tensor
143
+ """
144
+
145
+ def __init__(self, embed_dim, class_num, args):
146
+ super(SpeakerDecoderPostnet, self).__init__()
147
+ self.embed_dim = embed_dim
148
+ self.class_num = class_num
149
+ self.no_pooling_bn = getattr(args, "sid_no_pooling_bn", False)
150
+ self.no_embed_postnet = getattr(args, "sid_no_embed_postnet", False)
151
+ self.normalize_postnet = getattr(args, "sid_normalize_postnet", False)
152
+ self.softmax_head = getattr(args, "sid_softmax_type", "softmax")
153
+ if not self.no_pooling_bn:
154
+ self.bn_pooling = nn.BatchNorm1d(args.decoder_output_dim)
155
+ else:
156
+ self.bn_pooling = None
157
+ if not self.no_embed_postnet:
158
+ self.output_embedding = nn.Linear(args.decoder_output_dim, embed_dim, bias=False)
159
+ self.bn_embedding = nn.BatchNorm1d(embed_dim)
160
+ else:
161
+ self.output_embedding = None
162
+ self.bn_embedding = None
163
+ self.embed_dim = args.decoder_output_dim
164
+ self.output_projection = nn.Linear(self.embed_dim, class_num, bias=False)
165
+ if self.softmax_head == "amsoftmax":
166
+ self.output_layer = AngularMargin(args.softmax_margin, args.softmax_scale)
167
+ elif self.softmax_head == "aamsoftmax":
168
+ self.output_layer = AdditiveAngularMargin(args.softmax_margin, args.softmax_scale, args.softmax_easy_margin)
169
+ else:
170
+ self.output_layer = None
171
+ if self.output_embedding is not None:
172
+ nn.init.normal_(self.output_embedding.weight, mean=0, std=embed_dim ** -0.5)
173
+ nn.init.normal_(self.output_projection.weight, mean=0, std=class_num ** -0.5)
174
+
175
+ def forward(self, x, target=None):
176
+ """
177
+ Parameters
178
+ ----------
179
+ x : torch.Tensor of shape [batch, channel] or [batch, time, channel]
180
+ target : torch.Tensor of shape [batch, channel]
181
+ """
182
+ if self.bn_pooling is not None:
183
+ x = self.bn_pooling(x)
184
+ if self.output_embedding is not None and self.bn_embedding is not None:
185
+ embed = self.bn_embedding(self.output_embedding(x))
186
+ else:
187
+ embed = x
188
+ if self.output_layer is not None or self.normalize_postnet:
189
+ x_norm = F.normalize(embed, p=2, dim=1)
190
+ w_norm = F.normalize(self.output_projection.weight, p=2, dim=1) # [out_dim, in_dim]
191
+ output = F.linear(x_norm, w_norm)
192
+ if self.training and target is not None and self.output_layer is not None:
193
+ output = self.output_layer(output, target)
194
+ else:
195
+ output = self.output_projection(embed)
196
+ return output, embed
artst/models/modules/speech_decoder_postnet.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+
5
+ # Based on speecht5, fairseq and espnet code bases
6
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
7
+ # --------------------------------------------------------
8
+
9
+ import contextlib
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ from espnet.nets.pytorch_backend.tacotron2.decoder import Postnet
14
+
15
+
16
+ class SpeechDecoderPostnet(nn.Module):
17
+ """
18
+
19
+ Args:
20
+ in_channels (int): the number of input channels
21
+ mid_channels (int): the number of intermediate channels
22
+ out_channels (int): the number of output channels
23
+ kernel_sizes (List[int]): the kernel size for each convolutional layer
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ odim,
29
+ args,
30
+ ):
31
+ super(SpeechDecoderPostnet, self).__init__()
32
+ # define decoder postnet
33
+ # define final projection
34
+ self.feat_out = torch.nn.Linear(args.decoder_embed_dim, odim * args.reduction_factor)
35
+ self.prob_out = torch.nn.Linear(args.decoder_embed_dim, args.reduction_factor)
36
+
37
+ # define postnet
38
+ self.postnet = (
39
+ None
40
+ if args.postnet_layers == 0
41
+ else Postnet(
42
+ idim=0,
43
+ odim=odim,
44
+ n_layers=args.postnet_layers,
45
+ n_chans=args.postnet_chans,
46
+ n_filts=args.postnet_filts,
47
+ use_batch_norm=args.use_batch_norm,
48
+ dropout_rate=args.postnet_dropout_rate,
49
+ )
50
+ )
51
+
52
+ self.odim = odim
53
+ self.num_updates = 0
54
+ self.freeze_decoder_updates = args.freeze_decoder_updates
55
+
56
+ def forward(self, zs):
57
+ ft = self.freeze_decoder_updates <= self.num_updates
58
+ with torch.no_grad() if not ft else contextlib.ExitStack():
59
+ # (B, Lmax//r, odim * r) -> (B, Lmax//r * r, odim)
60
+ before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim)
61
+ # (B, Lmax//r, r) -> (B, Lmax//r * r)
62
+ logits = self.prob_out(zs).view(zs.size(0), -1)
63
+ # postnet -> (B, Lmax//r * r, odim)
64
+ if self.postnet is None:
65
+ after_outs = before_outs
66
+ else:
67
+ after_outs = before_outs + self.postnet(
68
+ before_outs.transpose(1, 2)
69
+ ).transpose(1, 2)
70
+
71
+ return before_outs, after_outs, logits
72
+
73
+ def set_num_updates(self, num_updates):
74
+ """Set the number of parameters updates."""
75
+ self.num_updates = num_updates
artst/models/modules/speech_decoder_prenet.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+
5
+ # Based on speecht5, fairseq and espnet code bases
6
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
7
+ # --------------------------------------------------------
8
+
9
+ import contextlib
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ import torch.nn.functional as F
14
+ from espnet.nets.pytorch_backend.tacotron2.decoder import Prenet as TacotronDecoderPrenet
15
+ from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding
16
+ from espnet.nets.pytorch_backend.transformer.embedding import ScaledPositionalEncoding
17
+ from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask
18
+
19
+
20
+ class SpeechDecoderPrenet(nn.Module):
21
+ """
22
+
23
+ Args:
24
+ in_channels (int): the number of input channels
25
+ mid_channels (int): the number of intermediate channels
26
+ out_channels (int): the number of output channels
27
+ kernel_sizes (List[int]): the kernel size for each convolutional layer
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ odim,
33
+ args,
34
+ ):
35
+ super(SpeechDecoderPrenet, self).__init__()
36
+ # define decoder prenet
37
+ if args.dprenet_layers != 0:
38
+ # decoder prenet
39
+ decoder_input_layer = torch.nn.Sequential(
40
+ TacotronDecoderPrenet(
41
+ idim=odim,
42
+ n_layers=args.dprenet_layers,
43
+ n_units=args.dprenet_units,
44
+ dropout_rate=args.dprenet_dropout_rate,
45
+ ),
46
+ torch.nn.Linear(args.dprenet_units, args.decoder_embed_dim),
47
+ )
48
+ else:
49
+ decoder_input_layer = "linear"
50
+
51
+ pos_enc_class = (
52
+ ScaledPositionalEncoding if args.dec_use_scaled_pos_enc else PositionalEncoding
53
+ )
54
+
55
+ if decoder_input_layer == "linear":
56
+ self.decoder_prenet = torch.nn.Sequential(
57
+ torch.nn.Linear(odim, args.decoder_embed_dim),
58
+ torch.nn.LayerNorm(args.decoder_embed_dim),
59
+ torch.nn.Dropout(args.transformer_dec_dropout_rate),
60
+ torch.nn.ReLU(),
61
+ pos_enc_class(args.decoder_embed_dim, args.transformer_dec_positional_dropout_rate),
62
+ )
63
+ elif isinstance(decoder_input_layer, torch.nn.Module):
64
+ self.decoder_prenet = torch.nn.Sequential(
65
+ decoder_input_layer, pos_enc_class(args.decoder_embed_dim, args.transformer_dec_positional_dropout_rate, max_len=args.max_speech_positions)
66
+ )
67
+
68
+ if args.spk_embed_integration_type == 'pre':
69
+ self.spkembs_layer = torch.nn.Sequential(
70
+ torch.nn.Linear(args.spk_embed_dim + args.decoder_embed_dim, args.decoder_embed_dim), torch.nn.ReLU()
71
+ )
72
+ self.num_updates = 0
73
+ self.freeze_decoder_updates = args.freeze_decoder_updates
74
+
75
+ def forward(self, prev_output_tokens, tgt_lengths_in=None, spkembs=None):
76
+ ft = self.freeze_decoder_updates <= self.num_updates
77
+ with torch.no_grad() if not ft else contextlib.ExitStack():
78
+ prev_output_tokens = self.decoder_prenet(prev_output_tokens)
79
+
80
+ if spkembs is not None:
81
+ spkembs = F.normalize(spkembs).unsqueeze(1).expand(-1, prev_output_tokens.size(1), -1)
82
+ prev_output_tokens = self.spkembs_layer(torch.cat([prev_output_tokens, spkembs], dim=-1))
83
+
84
+ if tgt_lengths_in is not None:
85
+ tgt_frames_mask = ~(self._source_mask(tgt_lengths_in).squeeze(1))
86
+ else:
87
+ tgt_frames_mask = None
88
+ return prev_output_tokens, tgt_frames_mask
89
+
90
+ def _source_mask(self, ilens):
91
+ """Make masks for self-attention.
92
+ Args:
93
+ ilens (LongTensor or List): Batch of lengths (B,).
94
+ Returns:
95
+ Tensor: Mask tensor for self-attention.
96
+ dtype=torch.uint8 in PyTorch 1.2-
97
+ dtype=torch.bool in PyTorch 1.2+ (including 1.2)
98
+ Examples:
99
+ >>> ilens = [5, 3]
100
+ >>> self._source_mask(ilens)
101
+ tensor([[[1, 1, 1, 1, 1],
102
+ [[1, 1, 1, 0, 0]]], dtype=torch.uint8)
103
+ """
104
+ x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device)
105
+ return x_masks.unsqueeze(-2)
106
+
107
+ def set_num_updates(self, num_updates):
108
+ """Set the number of parameters updates."""
109
+ self.num_updates = num_updates
artst/models/modules/speech_encoder_postnet.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+
5
+ # Based on speecht5, fairseq and espnet code bases
6
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
7
+ # --------------------------------------------------------
8
+
9
+ import logging
10
+ import torch.nn as nn
11
+ import torch
12
+
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ class SpeechEncoderPostnet(nn.Module):
17
+ """
18
+
19
+ Args:
20
+ in_channels (int): the number of input channels
21
+ mid_channels (int): the number of intermediate channels
22
+ out_channels (int): the number of output channels
23
+ kernel_sizes (List[int]): the kernel size for each convolutional layer
24
+ """
25
+
26
+ def __init__(self, dictionaries, args):
27
+ super(SpeechEncoderPostnet, self).__init__()
28
+ # modules below are not needed during fine-tuning
29
+ self.target_glu = args.target_glu
30
+ self.skip_masked = args.skip_masked
31
+ self.skip_nomask = args.skip_nomask
32
+ self.logit_temp = args.logit_temp
33
+
34
+ final_dim = (
35
+ args.final_dim if args.final_dim > 0 else args.encoder_embed_dim
36
+ )
37
+ if any([d is None for d in dictionaries]):
38
+ logger.info(
39
+ "cannot find dictionary. assume will be used for fine-tuning"
40
+ )
41
+ else:
42
+ self.num_classes = [len(d) for d in dictionaries]
43
+ self.label_embs_concat = nn.Parameter(
44
+ torch.FloatTensor(sum(self.num_classes), final_dim)
45
+ )
46
+ nn.init.uniform_(self.label_embs_concat)
47
+ self.untie_final_proj = args.untie_final_proj
48
+ if self.untie_final_proj:
49
+ self.final_proj = nn.Linear(
50
+ args.encoder_embed_dim, final_dim * len(dictionaries)
51
+ )
52
+ else:
53
+ self.final_proj = nn.Linear(args.encoder_embed_dim, final_dim)
54
+
55
+ def compute_nce(self, x, pos, negs):
56
+ neg_is_pos = (pos == negs).all(-1)
57
+ pos = pos.unsqueeze(0)
58
+ targets = torch.cat([pos, negs], dim=0)
59
+
60
+ logits = torch.cosine_similarity(
61
+ x.float(), targets.float(), dim=-1
62
+ ).type_as(x)
63
+ logits /= self.logit_temp
64
+ if neg_is_pos.any():
65
+ logits[1:][neg_is_pos] = float("-inf")
66
+ logits = logits.transpose(0, 1) # (num_x, num_cls+1)
67
+ return logits
68
+
69
+ def forward(self, x, padding_mask, mask_indices, target_list):
70
+ def compute_pred(proj_x, target, label_embs):
71
+ # compute logits for the i-th label set
72
+ y = torch.index_select(label_embs, 0, target.long())
73
+ negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1)
74
+ if self.target_glu:
75
+ y = self.target_glu(y)
76
+ negs = self.target_glu(negs)
77
+ # proj_x: (S, D)
78
+ # y: (S, D)
79
+ # negs: (Neg, S, D)
80
+ return self.compute_nce(proj_x, y, negs)
81
+
82
+ label_embs_list = self.label_embs_concat.split(self.num_classes, 0)
83
+
84
+ if not self.skip_masked:
85
+ masked_indices = torch.logical_and(~padding_mask, mask_indices)
86
+ proj_x_m = self.final_proj(x[masked_indices])
87
+ if self.untie_final_proj:
88
+ proj_x_m_list = proj_x_m.chunk(len(target_list), dim=-1)
89
+ else:
90
+ proj_x_m_list = [proj_x_m for _ in range(len(target_list))]
91
+ logit_m_list = [
92
+ compute_pred(proj_x_m, t[masked_indices], label_embs_list[i])
93
+ for i, (proj_x_m, t) in enumerate(
94
+ zip(proj_x_m_list, target_list)
95
+ )
96
+ ]
97
+ else:
98
+ logit_m_list = [None for _ in target_list]
99
+
100
+ if not self.skip_nomask:
101
+ nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
102
+ proj_x_u = self.final_proj(x[nomask_indices])
103
+ if self.untie_final_proj:
104
+ proj_x_u_list = proj_x_u.chunk(len(target_list), dim=-1)
105
+ else:
106
+ proj_x_u_list = [proj_x_u for _ in range(len(target_list))]
107
+
108
+ logit_u_list = [
109
+ compute_pred(proj_x_u, t[nomask_indices], label_embs_list[i])
110
+ for i, (proj_x_u, t) in enumerate(
111
+ zip(proj_x_u_list, target_list)
112
+ )
113
+ ]
114
+ else:
115
+ logit_u_list = [None for _ in target_list]
116
+
117
+ result = {
118
+ "logit_m_list": logit_m_list,
119
+ "logit_u_list": logit_u_list,
120
+ "padding_mask": padding_mask,
121
+ }
122
+
123
+ return result
artst/models/modules/speech_encoder_prenet.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+
5
+ # Based on speecht5, fairseq and espnet code bases
6
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
7
+ # --------------------------------------------------------
8
+
9
+ import logging
10
+ import math
11
+ import torch
12
+ import contextlib
13
+ from typing import List, Tuple
14
+ import torch.nn as nn
15
+
16
+ from fairseq.data.data_utils import lengths_to_padding_mask
17
+ from fairseq.data.data_utils import compute_mask_indices
18
+ from fairseq.modules import (
19
+ PositionalEmbedding,
20
+ Fp32GroupNorm,
21
+ FairseqDropout,
22
+ SamePad,
23
+ GradMultiply,
24
+ LayerNorm,
25
+ Fp32LayerNorm,
26
+ TransposeLast,
27
+ )
28
+ import numpy as np
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ class LinearLayer(nn.Module):
34
+ def __init__(self, idim, odom, dropout=0):
35
+ super(LinearLayer, self).__init__()
36
+ self.linear = nn.Sequential(
37
+ nn.Linear(idim, odom),
38
+ nn.LayerNorm(odom),
39
+ nn.Dropout(dropout),
40
+ nn.ReLU(),
41
+ )
42
+
43
+ def get_out_seq_lens_tensor(self, in_seq_lens_tensor):
44
+ out = in_seq_lens_tensor.clone()
45
+ return out
46
+
47
+ def forward(self, src_tokens, src_lengths):
48
+ """
49
+ src_tokens: [B, T, C]
50
+ src_lengths: [B]
51
+ """
52
+ x = self.linear(src_tokens)
53
+ x = x.transpose(0, 1).contiguous() # -> T x B x C
54
+ return x, src_lengths
55
+
56
+
57
+ class SpeechEncoderPrenet(nn.Module):
58
+ """
59
+
60
+ Args:
61
+ in_channels (int): the number of input channels
62
+ mid_channels (int): the number of intermediate channels
63
+ out_channels (int): the number of output channels
64
+ kernel_sizes (List[int]): the kernel size for each convolutional layer
65
+ """
66
+
67
+ def __init__(self, args):
68
+ super(SpeechEncoderPrenet, self).__init__()
69
+ self.dropout_module = FairseqDropout(
70
+ p=args.dropout, module_name=self.__class__.__name__
71
+ )
72
+ self.embed_scale = math.sqrt(args.encoder_embed_dim)
73
+ if args.no_scale_embedding:
74
+ self.embed_scale = 1.0
75
+ self.padding_idx = 1
76
+ self.freeze_encoder_updates = args.freeze_encoder_updates
77
+ self.num_updates = 0
78
+ assert args.encoder_speech_prenet in ["conv", "linear"], args.encoder_speech_prenet
79
+ feature_enc_layers = eval(args.conv_feature_layers) # noqa
80
+ self.embed = feature_enc_layers[-1][0]
81
+
82
+ self.feature_extractor = ConvFeatureExtractionModel(
83
+ conv_layers=feature_enc_layers,
84
+ dropout=0.0,
85
+ mode=args.extractor_mode,
86
+ conv_bias=args.conv_bias,
87
+ )
88
+ feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers])
89
+ self.feat2tar_ratio = (
90
+ args.label_rates * feature_ds_rate / args.sample_rate
91
+ )
92
+
93
+ self.post_extract_proj = (
94
+ nn.Linear(self.embed, args.encoder_embed_dim)
95
+ if self.embed != args.encoder_embed_dim
96
+ else None
97
+ )
98
+
99
+ self.use_conv_pos = args.use_conv_pos
100
+ self.use_sinc_pos = args.use_sinc_pos
101
+ self.use_abs_pos = getattr(args, "use_abs_pos", False)
102
+
103
+ self.feature_grad_mult = args.feature_grad_mult
104
+ if self.use_conv_pos:
105
+ self.layer_norm = LayerNorm(self.embed)
106
+ self.pos_conv = nn.Conv1d(
107
+ args.encoder_embed_dim,
108
+ args.encoder_embed_dim,
109
+ kernel_size=args.conv_pos,
110
+ padding=args.conv_pos // 2,
111
+ groups=args.conv_pos_groups,
112
+ )
113
+ dropout = 0
114
+ std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * args.encoder_embed_dim))
115
+ nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
116
+ nn.init.constant_(self.pos_conv.bias, 0)
117
+ self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
118
+ self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
119
+
120
+ assert not (self.use_sinc_pos and self.use_abs_pos), f"sinc pos: {self.use_sinc_pos} abs pos: {self.use_abs_pos}"
121
+ if self.use_sinc_pos:
122
+ self.embed_positions = PositionalEmbedding(
123
+ args.max_speech_positions, args.encoder_embed_dim, self.padding_idx
124
+ )
125
+ if self.use_abs_pos:
126
+ self.embed_positions = PositionalEmbedding(
127
+ args.max_speech_positions, args.encoder_embed_dim, self.padding_idx, learned=True
128
+ )
129
+
130
+ # Hubert
131
+ self.mask_prob = args.mask_prob
132
+ self.mask_selection = args.mask_selection
133
+ self.mask_other = args.mask_other
134
+ self.hubert_mask_length = args.hubert_mask_length
135
+ self.no_mask_overlap = args.no_mask_overlap
136
+ self.mask_min_space = args.mask_min_space
137
+
138
+ self.mask_channel_prob = args.mask_channel_prob
139
+ self.mask_channel_selection = args.mask_channel_selection
140
+ self.mask_channel_other = args.mask_channel_other
141
+ self.mask_channel_length = args.mask_channel_length
142
+ self.no_mask_channel_overlap = args.no_mask_channel_overlap
143
+ self.mask_channel_min_space = args.mask_channel_min_space
144
+
145
+ self.mask_emb = nn.Parameter(
146
+ torch.FloatTensor(args.encoder_embed_dim).uniform_()
147
+ )
148
+
149
+ def forward(self, src_tokens, require_feat_pen=False, target_list=None, padding_mask=None, mask=True):
150
+ ft = self.freeze_encoder_updates <= self.num_updates
151
+ with torch.no_grad() if not ft else contextlib.ExitStack():
152
+ return self._forward(src_tokens, require_feat_pen, target_list, padding_mask, mask)
153
+
154
+ def _forward(self, src_tokens, require_feat_pen=False, target_list=None, padding_mask=None, mask=True):
155
+ if self.feature_grad_mult > 0:
156
+ x = self.feature_extractor(src_tokens)
157
+ x = x.transpose(1, 2).transpose(0, 1) # [length, batch, hidden_size]
158
+ if self.feature_grad_mult != 1.0:
159
+ x = GradMultiply.apply(x, self.feature_grad_mult)
160
+ else:
161
+ with torch.no_grad():
162
+ x = self.feature_extractor(src_tokens)
163
+ x = x.transpose(1, 2).transpose(0, 1) # [length, batch, hidden_size]
164
+ x = x.transpose(0, 1) # [batch, length, hidden_size]
165
+
166
+ encoder_padding_mask = padding_mask
167
+
168
+ x = x.transpose(1, 2) # [batch, hidden_size, length]
169
+ if target_list is not None:
170
+ x, target_list = self.forward_targets(x, target_list)
171
+ features_pen = x.float().pow(2).mean()
172
+ x = x.transpose(1, 2) # [batch, length, hidden_size]
173
+ x = self.layer_norm(x)
174
+ encoder_padding_mask = self.forward_padding_mask(x, encoder_padding_mask)
175
+ if self.post_extract_proj is not None:
176
+ x = self.post_extract_proj(x)
177
+ x = self.dropout_module(x)
178
+ if mask:
179
+ x, mask_indices = self.apply_hubert_mask(
180
+ x, encoder_padding_mask
181
+ )
182
+ else:
183
+ x = x
184
+ mask_indices = None
185
+
186
+ if self.use_conv_pos:
187
+ positions = self.pos_conv(x.transpose(1, 2))
188
+ positions = positions.transpose(1, 2)
189
+ #else:
190
+ # positions = self.embed_positions(encoder_padding_mask)
191
+ x = x + positions
192
+
193
+ if self.use_sinc_pos:
194
+ positions = self.embed_positions(encoder_padding_mask)
195
+ x = x + positions
196
+
197
+ # x = self.dropout_module(x)
198
+
199
+ if require_feat_pen:
200
+ return (x, features_pen, mask_indices, target_list), encoder_padding_mask
201
+ else:
202
+ # For consistence with encoder
203
+ return x, encoder_padding_mask
204
+
205
+ def forward_targets(
206
+ self, features: torch.Tensor, target_list: List[torch.Tensor],
207
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
208
+ # Trim features to ensure labels exist and then get aligned labels
209
+ feat_tsz = features.size(2)
210
+ targ_tsz = min([t.size(1) for t in target_list])
211
+ if self.feat2tar_ratio * feat_tsz > targ_tsz:
212
+ feat_tsz = int(targ_tsz / self.feat2tar_ratio)
213
+ features = features[..., :feat_tsz]
214
+ target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
215
+ target_list = [t[:, target_inds.long()] for t in target_list]
216
+ return features, target_list
217
+
218
+ def forward_padding_mask(
219
+ self, features: torch.Tensor, padding_mask: torch.Tensor,
220
+ ) -> torch.Tensor:
221
+ extra = padding_mask.size(1) % features.size(1)
222
+ if extra > 0:
223
+ padding_mask = padding_mask[:, :-extra]
224
+ padding_mask = padding_mask.view(
225
+ padding_mask.size(0), features.size(1), -1
226
+ )
227
+ padding_mask = padding_mask.all(-1)
228
+ return padding_mask
229
+
230
+ def get_src_lengths(self, src_lengths):
231
+ return self.feature_extractor.get_out_seq_lens_tensor(src_lengths)
232
+
233
+ def apply_hubert_mask(self, x, padding_mask):
234
+ B, T, C = x.shape
235
+ if self.mask_prob > 0:
236
+ mask_indices = compute_mask_indices(
237
+ (B, T),
238
+ padding_mask,
239
+ self.mask_prob,
240
+ self.hubert_mask_length,
241
+ self.mask_selection,
242
+ self.mask_other,
243
+ min_masks=2,
244
+ no_overlap=self.no_mask_overlap,
245
+ min_space=self.mask_min_space,
246
+ )
247
+ mask_indices = torch.from_numpy(mask_indices).to(x.device)
248
+ x[mask_indices] = self.mask_emb
249
+ else:
250
+ mask_indices = None
251
+
252
+ if self.mask_channel_prob > 0:
253
+ mask_channel_indices = compute_mask_indices(
254
+ (B, C),
255
+ None,
256
+ self.mask_channel_prob,
257
+ self.mask_channel_length,
258
+ self.mask_channel_selection,
259
+ self.mask_channel_other,
260
+ no_overlap=self.no_mask_channel_overlap,
261
+ min_space=self.mask_channel_min_space,
262
+ )
263
+ mask_channel_indices = (
264
+ torch.from_numpy(mask_channel_indices)
265
+ .to(x.device)
266
+ .unsqueeze(1)
267
+ .expand(-1, T, -1)
268
+ )
269
+ x[mask_channel_indices] = 0
270
+
271
+ return x, mask_indices
272
+
273
+ def set_num_updates(self, num_updates):
274
+ """Set the number of parameters updates."""
275
+ self.num_updates = num_updates
276
+
277
+ class ConvFeatureExtractionModel(nn.Module):
278
+ def __init__(
279
+ self,
280
+ conv_layers: List[Tuple[int, int, int]],
281
+ dropout: float = 0.0,
282
+ mode: str = "default",
283
+ conv_bias: bool = False,
284
+ ):
285
+ super().__init__()
286
+
287
+ assert mode in {"default", "layer_norm"}
288
+
289
+ def block(
290
+ n_in,
291
+ n_out,
292
+ k,
293
+ stride,
294
+ is_layer_norm=False,
295
+ is_group_norm=False,
296
+ conv_bias=False,
297
+ ):
298
+ def make_conv():
299
+ conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
300
+ nn.init.kaiming_normal_(conv.weight)
301
+ return conv
302
+
303
+ assert (
304
+ is_layer_norm and is_group_norm
305
+ ) == False, "layer norm and group norm are exclusive"
306
+
307
+ if is_layer_norm:
308
+ return nn.Sequential(
309
+ make_conv(),
310
+ nn.Dropout(p=dropout),
311
+ nn.Sequential(
312
+ TransposeLast(),
313
+ Fp32LayerNorm(dim, elementwise_affine=True),
314
+ TransposeLast(),
315
+ ),
316
+ nn.GELU(),
317
+ )
318
+ elif is_group_norm:
319
+ return nn.Sequential(
320
+ make_conv(),
321
+ nn.Dropout(p=dropout),
322
+ Fp32GroupNorm(dim, dim, affine=True),
323
+ nn.GELU(),
324
+ )
325
+ else:
326
+ return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
327
+
328
+ in_d = 1
329
+ self.conv_layers = nn.ModuleList()
330
+ self.conv_layers_infos = conv_layers
331
+ for i, cl in enumerate(conv_layers):
332
+ assert len(cl) == 3, "invalid conv definition: " + str(cl)
333
+ (dim, k, stride) = cl
334
+
335
+ self.conv_layers.append(
336
+ block(
337
+ in_d,
338
+ dim,
339
+ k,
340
+ stride,
341
+ is_layer_norm=mode == "layer_norm",
342
+ is_group_norm=mode == "default" and i == 0,
343
+ conv_bias=conv_bias,
344
+ )
345
+ )
346
+ in_d = dim
347
+
348
+ def forward(self, x):
349
+ # BxT -> BxCxT
350
+ x = x.unsqueeze(1)
351
+ for conv in self.conv_layers:
352
+ x = conv(x)
353
+ return x
354
+
355
+ def get_out_seq_lens_nonmask_after_a_layer(self, in_seq_lens_tensor, i):
356
+ """Returns the out_seq_lens_nonmask 0/1 tensor after a layer.
357
+
358
+ Args:
359
+ in_seq_lens_tensor (LongTensor): length
360
+
361
+ Returns:
362
+ LongTensor: length
363
+ """
364
+ out_lengths = in_seq_lens_tensor.clone()
365
+ out_lengths = ((out_lengths.float() - (self.conv_layers_infos[i][1] - 1) - 1) / self.conv_layers_infos[i][-1] + 1).floor().long()
366
+ out_nonmask = (~lengths_to_padding_mask(out_lengths)).float()
367
+ return out_nonmask, out_lengths
368
+
369
+ def get_out_seq_lens_tensor(self, in_seq_lens_tensor):
370
+ out = in_seq_lens_tensor.clone()
371
+ for i in range(len(self.conv_layers)):
372
+ out = ((out.float() - (self.conv_layers_infos[i][1] - 1) - 1) / self.conv_layers_infos[i][-1] + 1).floor().long()
373
+ return out
artst/models/modules/text_decoder_postnet.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+
5
+ # Based on speecht5, fairseq and espnet code bases
6
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
7
+ # --------------------------------------------------------
8
+
9
+ import torch.nn as nn
10
+ import torch
11
+ import contextlib
12
+
13
+ from fairseq import utils
14
+ from fairseq.modules import (
15
+ AdaptiveSoftmax,
16
+ )
17
+
18
+ class TextDecoderPostnet(nn.Module):
19
+ """
20
+
21
+ Args:
22
+ in_channels (int): the number of input channels
23
+ mid_channels (int): the number of intermediate channels
24
+ out_channels (int): the number of output channels
25
+ kernel_sizes (List[int]): the kernel size for each convolutional layer
26
+ """
27
+
28
+ def __init__(self, embed_tokens, dictionary, args, output_projection=None,):
29
+ super(TextDecoderPostnet, self).__init__()
30
+ self.output_embed_dim = args.decoder_output_dim
31
+ self.output_projection = output_projection
32
+ self.adaptive_softmax = None
33
+ self.share_input_output_embed = args.share_input_output_embed
34
+ if self.output_projection is None:
35
+ self.build_output_projection(args, dictionary, embed_tokens)
36
+ self.freeze_decoder_updates = args.freeze_decoder_updates
37
+ self.num_updates = 0
38
+
39
+ def output_layer(self, features):
40
+ """Project features to the vocabulary size."""
41
+ if self.adaptive_softmax is None:
42
+ # project back to size of vocabulary
43
+ return self.output_projection(features)
44
+ else:
45
+ return features
46
+
47
+ def build_output_projection(self, args, dictionary, embed_tokens):
48
+ if args.adaptive_softmax_cutoff is not None:
49
+ self.adaptive_softmax = AdaptiveSoftmax(
50
+ len(dictionary),
51
+ self.output_embed_dim,
52
+ utils.eval_str_list(args.adaptive_softmax_cutoff, type=int),
53
+ dropout=args.adaptive_softmax_dropout,
54
+ adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None,
55
+ factor=args.adaptive_softmax_factor,
56
+ tie_proj=args.tie_adaptive_proj,
57
+ )
58
+ elif self.share_input_output_embed:
59
+ self.output_projection = nn.Linear(
60
+ embed_tokens.weight.shape[1],
61
+ embed_tokens.weight.shape[0],
62
+ bias=False,
63
+ )
64
+ self.output_projection.weight = embed_tokens.weight
65
+ else:
66
+ self.output_projection = nn.Linear(
67
+ self.output_embed_dim, len(dictionary), bias=False
68
+ )
69
+ nn.init.normal_(
70
+ self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5
71
+ )
72
+ # num_base_layers = getattr(args, "base_layers", 0)
73
+ # for i in range(num_base_layers):
74
+ # self.layers.insert(
75
+ # ((i + 1) * args.decoder_layers) // (num_base_layers + 1),
76
+ # BaseLayer(args),
77
+ # )
78
+
79
+ def forward(self, x):
80
+ ft = self.freeze_decoder_updates <= self.num_updates
81
+ with torch.no_grad() if not ft else contextlib.ExitStack():
82
+ return self._forward(x)
83
+
84
+ def _forward(self, x):
85
+ # embed positions
86
+ x = self.output_layer(x)
87
+
88
+ return x
89
+
90
+ def set_num_updates(self, num_updates):
91
+ """Set the number of parameters updates."""
92
+ self.num_updates = num_updates
artst/models/modules/text_decoder_prenet.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+
5
+ # Based on speecht5, fairseq and espnet code bases
6
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
7
+ # --------------------------------------------------------
8
+
9
+ import math
10
+ import torch.nn as nn
11
+ import torch
12
+ import contextlib
13
+
14
+ from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
15
+ from fairseq.models.transformer import Linear #,LayerNorm
16
+ from fairseq.modules import (
17
+ PositionalEmbedding,
18
+ FairseqDropout,
19
+ LayerNorm
20
+ )
21
+
22
+
23
+ class TextDecoderPrenet(nn.Module):
24
+ """
25
+
26
+ Args:
27
+ in_channels (int): the number of input channels
28
+ mid_channels (int): the number of intermediate channels
29
+ out_channels (int): the number of output channels
30
+ kernel_sizes (List[int]): the kernel size for each convolutional layer
31
+ """
32
+
33
+ def __init__(self, embed_tokens, args):
34
+ super(TextDecoderPrenet, self).__init__()
35
+ self.dropout_module = FairseqDropout(
36
+ args.dropout, module_name=self.__class__.__name__
37
+ )
38
+ self.decoder_layerdrop = args.decoder_layerdrop
39
+ self.num_updates = 0
40
+
41
+ input_embed_dim = embed_tokens.embedding_dim
42
+ embed_dim = args.decoder_embed_dim
43
+ self.embed_dim = embed_dim
44
+ self.output_embed_dim = args.decoder_output_dim
45
+
46
+ self.padding_idx = embed_tokens.padding_idx
47
+
48
+ self.embed_tokens = embed_tokens
49
+
50
+ self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
51
+
52
+ if not args.adaptive_input and args.quant_noise_pq > 0:
53
+ self.quant_noise = apply_quant_noise_(
54
+ nn.Linear(embed_dim, embed_dim, bias=False),
55
+ args.quant_noise_pq,
56
+ args.quant_noise_pq_block_size,
57
+ )
58
+ else:
59
+ self.quant_noise = None
60
+
61
+ self.project_in_dim = (
62
+ Linear(input_embed_dim, embed_dim, bias=False)
63
+ if embed_dim != input_embed_dim
64
+ else None
65
+ )
66
+ self.embed_positions = (
67
+ PositionalEmbedding(
68
+ args.max_text_positions,
69
+ embed_dim,
70
+ self.padding_idx,
71
+ learned=args.decoder_learned_pos,
72
+ )
73
+ if not args.no_token_positional_embeddings
74
+ else None
75
+ )
76
+ export = getattr(args, "export", False)
77
+ if getattr(args, "layernorm_embedding", False):
78
+ self.layernorm_embedding = LayerNorm(embed_dim, export=export)
79
+ else:
80
+ self.layernorm_embedding = None
81
+
82
+ self.freeze_decoder_updates = args.freeze_decoder_updates
83
+
84
+ def forward(self, prev_output_tokens, incremental_state=None):
85
+ ft = self.freeze_decoder_updates <= self.num_updates
86
+ with torch.no_grad() if not ft else contextlib.ExitStack():
87
+ return self._forward(prev_output_tokens, incremental_state)
88
+
89
+ def _forward(self, prev_output_tokens, incremental_state=None):
90
+ if prev_output_tokens.eq(self.padding_idx).any():
91
+ x_mask = prev_output_tokens.eq(self.padding_idx)
92
+ else:
93
+ x_mask = None
94
+
95
+ # embed positions
96
+ positions = None
97
+ if self.embed_positions is not None:
98
+ positions = self.embed_positions(
99
+ prev_output_tokens, incremental_state=incremental_state
100
+ )
101
+
102
+ if incremental_state is not None:
103
+ prev_output_tokens = prev_output_tokens[:, -1:]
104
+ if positions is not None:
105
+ positions = positions[:, -1:]
106
+
107
+ # embed tokens and positions
108
+ x = self.embed_scale * self.embed_tokens(prev_output_tokens)
109
+
110
+ if self.quant_noise is not None:
111
+ x = self.quant_noise(x)
112
+
113
+ if self.project_in_dim is not None:
114
+ x = self.project_in_dim(x)
115
+
116
+ if positions is not None:
117
+ x += positions
118
+
119
+ if self.layernorm_embedding is not None:
120
+ x = self.layernorm_embedding(x)
121
+
122
+ x = self.dropout_module(x)
123
+
124
+ return x, x_mask, incremental_state
125
+
126
+ def set_num_updates(self, num_updates):
127
+ """Set the number of parameters updates."""
128
+ self.num_updates = num_updates
artst/models/modules/text_encoder_prenet.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+
5
+ # Based on speecht5, fairseq and espnet code bases
6
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
7
+ # --------------------------------------------------------
8
+
9
+ import torch.nn as nn
10
+
11
+ from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding
12
+ from espnet.nets.pytorch_backend.transformer.embedding import ScaledPositionalEncoding
13
+
14
+
15
+ class TextEncoderPrenet(nn.Module):
16
+ """
17
+
18
+ Args:
19
+ in_channels (int): the number of input channels
20
+ mid_channels (int): the number of intermediate channels
21
+ out_channels (int): the number of output channels
22
+ kernel_sizes (List[int]): the kernel size for each convolutional layer
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ embed_tokens,
28
+ args,
29
+ ):
30
+ super(TextEncoderPrenet, self).__init__()
31
+ self.padding_idx = embed_tokens.padding_idx
32
+ # define encoder prenet
33
+ # get positional encoding class
34
+ pos_enc_class = (
35
+ ScaledPositionalEncoding if args.enc_use_scaled_pos_enc else PositionalEncoding
36
+ )
37
+
38
+ self.encoder_prenet = nn.Sequential(
39
+ embed_tokens,
40
+ pos_enc_class(args.encoder_embed_dim, args.transformer_enc_positional_dropout_rate, max_len=args.max_text_positions),
41
+ )
42
+
43
+ def forward(self, src_tokens):
44
+ return self.encoder_prenet(src_tokens), src_tokens.eq(self.padding_idx)
artst/models/modules/transformer_layer.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+
5
+ # Based on speecht5, fairseq and espnet code bases
6
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
7
+ # --------------------------------------------------------
8
+
9
+ from typing import Dict, List, Optional
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import contextlib
14
+ from fairseq import utils
15
+ from fairseq.modules import LayerNorm
16
+ from .multihead_attention import MultiheadAttention
17
+ from fairseq.modules.fairseq_dropout import FairseqDropout
18
+ from fairseq.modules.quant_noise import quant_noise
19
+ from torch import Tensor
20
+
21
+
22
+ class TransformerSentenceEncoderLayer(nn.Module):
23
+ """
24
+ Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
25
+ models.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ embedding_dim: float = 768,
31
+ ffn_embedding_dim: float = 3072,
32
+ num_attention_heads: float = 8,
33
+ dropout: float = 0.1,
34
+ attention_dropout: float = 0.1,
35
+ activation_dropout: float = 0.1,
36
+ activation_fn: str = "relu",
37
+ layer_norm_first: bool = False,
38
+ has_relative_attention_bias: bool = False,
39
+ ) -> None:
40
+
41
+ super().__init__()
42
+ # Initialize parameters
43
+ self.embedding_dim = embedding_dim
44
+ self.dropout = dropout
45
+ self.activation_dropout = activation_dropout
46
+
47
+ # Initialize blocks
48
+ self.activation_fn = utils.get_activation_fn(activation_fn)
49
+ self.self_attn = MultiheadAttention(
50
+ self.embedding_dim,
51
+ num_attention_heads,
52
+ dropout=attention_dropout,
53
+ self_attention=True,
54
+ has_relative_attention_bias=has_relative_attention_bias,
55
+ )
56
+
57
+ self.dropout1 = nn.Dropout(dropout)
58
+ self.dropout2 = nn.Dropout(self.activation_dropout)
59
+ self.dropout3 = nn.Dropout(dropout)
60
+
61
+ self.layer_norm_first = layer_norm_first
62
+
63
+ # layer norm associated with the self attention layer
64
+ self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
65
+ self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
66
+ self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
67
+
68
+ # layer norm associated with the position wise feed-forward NN
69
+ self.final_layer_norm = LayerNorm(self.embedding_dim)
70
+
71
+ if has_relative_attention_bias:
72
+ self.norm_k = LayerNorm(self.embedding_dim//num_attention_heads)
73
+
74
+ def forward(
75
+ self,
76
+ x: torch.Tensor,
77
+ self_attn_mask: torch.Tensor = None,
78
+ self_attn_padding_mask: torch.Tensor = None,
79
+ need_weights: bool = False,
80
+ att_args=None,
81
+ pos_bias=None,
82
+ ):
83
+ """
84
+ LayerNorm is applied either before or after the self-attention/ffn
85
+ modules similar to the original Transformer imlementation.
86
+ """
87
+ residual = x
88
+
89
+ if self.layer_norm_first:
90
+ x = self.self_attn_layer_norm(x)
91
+ if pos_bias is not None:
92
+ pos_bias = self.norm_k(pos_bias)
93
+ x, attn = self.self_attn(
94
+ query=x,
95
+ key=x,
96
+ value=x,
97
+ key_padding_mask=self_attn_padding_mask,
98
+ attn_mask=self_attn_mask,
99
+ position_bias=pos_bias,
100
+ )
101
+ x = self.dropout1(x)
102
+ x = residual + x
103
+
104
+ residual = x
105
+ x = self.final_layer_norm(x)
106
+ x = self.activation_fn(self.fc1(x))
107
+ x = self.dropout2(x)
108
+ x = self.fc2(x)
109
+ x = self.dropout3(x)
110
+ x = residual + x
111
+ else:
112
+ x, attn = self.self_attn(
113
+ query=x,
114
+ key=x,
115
+ value=x,
116
+ key_padding_mask=self_attn_padding_mask,
117
+ position_bias=pos_bias,
118
+ )
119
+
120
+ x = self.dropout1(x)
121
+ x = residual + x
122
+
123
+ x = self.self_attn_layer_norm(x)
124
+
125
+ residual = x
126
+ x = self.activation_fn(self.fc1(x))
127
+ x = self.dropout2(x)
128
+ x = self.fc2(x)
129
+ x = self.dropout3(x)
130
+ x = residual + x
131
+ x = self.final_layer_norm(x)
132
+
133
+ return x, attn
134
+
135
+
136
+ class TransformerDecoderLayer(nn.Module):
137
+ """Decoder layer block.
138
+
139
+ In the original paper each operation (multi-head attention, encoder
140
+ attention or FFN) is postprocessed with: `dropout -> add residual ->
141
+ layernorm`. In the tensor2tensor code they suggest that learning is more
142
+ robust when preprocessing each layer with layernorm and postprocessing with:
143
+ `dropout -> add residual`. We default to the approach in the paper, but the
144
+ tensor2tensor approach can be enabled by setting
145
+ *args.decoder_normalize_before* to ``True``.
146
+
147
+ Args:
148
+ args (argparse.Namespace): parsed command-line arguments
149
+ no_encoder_attn (bool, optional): whether to attend to encoder outputs
150
+ (default: False).
151
+ """
152
+
153
+ def __init__(
154
+ self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False, has_relative_attention_bias=False
155
+ ):
156
+ super().__init__()
157
+ self.embed_dim = args.decoder_embed_dim
158
+ self.num_updates = 0
159
+ self.dropout_module = FairseqDropout(
160
+ args.dropout, module_name=self.__class__.__name__
161
+ )
162
+ self.quant_noise = getattr(args, "quant_noise_pq", 0)
163
+ self.quant_noise_block_size = getattr(args, "quant_noise_pq_block_size", 8)
164
+
165
+ self.cross_self_attention = getattr(args, "cross_self_attention", False)
166
+
167
+ self.freeze_decoder_updates = getattr(args, "freeze_decoder_updates", 0)
168
+
169
+ self.self_attn = self.build_self_attention(
170
+ self.embed_dim,
171
+ args,
172
+ add_bias_kv=add_bias_kv,
173
+ add_zero_attn=add_zero_attn,
174
+ )
175
+
176
+ self.activation_fn = utils.get_activation_fn(
177
+ activation=str(args.activation_fn)
178
+ if getattr(args, "activation_fn", None) is not None
179
+ else "relu"
180
+ )
181
+ activation_dropout_p = getattr(args, "activation_dropout", 0) or 0
182
+ if activation_dropout_p == 0:
183
+ # for backwards compatibility with models that use args.relu_dropout
184
+ activation_dropout_p = getattr(args, "relu_dropout", 0) or 0
185
+ self.activation_dropout_module = FairseqDropout(
186
+ float(activation_dropout_p), module_name=self.__class__.__name__
187
+ )
188
+ self.normalize_before = args.decoder_normalize_before
189
+
190
+ export = getattr(args, "export", False)
191
+ self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
192
+
193
+ if no_encoder_attn:
194
+ self.encoder_attn = None
195
+ self.encoder_attn_layer_norm = None
196
+ else:
197
+ self.encoder_attn = self.build_encoder_attention(self.embed_dim, args)
198
+ self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
199
+
200
+ self.fc1 = self.build_fc1(
201
+ self.embed_dim,
202
+ args.decoder_ffn_embed_dim,
203
+ self.quant_noise,
204
+ self.quant_noise_block_size,
205
+ )
206
+ self.fc2 = self.build_fc2(
207
+ args.decoder_ffn_embed_dim,
208
+ self.embed_dim,
209
+ self.quant_noise,
210
+ self.quant_noise_block_size,
211
+ )
212
+
213
+ self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
214
+ self.need_attn = True
215
+
216
+ self.onnx_trace = False
217
+
218
+ self.has_relative_attention_bias = has_relative_attention_bias
219
+ if self.has_relative_attention_bias:
220
+ self.norm_k = LayerNorm(self.embed_dim//args.decoder_attention_heads)
221
+
222
+ def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
223
+ return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
224
+
225
+ def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
226
+ return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
227
+
228
+ def build_self_attention(
229
+ self, embed_dim, args, add_bias_kv=False, add_zero_attn=False
230
+ ):
231
+ return MultiheadAttention(
232
+ embed_dim,
233
+ args.decoder_attention_heads,
234
+ dropout=args.attention_dropout,
235
+ add_bias_kv=add_bias_kv,
236
+ add_zero_attn=add_zero_attn,
237
+ self_attention=not getattr(args, "cross_self_attention", False),
238
+ q_noise=self.quant_noise,
239
+ qn_block_size=self.quant_noise_block_size,
240
+ #has_relative_attention_bias=args.has_relative_attention_bias,
241
+ )
242
+
243
+ def build_encoder_attention(self, embed_dim, args):
244
+ return MultiheadAttention(
245
+ embed_dim,
246
+ args.decoder_attention_heads,
247
+ kdim=getattr(args, "encoder_embed_dim", None),
248
+ vdim=getattr(args, "encoder_embed_dim", None),
249
+ dropout=args.attention_dropout,
250
+ encoder_decoder_attention=True,
251
+ q_noise=self.quant_noise,
252
+ qn_block_size=self.quant_noise_block_size,
253
+ )
254
+
255
+ def prepare_for_onnx_export_(self):
256
+ self.onnx_trace = True
257
+
258
+ def residual_connection(self, x, residual):
259
+ return residual + x
260
+
261
+ def forward(
262
+ self,
263
+ x,
264
+ encoder_out: Optional[torch.Tensor] = None,
265
+ encoder_padding_mask: Optional[torch.Tensor] = None,
266
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
267
+ prev_self_attn_state: Optional[List[torch.Tensor]] = None,
268
+ prev_attn_state: Optional[List[torch.Tensor]] = None,
269
+ self_attn_mask: Optional[torch.Tensor] = None,
270
+ self_attn_padding_mask: Optional[torch.Tensor] = None,
271
+ need_attn: bool = False,
272
+ need_head_weights: bool = False,
273
+ pos_bias=None,
274
+ ):
275
+ """
276
+ Args:
277
+ x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
278
+ encoder_padding_mask (ByteTensor, optional): binary
279
+ ByteTensor of shape `(batch, src_len)` where padding
280
+ elements are indicated by ``1``.
281
+ need_attn (bool, optional): return attention weights
282
+ need_head_weights (bool, optional): return attention weights
283
+ for each head (default: return average over heads).
284
+
285
+ Returns:
286
+ encoded output of shape `(seq_len, batch, embed_dim)`
287
+ """
288
+ ft = self.freeze_decoder_updates <= self.num_updates
289
+
290
+ with torch.no_grad() if not ft else contextlib.ExitStack():
291
+ if need_head_weights:
292
+ need_attn = True
293
+
294
+ residual = x
295
+ if self.normalize_before:
296
+ x = self.self_attn_layer_norm(x)
297
+ if pos_bias is not None:
298
+ pos_bias = self.norm_k(pos_bias)
299
+ if prev_self_attn_state is not None:
300
+ prev_key, prev_value = prev_self_attn_state[:2]
301
+ saved_state: Dict[str, Optional[Tensor]] = {
302
+ "prev_key": prev_key,
303
+ "prev_value": prev_value,
304
+ }
305
+ if len(prev_self_attn_state) >= 3:
306
+ saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
307
+ assert incremental_state is not None
308
+ self.self_attn._set_input_buffer(incremental_state, saved_state)
309
+ _self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state)
310
+ if self.cross_self_attention and not (
311
+ incremental_state is not None
312
+ and _self_attn_input_buffer is not None
313
+ and "prev_key" in _self_attn_input_buffer
314
+ ):
315
+ if self_attn_mask is not None:
316
+ assert encoder_out is not None
317
+ self_attn_mask = torch.cat(
318
+ (x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1
319
+ )
320
+ if self_attn_padding_mask is not None:
321
+ if encoder_padding_mask is None:
322
+ assert encoder_out is not None
323
+ encoder_padding_mask = self_attn_padding_mask.new_zeros(
324
+ encoder_out.size(1), encoder_out.size(0)
325
+ )
326
+ self_attn_padding_mask = torch.cat(
327
+ (encoder_padding_mask, self_attn_padding_mask), dim=1
328
+ )
329
+ assert encoder_out is not None
330
+ y = torch.cat((encoder_out, x), dim=0)
331
+ else:
332
+ y = x
333
+
334
+ x, attn = self.self_attn(
335
+ query=x,
336
+ key=y,
337
+ value=y,
338
+ key_padding_mask=self_attn_padding_mask,
339
+ incremental_state=incremental_state,
340
+ need_weights=False,
341
+ attn_mask=self_attn_mask,
342
+ position_bias=pos_bias,
343
+ )
344
+ x = self.dropout_module(x)
345
+ x = self.residual_connection(x, residual)
346
+ if not self.normalize_before:
347
+ x = self.self_attn_layer_norm(x)
348
+
349
+ if self.encoder_attn is not None and encoder_out is not None:
350
+ residual = x
351
+ if self.normalize_before:
352
+ x = self.encoder_attn_layer_norm(x)
353
+ if prev_attn_state is not None:
354
+ prev_key, prev_value = prev_attn_state[:2]
355
+ saved_state: Dict[str, Optional[Tensor]] = {
356
+ "prev_key": prev_key,
357
+ "prev_value": prev_value,
358
+ }
359
+ if len(prev_attn_state) >= 3:
360
+ saved_state["prev_key_padding_mask"] = prev_attn_state[2]
361
+ assert incremental_state is not None
362
+ self.encoder_attn._set_input_buffer(incremental_state, saved_state)
363
+
364
+ x, attn = self.encoder_attn(
365
+ query=x,
366
+ key=encoder_out,
367
+ value=encoder_out,
368
+ key_padding_mask=encoder_padding_mask,
369
+ incremental_state=incremental_state,
370
+ static_kv=True,
371
+ need_weights=need_attn or (not self.training and self.need_attn),
372
+ need_head_weights=need_head_weights,
373
+ )
374
+ x = self.dropout_module(x)
375
+ x = self.residual_connection(x, residual)
376
+ if not self.normalize_before:
377
+ x = self.encoder_attn_layer_norm(x)
378
+
379
+ with torch.no_grad() if not ft else contextlib.ExitStack():
380
+ residual = x
381
+ if self.normalize_before:
382
+ x = self.final_layer_norm(x)
383
+
384
+ x = self.activation_fn(self.fc1(x))
385
+ x = self.activation_dropout_module(x)
386
+ x = self.fc2(x)
387
+ x = self.dropout_module(x)
388
+ x = self.residual_connection(x, residual)
389
+ if not self.normalize_before:
390
+ x = self.final_layer_norm(x)
391
+ if self.onnx_trace and incremental_state is not None:
392
+ saved_state = self.self_attn._get_input_buffer(incremental_state)
393
+ assert saved_state is not None
394
+ if self_attn_padding_mask is not None:
395
+ self_attn_state = [
396
+ saved_state["prev_key"],
397
+ saved_state["prev_value"],
398
+ saved_state["prev_key_padding_mask"],
399
+ ]
400
+ else:
401
+ self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]]
402
+ return x, attn, self_attn_state
403
+ return x, attn, None
404
+
405
+ def make_generation_fast_(self, need_attn: bool = False, **kwargs):
406
+ self.need_attn = need_attn
407
+
408
+ def set_num_updates(self, num_updates):
409
+ """Set the number of parameters updates."""
410
+ self.num_updates = num_updates
artst/models/t5_transformer_lm.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+ # Based on speecht5, fairseq and espnet code bases
5
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
6
+ # --------------------------------------------------------
7
+
8
+ from fairseq.models import (
9
+ register_model_architecture,
10
+ )
11
+ from fairseq.models.transformer_lm import base_lm_architecture
12
+
13
+
14
+ # @register_model_architecture(model_name="transformer_lm", arch_name="transformer_lm_t5")
15
+ def transformer_lm_t5(args):
16
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1280)
17
+ args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 6144)
18
+ args.decoder_layers = getattr(args, "decoder_layers", 20)
19
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
20
+ args.dropout = getattr(args, "dropout", 0.1)
21
+ args.attention_dropout = getattr(args, "attention_dropout", 0.1)
22
+ args.activation_fn = getattr(args, "activation_fn", "gelu")
23
+ base_lm_architecture(args)
artst/sequence_generator.py ADDED
@@ -0,0 +1,1080 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+ # Based on speecht5, fairseq and espnet code bases
5
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
6
+ # --------------------------------------------------------
7
+
8
+ import math
9
+ from typing import Dict, List, Optional
10
+ import sys
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from fairseq import search, utils
15
+ from fairseq.data import data_utils
16
+ from fairseq.models import FairseqIncrementalDecoder
17
+ from torch import Tensor
18
+ from fairseq.ngram_repeat_block import NGramRepeatBlock
19
+ from espnet.nets.ctc_prefix_score import CTCPrefixScore
20
+ import numpy
21
+
22
+ CTC_SCORING_RATIO = 7.0
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+
25
+ class SequenceGenerator(nn.Module):
26
+ def __init__(
27
+ self,
28
+ models,
29
+ tgt_dict,
30
+ beam_size=1,
31
+ max_len_a=0,
32
+ max_len_b=200,
33
+ max_len=0,
34
+ min_len=1,
35
+ normalize_scores=True,
36
+ len_penalty=1.0,
37
+ unk_penalty=0.0,
38
+ temperature=1.0,
39
+ match_source_len=False,
40
+ no_repeat_ngram_size=0,
41
+ search_strategy=None,
42
+ eos=None,
43
+ symbols_to_strip_from_output=None,
44
+ lm_model=None,
45
+ lm_weight=1.0,
46
+ ctc_weight=0.0,
47
+ ):
48
+ """Generates translations of a given source sentence.
49
+
50
+ Args:
51
+ models (List[~fairseq.models.FairseqModel]): ensemble of models,
52
+ currently support fairseq.models.TransformerModel for scripting
53
+ beam_size (int, optional): beam width (default: 1)
54
+ max_len_a/b (int, optional): generate sequences of maximum length
55
+ ax + b, where x is the source length
56
+ max_len (int, optional): the maximum length of the generated output
57
+ (not including end-of-sentence)
58
+ min_len (int, optional): the minimum length of the generated output
59
+ (not including end-of-sentence)
60
+ normalize_scores (bool, optional): normalize scores by the length
61
+ of the output (default: True)
62
+ len_penalty (float, optional): length penalty, where <1.0 favors
63
+ shorter, >1.0 favors longer sentences (default: 1.0)
64
+ unk_penalty (float, optional): unknown word penalty, where <0
65
+ produces more unks, >0 produces fewer (default: 0.0)
66
+ temperature (float, optional): temperature, where values
67
+ >1.0 produce more uniform samples and values <1.0 produce
68
+ sharper samples (default: 1.0)
69
+ match_source_len (bool, optional): outputs should match the source
70
+ length (default: False)
71
+ """
72
+ super().__init__()
73
+ if isinstance(models, EnsembleModel):
74
+ self.model = models
75
+ else:
76
+ self.model = EnsembleModel(models)
77
+ self.tgt_dict = tgt_dict
78
+ self.pad = tgt_dict.pad()
79
+ self.unk = tgt_dict.unk()
80
+ self.eos = tgt_dict.eos() if eos is None else eos
81
+ self.blank = self.tgt_dict.index("<ctc_blank>")
82
+ self.mask = self.tgt_dict.index("<mask>")
83
+ self.mask_idxs = []
84
+ if self.tgt_dict.index("<mask>0") != self.unk:
85
+ count = 0
86
+ while self.tgt_dict.index("<mask>" + str(count)) != self.unk:
87
+ self.mask_idxs.append(self.tgt_dict.index("<mask>" + str(count)))
88
+ count += 1
89
+ self.mask_idxs = torch.tensor(self.mask_idxs)
90
+ self.symbols_to_strip_from_output = (
91
+ symbols_to_strip_from_output.union({self.eos})
92
+ if symbols_to_strip_from_output is not None
93
+ else {self.eos}
94
+ )
95
+ self.vocab_size = len(tgt_dict)
96
+ self.beam_size = beam_size
97
+ # the max beam size is the dictionary size - 1, since we never select pad
98
+ self.beam_size = min(beam_size, self.vocab_size - 1)
99
+ self.max_len_a = max_len_a
100
+ self.max_len_b = max_len_b
101
+ self.min_len = min_len
102
+ self.max_len = max_len or self.model.max_decoder_positions()
103
+
104
+ self.normalize_scores = normalize_scores
105
+ self.len_penalty = len_penalty
106
+ self.unk_penalty = unk_penalty
107
+ self.temperature = temperature
108
+ self.match_source_len = match_source_len
109
+
110
+ if no_repeat_ngram_size > 0:
111
+ self.repeat_ngram_blocker = NGramRepeatBlock(no_repeat_ngram_size)
112
+ else:
113
+ self.repeat_ngram_blocker = None
114
+
115
+ assert temperature > 0, "--temperature must be greater than 0"
116
+
117
+ self.search = (
118
+ search.BeamSearch(tgt_dict) if search_strategy is None else search_strategy
119
+ )
120
+ # We only need to set src_lengths in LengthConstrainedBeamSearch.
121
+ # As a module attribute, setting it would break in multithread
122
+ # settings when the model is shared.
123
+ self.should_set_src_lengths = (
124
+ hasattr(self.search, "needs_src_lengths") and self.search.needs_src_lengths
125
+ )
126
+
127
+ self.model.eval()
128
+
129
+ self.lm_model = lm_model
130
+ self.lm_weight = lm_weight
131
+ self.ctc_weight = ctc_weight
132
+ if self.lm_model is not None:
133
+ self.lm_model.eval()
134
+
135
+ def cuda(self):
136
+ self.model.cuda()
137
+ return self
138
+
139
+ @torch.no_grad()
140
+ def forward(
141
+ self,
142
+ sample: Dict[str, Dict[str, Tensor]],
143
+ prefix_tokens: Optional[Tensor] = None,
144
+ bos_token: Optional[int] = None,
145
+ ):
146
+ """Generate a batch of translations.
147
+
148
+ Args:
149
+ sample (dict): batch
150
+ prefix_tokens (torch.LongTensor, optional): force decoder to begin
151
+ with these tokens
152
+ bos_token (int, optional): beginning of sentence token
153
+ (default: self.eos)
154
+ """
155
+ return self._generate(sample, prefix_tokens, bos_token=bos_token)
156
+
157
+ # TODO(myleott): unused, deprecate after pytorch-translate migration
158
+ def generate_batched_itr(self, data_itr, beam_size=None, cuda=False, timer=None):
159
+ """Iterate over a batched dataset and yield individual translations.
160
+ Args:
161
+ cuda (bool, optional): use GPU for generation
162
+ timer (StopwatchMeter, optional): time generations
163
+ """
164
+ for sample in data_itr:
165
+ s = utils.move_to_cuda(sample) if cuda else sample
166
+ if "net_input" not in s:
167
+ continue
168
+ input = s["net_input"]
169
+ # model.forward normally channels prev_output_tokens into the decoder
170
+ # separately, but SequenceGenerator directly calls model.encoder
171
+ encoder_input = {
172
+ k: v for k, v in input.items() if k != "prev_output_tokens"
173
+ }
174
+ if timer is not None:
175
+ timer.start()
176
+ with torch.no_grad():
177
+ hypos = self.generate(encoder_input)
178
+ if timer is not None:
179
+ timer.stop(sum(len(h[0]["tokens"]) for h in hypos))
180
+ for i, id in enumerate(s["id"].data):
181
+ # remove padding
182
+ src = utils.strip_pad(input["src_tokens"].data[i, :], self.pad)
183
+ ref = (
184
+ utils.strip_pad(s["target"].data[i, :], self.pad)
185
+ if s["target"] is not None
186
+ else None
187
+ )
188
+ yield id, src, ref, hypos[i]
189
+
190
+ @torch.no_grad()
191
+ def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs):
192
+ """Generate translations. Match the api of other fairseq generators.
193
+
194
+ Args:
195
+ models (List[~fairseq.models.FairseqModel]): ensemble of models
196
+ sample (dict): batch
197
+ prefix_tokens (torch.LongTensor, optional): force decoder to begin
198
+ with these tokens
199
+ constraints (torch.LongTensor, optional): force decoder to include
200
+ the list of constraints
201
+ bos_token (int, optional): beginning of sentence token
202
+ (default: self.eos)
203
+ """
204
+ return self._generate(sample, **kwargs)
205
+
206
+ def _generate(
207
+ self,
208
+ sample: Dict[str, Dict[str, Tensor]],
209
+ prefix_tokens: Optional[Tensor] = None,
210
+ constraints: Optional[Tensor] = None,
211
+ bos_token: Optional[int] = None,
212
+ ):
213
+ incremental_states = torch.jit.annotate(
214
+ List[Dict[str, Dict[str, Optional[Tensor]]]],
215
+ [
216
+ torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})
217
+ for i in range(self.model.models_size)
218
+ ],
219
+ )
220
+ net_input = sample["net_input"]
221
+
222
+ if "src_tokens" in net_input:
223
+ src_tokens = net_input["src_tokens"]
224
+ # length of the source text being the character length except EndOfSentence and pad
225
+ src_lengths = (
226
+ (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
227
+ )
228
+ elif "source" in net_input:
229
+ src_tokens = net_input["source"]
230
+ src_lengths = (
231
+ net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1)
232
+ if net_input["padding_mask"] is not None
233
+ else torch.tensor(src_tokens.size(-1)).to(src_tokens)
234
+ )
235
+ elif "features" in net_input:
236
+ src_tokens = net_input["features"]
237
+ src_lengths = (
238
+ net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1)
239
+ if net_input["padding_mask"] is not None
240
+ else torch.tensor(src_tokens.size(-1)).to(src_tokens)
241
+ )
242
+ else:
243
+ raise Exception("expected src_tokens or source in net input. input keys: " + str(net_input.keys()))
244
+
245
+ # bsz: total number of sentences in beam
246
+ # Note that src_tokens may have more than 2 dimensions (i.e. audio features)
247
+ bsz, src_len = src_tokens.size()[:2]
248
+ beam_size = self.beam_size
249
+
250
+ if constraints is not None and not self.search.supports_constraints:
251
+ raise NotImplementedError(
252
+ "Target-side constraints were provided, but search method doesn't support them"
253
+ )
254
+
255
+ # Initialize constraints, when active
256
+ self.search.init_constraints(constraints, beam_size)
257
+
258
+ max_len: int = -1
259
+ if self.match_source_len:
260
+ max_len = src_lengths.max().item()
261
+ else:
262
+ max_len = min(
263
+ int(self.max_len_a * src_len + self.max_len_b),
264
+ self.max_len - 1,
265
+ )
266
+ assert (
267
+ self.min_len <= max_len
268
+ ), "min_len cannot be larger than max_len, please adjust these!"
269
+ # compute the encoder output for each beam
270
+ encoder_outs = self.model.forward_encoder(net_input)
271
+
272
+ # Get CTC lprobs and prep ctc_scorer
273
+ if self.ctc_weight > 0:
274
+ ctc_lprobs = self.model.models[0].get_normalized_probs_for_ctc(
275
+ encoder_outs[0], log_probs=True
276
+ ).contiguous().transpose(0, 1) # (B, T, C) from the encoder
277
+
278
+ hyp = {}
279
+ ctc_prefix_score = CTCPrefixScore(ctc_lprobs[0].detach().cpu().numpy(), self.blank, self.eos, numpy)
280
+ hyp["ctc_state_prev"] = ctc_prefix_score.initial_state()
281
+ hyp["ctc_score_prev"] = 0.0
282
+ ctc_beam = min(ctc_lprobs.shape[-1] - self.mask_idxs.size(-1), int(beam_size * CTC_SCORING_RATIO))
283
+ ctc_hyps = {str(self.eos): hyp}
284
+
285
+ # placeholder of indices for bsz * beam_size to hold tokens and accumulative scores
286
+ new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
287
+ new_order = new_order.to(src_tokens.device).long()
288
+ encoder_outs = self.model.reorder_encoder_out(encoder_outs, new_order)
289
+ # ensure encoder_outs is a List.
290
+ assert encoder_outs is not None
291
+
292
+ # initialize buffers
293
+ scores = (
294
+ torch.zeros(bsz * beam_size, max_len + 1).to(src_tokens).float()
295
+ ) # +1 for eos; pad is never chosen for scoring
296
+ tokens = (
297
+ torch.zeros(bsz * beam_size, max_len + 2)
298
+ .to(src_tokens)
299
+ .long()
300
+ .fill_(self.pad)
301
+ ) # +2 for eos and pad
302
+ tokens[:, 0] = self.eos if bos_token is None else bos_token
303
+ attn: Optional[Tensor] = None
304
+
305
+ # A list that indicates candidates that should be ignored.
306
+ # For example, suppose we're sampling and have already finalized 2/5
307
+ # samples. Then cands_to_ignore would mark 2 positions as being ignored,
308
+ # so that we only finalize the remaining 3 samples.
309
+ cands_to_ignore = (
310
+ torch.zeros(bsz, beam_size).to(src_tokens).eq(-1)
311
+ ) # forward and backward-compatible False mask
312
+
313
+ # list of completed sentences
314
+ finalized = torch.jit.annotate(
315
+ List[List[Dict[str, Tensor]]],
316
+ [torch.jit.annotate(List[Dict[str, Tensor]], []) for i in range(bsz)],
317
+ ) # contains lists of dictionaries of infomation about the hypothesis being finalized at each step
318
+
319
+ # a boolean array indicating if the sentence at the index is finished or not
320
+ finished = [False for i in range(bsz)]
321
+ num_remaining_sent = bsz # number of sentences remaining
322
+
323
+ # number of candidate hypos per step
324
+ cand_size = 2 * beam_size # 2 x beam size in case half are EOS
325
+
326
+ # offset arrays for converting between different indexing schemes
327
+ bbsz_offsets = (
328
+ (torch.arange(0, bsz) * beam_size)
329
+ .unsqueeze(1)
330
+ .type_as(tokens)
331
+ .to(src_tokens.device)
332
+ )
333
+ cand_offsets = torch.arange(0, cand_size).type_as(tokens).to(src_tokens.device)
334
+
335
+ reorder_state: Optional[Tensor] = None
336
+ ctc_state = None
337
+ batch_idxs: Optional[Tensor] = None
338
+
339
+ original_batch_idxs: Optional[Tensor] = None
340
+ if "id" in sample and isinstance(sample["id"], Tensor):
341
+ original_batch_idxs = sample["id"]
342
+ else:
343
+ original_batch_idxs = torch.arange(0, bsz).type_as(tokens)
344
+
345
+ for step in range(max_len + 1): # one extra step for EOS marker
346
+ # reorder decoder internal states based on the prev choice of beams
347
+ if reorder_state is not None:
348
+ if batch_idxs is not None:
349
+ # update beam indices to take into account removed sentences
350
+ corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(
351
+ batch_idxs
352
+ )
353
+ reorder_state.view(-1, beam_size).add_(
354
+ corr.unsqueeze(-1) * beam_size
355
+ )
356
+ original_batch_idxs = original_batch_idxs[batch_idxs]
357
+ self.model.reorder_incremental_state(incremental_states, reorder_state)
358
+ encoder_outs = self.model.reorder_encoder_out(
359
+ encoder_outs, reorder_state
360
+ )
361
+
362
+ lprobs, avg_attn_scores = self.model.forward_decoder(
363
+ tokens[:, : step + 1],
364
+ encoder_outs,
365
+ incremental_states,
366
+ self.temperature,
367
+ )
368
+
369
+ if self.ctc_weight > 0 and step != 0:
370
+ # lprobs[:, self.blank] = -math.inf # never select blank
371
+ ctc_lprobs = lprobs.clone()
372
+ ctc_lprobs[:, self.blank] = -math.inf # never select blank
373
+ if self.mask != self.unk:
374
+ ctc_lprobs[:, self.mask] = -math.inf # never select mask
375
+ if self.mask_idxs.size(0) != 0:
376
+ ctc_lprobs[:, self.mask_idxs] = -math.inf # never select mask
377
+ local_best_scores, local_best_ids = torch.topk(ctc_lprobs, ctc_beam, dim=-1)
378
+ for b in range(tokens.size(0)):
379
+ hyp_key = " ".join(str(x) for x in tokens[b, : step + 1].tolist())
380
+
381
+ ctc_scores, ctc_states = ctc_prefix_score(
382
+ tokens[b, : step + 1].cpu(), local_best_ids[b].cpu(), ctc_hyps[hyp_key]["ctc_state_prev"]
383
+ )
384
+ lprobs[b] = lprobs[b]
385
+ lprobs[b, local_best_ids[b]] = (1 - self.ctc_weight) * (lprobs[b, local_best_ids[b]]) + self.ctc_weight * torch.from_numpy(
386
+ ctc_scores - ctc_hyps[hyp_key]["ctc_score_prev"]
387
+ ).to(device=device)
388
+ for j in range(len(local_best_ids[b])):
389
+ ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())] = {}
390
+ ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())]["ctc_score_prev"] = ctc_scores[j]
391
+ ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())]["ctc_state_prev"] = ctc_states[j]
392
+
393
+ # local_ctc_scores, ctc_state = ctc_scorer(
394
+ # tokens[:, : step + 1], ctc_state, part_ids
395
+ # )
396
+ # lprobs += local_ctc_scores * self.ctc_weight
397
+ elif self.ctc_weight > 0 and step == 0:
398
+ ctc_lprobs = lprobs.clone()
399
+ ctc_lprobs[:, self.blank] = -math.inf # never select blank
400
+ if self.mask != self.unk:
401
+ ctc_lprobs[:, self.mask] = -math.inf # never select mask
402
+ if self.mask_idxs.size(0) != 0:
403
+ ctc_lprobs[:, self.mask_idxs] = -math.inf # never select mask
404
+ local_best_scores, local_best_ids = torch.topk(ctc_lprobs, ctc_beam, dim=-1)
405
+ for b in range(tokens.size(0)):
406
+ hyp_key = " ".join(str(x) for x in tokens[b, : step + 1].tolist())
407
+ ctc_scores, ctc_states = ctc_prefix_score(
408
+ tokens[b, : step + 1].cpu(), local_best_ids[b].cpu(), ctc_hyps[hyp_key]["ctc_state_prev"]
409
+ )
410
+ lprobs[b] = lprobs[b]
411
+ lprobs[b, local_best_ids[b]] = (1 - self.ctc_weight) * (lprobs[b, local_best_ids[b]]) + self.ctc_weight * torch.from_numpy(
412
+ ctc_scores - ctc_hyps[hyp_key]["ctc_score_prev"]
413
+ ).to(device=device)
414
+ for j in range(len(local_best_ids[b])):
415
+ if b == 0:
416
+ ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())] = {}
417
+ ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())]["ctc_score_prev"] = ctc_scores[j]
418
+ ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())]["ctc_state_prev"] = ctc_states[j]
419
+
420
+ if self.lm_model is not None:
421
+ lm_out = self.lm_model(tokens[:, : step + 1])
422
+ probs = self.lm_model.get_normalized_probs(
423
+ lm_out, log_probs=True, sample=None
424
+ )
425
+ probs = probs[:, -1, :] * self.lm_weight
426
+ lprobs[:, :probs.size(1)] += probs
427
+
428
+ # handle prefix tokens (possibly with different lengths)
429
+ if (
430
+ prefix_tokens is not None
431
+ and step < prefix_tokens.size(1)
432
+ and step < max_len
433
+ ):
434
+ lprobs, tokens, scores = self._prefix_tokens(
435
+ step, lprobs, scores, tokens, prefix_tokens, beam_size
436
+ )
437
+ elif step < self.min_len:
438
+ # minimum length constraint (does not apply if using prefix_tokens)
439
+ lprobs[:, self.eos] = -math.inf
440
+
441
+ lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs)
442
+
443
+ lprobs[:, self.pad] = -math.inf # never select pad
444
+ lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty
445
+ lprobs[:, self.blank] = -math.inf # never select blank
446
+ if self.mask != self.unk:
447
+ lprobs[:, self.mask] = -math.inf # never select mask
448
+ if self.mask_idxs.size(0) != 0:
449
+ lprobs[:, self.mask_idxs] = -math.inf # never select mask
450
+
451
+ # handle max length constraint
452
+ if step >= max_len:
453
+ lprobs[:, : self.eos] = -math.inf
454
+ lprobs[:, self.eos + 1 :] = -math.inf
455
+
456
+ # Record attention scores, only support avg_attn_scores is a Tensor
457
+ if avg_attn_scores is not None:
458
+ if attn is None:
459
+ attn = torch.empty(
460
+ bsz * beam_size, avg_attn_scores.size(1), max_len + 2
461
+ ).to(scores)
462
+ attn[:, :, step + 1].copy_(avg_attn_scores)
463
+
464
+ scores = scores.type_as(lprobs)
465
+ eos_bbsz_idx = torch.empty(0).to(
466
+ tokens
467
+ ) # indices of hypothesis ending with eos (finished sentences)
468
+ eos_scores = torch.empty(0).to(
469
+ scores
470
+ ) # scores of hypothesis ending with eos (finished sentences)
471
+
472
+ if self.should_set_src_lengths:
473
+ self.search.set_src_lengths(src_lengths)
474
+
475
+ if self.repeat_ngram_blocker is not None:
476
+ lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, beam_size, step)
477
+
478
+ # Shape: (batch, cand_size)
479
+ cand_scores, cand_indices, cand_beams = self.search.step(
480
+ step,
481
+ lprobs.view(bsz, -1, self.vocab_size),
482
+ scores.view(bsz, beam_size, -1)[:, :, :step],
483
+ tokens[:, : step + 1],
484
+ original_batch_idxs,
485
+ )
486
+
487
+ # cand_bbsz_idx contains beam indices for the top candidate
488
+ # hypotheses, with a range of values: [0, bsz*beam_size),
489
+ # and dimensions: [bsz, cand_size]
490
+ cand_bbsz_idx = cand_beams.add(bbsz_offsets)
491
+
492
+ # finalize hypotheses that end in eos
493
+ # Shape of eos_mask: (batch size, beam size)
494
+ eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf)
495
+ eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to(eos_mask)
496
+
497
+ # only consider eos when it's among the top beam_size indices
498
+ # Now we know what beam item(s) to finish
499
+ # Shape: 1d list of absolute-numbered
500
+ eos_bbsz_idx = torch.masked_select(
501
+ cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size]
502
+ )
503
+
504
+ finalized_sents: List[int] = []
505
+ if eos_bbsz_idx.numel() > 0:
506
+ eos_scores = torch.masked_select(
507
+ cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size]
508
+ )
509
+
510
+ finalized_sents = self.finalize_hypos(
511
+ step,
512
+ eos_bbsz_idx,
513
+ eos_scores,
514
+ tokens,
515
+ scores,
516
+ finalized,
517
+ finished,
518
+ beam_size,
519
+ attn,
520
+ src_lengths,
521
+ max_len,
522
+ )
523
+ num_remaining_sent -= len(finalized_sents)
524
+
525
+ assert num_remaining_sent >= 0
526
+ if num_remaining_sent == 0:
527
+ break
528
+ if self.search.stop_on_max_len and step >= max_len:
529
+ break
530
+ assert step < max_len, f"{step} < {max_len}"
531
+
532
+ # Remove finalized sentences (ones for which {beam_size}
533
+ # finished hypotheses have been generated) from the batch.
534
+ if len(finalized_sents) > 0:
535
+ new_bsz = bsz - len(finalized_sents)
536
+
537
+ # construct batch_idxs which holds indices of batches to keep for the next pass
538
+ batch_mask = torch.ones(
539
+ bsz, dtype=torch.bool, device=cand_indices.device
540
+ )
541
+ batch_mask[finalized_sents] = False
542
+ # TODO replace `nonzero(as_tuple=False)` after TorchScript supports it
543
+ batch_idxs = torch.arange(
544
+ bsz, device=cand_indices.device
545
+ ).masked_select(batch_mask)
546
+
547
+ # Choose the subset of the hypothesized constraints that will continue
548
+ self.search.prune_sentences(batch_idxs)
549
+
550
+ eos_mask = eos_mask[batch_idxs]
551
+ cand_beams = cand_beams[batch_idxs]
552
+ bbsz_offsets.resize_(new_bsz, 1)
553
+ cand_bbsz_idx = cand_beams.add(bbsz_offsets)
554
+ cand_scores = cand_scores[batch_idxs]
555
+ cand_indices = cand_indices[batch_idxs]
556
+
557
+ if prefix_tokens is not None:
558
+ prefix_tokens = prefix_tokens[batch_idxs]
559
+ src_lengths = src_lengths[batch_idxs]
560
+ cands_to_ignore = cands_to_ignore[batch_idxs]
561
+
562
+ scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
563
+ tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
564
+ if attn is not None:
565
+ attn = attn.view(bsz, -1)[batch_idxs].view(
566
+ new_bsz * beam_size, attn.size(1), -1
567
+ )
568
+ bsz = new_bsz
569
+ else:
570
+ batch_idxs = None
571
+
572
+ # Set active_mask so that values > cand_size indicate eos hypos
573
+ # and values < cand_size indicate candidate active hypos.
574
+ # After, the min values per row are the top candidate active hypos
575
+
576
+ # Rewrite the operator since the element wise or is not supported in torchscript.
577
+
578
+ eos_mask[:, :beam_size] = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size]))
579
+ active_mask = torch.add(
580
+ eos_mask.type_as(cand_offsets) * cand_size,
581
+ cand_offsets[: eos_mask.size(1)],
582
+ )
583
+
584
+ # get the top beam_size active hypotheses, which are just
585
+ # the hypos with the smallest values in active_mask.
586
+ # {active_hypos} indicates which {beam_size} hypotheses
587
+ # from the list of {2 * beam_size} candidates were
588
+ # selected. Shapes: (batch size, beam size)
589
+ new_cands_to_ignore, active_hypos = torch.topk(
590
+ active_mask, k=beam_size, dim=1, largest=False
591
+ )
592
+
593
+ # update cands_to_ignore to ignore any finalized hypos.
594
+ cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size]
595
+ # Make sure there is at least one active item for each sentence in the batch.
596
+ assert (~cands_to_ignore).any(dim=1).all()
597
+
598
+ # update cands_to_ignore to ignore any finalized hypos
599
+
600
+ # {active_bbsz_idx} denotes which beam number is continued for each new hypothesis (a beam
601
+ # can be selected more than once).
602
+ active_bbsz_idx = torch.gather(cand_bbsz_idx, dim=1, index=active_hypos)
603
+ active_scores = torch.gather(cand_scores, dim=1, index=active_hypos)
604
+
605
+ active_bbsz_idx = active_bbsz_idx.view(-1)
606
+ active_scores = active_scores.view(-1)
607
+
608
+ # copy tokens and scores for active hypotheses
609
+
610
+ # Set the tokens for each beam (can select the same row more than once)
611
+ tokens[:, : step + 1] = torch.index_select(
612
+ tokens[:, : step + 1], dim=0, index=active_bbsz_idx
613
+ )
614
+ # Select the next token for each of them
615
+ tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather(
616
+ cand_indices, dim=1, index=active_hypos
617
+ )
618
+ if step > 0:
619
+ scores[:, :step] = torch.index_select(
620
+ scores[:, :step], dim=0, index=active_bbsz_idx
621
+ )
622
+ scores.view(bsz, beam_size, -1)[:, :, step] = torch.gather(
623
+ cand_scores, dim=1, index=active_hypos
624
+ )
625
+
626
+ # Update constraints based on which candidates were selected for the next beam
627
+ self.search.update_constraints(active_hypos)
628
+
629
+ # copy attention for active hypotheses
630
+ if attn is not None:
631
+ attn[:, :, : step + 2] = torch.index_select(
632
+ attn[:, :, : step + 2], dim=0, index=active_bbsz_idx
633
+ )
634
+
635
+ # reorder incremental state in decoder
636
+ reorder_state = active_bbsz_idx
637
+
638
+ # if self.ctc_weight > 0:
639
+ # accum_best_id = torch.gather(cand_indices, dim=1, index=active_hypos)
640
+ # ctc_state = ctc_scorer.index_select_state(
641
+ # ctc_state, accum_best_id
642
+ # )
643
+
644
+ # sort by score descending
645
+ for sent in range(len(finalized)):
646
+ scores = torch.tensor(
647
+ [float(elem["score"].item()) for elem in finalized[sent]]
648
+ )
649
+ _, sorted_scores_indices = torch.sort(scores, descending=True)
650
+ finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices]
651
+ finalized[sent] = torch.jit.annotate(
652
+ List[Dict[str, Tensor]], finalized[sent]
653
+ )
654
+ return finalized
655
+
656
+ def _prefix_tokens(
657
+ self, step: int, lprobs, scores, tokens, prefix_tokens, beam_size: int
658
+ ):
659
+ """Handle prefix tokens"""
660
+ prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1)
661
+ prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1))
662
+ prefix_mask = prefix_toks.ne(self.pad)
663
+ lprobs[prefix_mask] = torch.min(prefix_lprobs) - 1
664
+ lprobs[prefix_mask] = lprobs[prefix_mask].scatter(
665
+ -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask]
666
+ )
667
+ # if prefix includes eos, then we should make sure tokens and
668
+ # scores are the same across all beams
669
+ eos_mask = prefix_toks.eq(self.eos)
670
+ if eos_mask.any():
671
+ # validate that the first beam matches the prefix
672
+ first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[
673
+ :, 0, 1 : step + 1
674
+ ]
675
+ eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0]
676
+ target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step]
677
+ assert (first_beam == target_prefix).all()
678
+
679
+ # copy tokens, scores and lprobs from the first beam to all beams
680
+ tokens = self.replicate_first_beam(tokens, eos_mask_batch_dim, beam_size)
681
+ scores = self.replicate_first_beam(scores, eos_mask_batch_dim, beam_size)
682
+ lprobs = self.replicate_first_beam(lprobs, eos_mask_batch_dim, beam_size)
683
+ return lprobs, tokens, scores
684
+
685
+ def replicate_first_beam(self, tensor, mask, beam_size: int):
686
+ tensor = tensor.view(-1, beam_size, tensor.size(-1))
687
+ tensor[mask] = tensor[mask][:, :1, :]
688
+ return tensor.view(-1, tensor.size(-1))
689
+
690
+ def finalize_hypos(
691
+ self,
692
+ step: int,
693
+ bbsz_idx,
694
+ eos_scores,
695
+ tokens,
696
+ scores,
697
+ finalized: List[List[Dict[str, Tensor]]],
698
+ finished: List[bool],
699
+ beam_size: int,
700
+ attn: Optional[Tensor],
701
+ src_lengths,
702
+ max_len: int,
703
+ ):
704
+ """Finalize hypothesis, store finalized information in `finalized`, and change `finished` accordingly.
705
+ A sentence is finalized when {beam_size} finished items have been collected for it.
706
+ Returns number of sentences (not beam items) being finalized.
707
+ These will be removed from the batch and not processed further.
708
+ Args:
709
+ bbsz_idx (Tensor):
710
+ """
711
+ assert bbsz_idx.numel() == eos_scores.numel()
712
+
713
+ # clone relevant token and attention tensors.
714
+ # tokens is (batch * beam, max_len). So the index_select
715
+ # gets the newly EOS rows, then selects cols 1..{step + 2}
716
+ tokens_clone = tokens.index_select(0, bbsz_idx)[
717
+ :, 1 : step + 2
718
+ ] # skip the first index, which is EOS
719
+
720
+ tokens_clone[:, step] = self.eos
721
+ attn_clone = (
722
+ attn.index_select(0, bbsz_idx)[:, :, 1 : step + 2]
723
+ if attn is not None
724
+ else None
725
+ )
726
+
727
+ # compute scores per token position
728
+ pos_scores = scores.index_select(0, bbsz_idx)[:, : step + 1]
729
+ pos_scores[:, step] = eos_scores
730
+ # convert from cumulative to per-position scores
731
+ pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]
732
+
733
+ # normalize sentence-level scores
734
+ if self.normalize_scores:
735
+ eos_scores /= (step + 1) ** self.len_penalty
736
+
737
+ # cum_unfin records which sentences in the batch are finished.
738
+ # It helps match indexing between (a) the original sentences
739
+ # in the batch and (b) the current, possibly-reduced set of
740
+ # sentences.
741
+ cum_unfin: List[int] = []
742
+ prev = 0
743
+ for f in finished:
744
+ if f:
745
+ prev += 1
746
+ else:
747
+ cum_unfin.append(prev)
748
+ cum_fin_tensor = torch.tensor(cum_unfin, dtype=torch.int).to(bbsz_idx)
749
+
750
+ unfin_idx = bbsz_idx // beam_size
751
+ sent = unfin_idx + torch.index_select(cum_fin_tensor, 0, unfin_idx)
752
+
753
+ # Create a set of "{sent}{unfin_idx}", where
754
+ # "unfin_idx" is the index in the current (possibly reduced)
755
+ # list of sentences, and "sent" is the index in the original,
756
+ # unreduced batch
757
+ # For every finished beam item
758
+ # sentence index in the current (possibly reduced) batch
759
+ seen = (sent << 32) + unfin_idx
760
+ unique_seen: List[int] = torch.unique(seen).tolist()
761
+
762
+ if self.match_source_len:
763
+ condition = step > torch.index_select(src_lengths, 0, unfin_idx)
764
+ eos_scores = torch.where(condition, torch.tensor(-math.inf), eos_scores)
765
+ sent_list: List[int] = sent.tolist()
766
+ for i in range(bbsz_idx.size()[0]):
767
+ # An input sentence (among those in a batch) is finished when
768
+ # beam_size hypotheses have been collected for it
769
+ if len(finalized[sent_list[i]]) < beam_size:
770
+ if attn_clone is not None:
771
+ # remove padding tokens from attn scores
772
+ hypo_attn = attn_clone[i]
773
+ else:
774
+ hypo_attn = torch.empty(0)
775
+
776
+ finalized[sent_list[i]].append(
777
+ {
778
+ "tokens": tokens_clone[i],
779
+ "score": eos_scores[i],
780
+ "attention": hypo_attn, # src_len x tgt_len
781
+ "alignment": torch.empty(0),
782
+ "positional_scores": pos_scores[i],
783
+ }
784
+ )
785
+
786
+ newly_finished: List[int] = []
787
+ for unique_s in unique_seen:
788
+ # check termination conditions for this sentence
789
+ unique_sent: int = unique_s >> 32
790
+ unique_unfin_idx: int = unique_s - (unique_sent << 32)
791
+
792
+ if not finished[unique_sent] and self.is_finished(
793
+ step, unique_unfin_idx, max_len, len(finalized[unique_sent]), beam_size
794
+ ):
795
+ finished[unique_sent] = True
796
+ newly_finished.append(unique_unfin_idx)
797
+
798
+ return newly_finished
799
+
800
+ def is_finished(
801
+ self,
802
+ step: int,
803
+ unfin_idx: int,
804
+ max_len: int,
805
+ finalized_sent_len: int,
806
+ beam_size: int,
807
+ ):
808
+ """
809
+ Check whether decoding for a sentence is finished, which
810
+ occurs when the list of finalized sentences has reached the
811
+ beam size, or when we reach the maximum length.
812
+ """
813
+ assert finalized_sent_len <= beam_size
814
+ if finalized_sent_len == beam_size or step == max_len:
815
+ return True
816
+ return False
817
+
818
+
819
+ class EnsembleModel(nn.Module):
820
+ """A wrapper around an ensemble of models."""
821
+
822
+ def __init__(self, models):
823
+ super().__init__()
824
+ self.models_size = len(models)
825
+ # method '__len__' is not supported in ModuleList for torch script
826
+ self.single_model = models[0]
827
+ self.models = nn.ModuleList(models)
828
+
829
+ self.has_incremental: bool = False
830
+ if all(
831
+ hasattr(m, "decoder") and isinstance(m.decoder, FairseqIncrementalDecoder)
832
+ for m in models
833
+ ):
834
+ self.has_incremental = True
835
+
836
+ def forward(self):
837
+ pass
838
+
839
+ def has_encoder(self):
840
+ return hasattr(self.single_model, "encoder")
841
+
842
+ def is_t5_structure(self):
843
+ t5_structure = hasattr(self.single_model, "text_encoder_prenet") and hasattr(self.single_model, "speech_encoder_prenet") or \
844
+ hasattr(self.single_model, "encoder_prenet") and hasattr(self.single_model, "encoder_prenet")
845
+ return t5_structure
846
+
847
+ def has_incremental_states(self):
848
+ return self.has_incremental
849
+
850
+ def max_decoder_positions(self):
851
+ return min([m.max_decoder_positions() for m in self.models if hasattr(m, "max_decoder_positions")] + [sys.maxsize])
852
+
853
+ @torch.jit.export
854
+ def forward_encoder(self, net_input: Dict[str, Tensor]):
855
+ if not self.has_encoder():
856
+ return None
857
+ elif self.is_t5_structure():
858
+ return [model.forward_encoder_torchscript(net_input) for model in self.models]
859
+ else:
860
+ return [model.encoder.forward_torchscript(net_input) for model in self.models]
861
+
862
+ @torch.jit.export
863
+ def forward_decoder(
864
+ self,
865
+ tokens,
866
+ encoder_outs: List[Dict[str, List[Tensor]]],
867
+ incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]],
868
+ temperature: float = 1.0,
869
+ ):
870
+ log_probs = []
871
+ avg_attn: Optional[Tensor] = None
872
+ encoder_out: Optional[Dict[str, List[Tensor]]] = None
873
+ for i, model in enumerate(self.models):
874
+ if self.has_encoder():
875
+ encoder_out = encoder_outs[i]
876
+ # decode each model
877
+ if self.has_incremental_states():
878
+ if self.is_t5_structure:
879
+ decoder_out = model.forward_decoder(
880
+ tokens,
881
+ encoder_out=encoder_out,
882
+ incremental_state=incremental_states[i]
883
+ )
884
+ else:
885
+ decoder_out = model.decoder.forward(
886
+ tokens,
887
+ encoder_out=encoder_out,
888
+ incremental_state=incremental_states[i],
889
+ )
890
+ else:
891
+ if hasattr(model, "decoder"):
892
+ decoder_out = model.decoder.forward(tokens, encoder_out=encoder_out)
893
+ else:
894
+ decoder_out = model.forward(tokens)
895
+
896
+ attn: Optional[Tensor] = None
897
+ decoder_len = len(decoder_out)
898
+ if decoder_len > 1 and decoder_out[1] is not None:
899
+ if isinstance(decoder_out[1], Tensor):
900
+ attn = decoder_out[1]
901
+ else:
902
+ attn_holder = decoder_out[1]["attn"]
903
+ if isinstance(attn_holder, Tensor):
904
+ attn = attn_holder
905
+ elif attn_holder is not None:
906
+ attn = attn_holder[0]
907
+ if attn is not None:
908
+ attn = attn[:, -1, :]
909
+
910
+ decoder_out_tuple = (
911
+ decoder_out[0][:, -1:, :].div_(temperature),
912
+ None if decoder_len <= 1 else decoder_out[1],
913
+ )
914
+ probs = model.get_normalized_probs(
915
+ decoder_out_tuple, log_probs=True, sample=None
916
+ )
917
+ probs = probs[:, -1, :]
918
+ if self.models_size == 1:
919
+ return probs, attn
920
+
921
+ log_probs.append(probs)
922
+ if attn is not None:
923
+ if avg_attn is None:
924
+ avg_attn = attn
925
+ else:
926
+ avg_attn.add_(attn)
927
+
928
+ avg_probs = torch.logsumexp(torch.stack(log_probs, dim=0), dim=0) - math.log(
929
+ self.models_size
930
+ )
931
+
932
+ if avg_attn is not None:
933
+ avg_attn.div_(self.models_size)
934
+ return avg_probs, avg_attn
935
+
936
+ @torch.jit.export
937
+ def reorder_encoder_out(
938
+ self, encoder_outs: Optional[List[Dict[str, List[Tensor]]]], new_order
939
+ ):
940
+ """
941
+ Reorder encoder output according to *new_order*.
942
+
943
+ Args:
944
+ encoder_out: output from the ``forward()`` method
945
+ new_order (LongTensor): desired order
946
+
947
+ Returns:
948
+ *encoder_out* rearranged according to *new_order*
949
+ """
950
+ new_outs: List[Dict[str, List[Tensor]]] = []
951
+ if not self.has_encoder():
952
+ return new_outs
953
+ for i, model in enumerate(self.models):
954
+ assert encoder_outs is not None
955
+ new_outs.append(
956
+ model.encoder.reorder_encoder_out(encoder_outs[i], new_order)
957
+ )
958
+ return new_outs
959
+
960
+ @torch.jit.export
961
+ def reorder_incremental_state(
962
+ self,
963
+ incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]],
964
+ new_order,
965
+ ):
966
+ if not self.has_incremental_states():
967
+ return
968
+ for i, model in enumerate(self.models):
969
+ model.decoder.reorder_incremental_state_scripting(
970
+ incremental_states[i], new_order
971
+ )
972
+
973
+
974
+ class SequenceGeneratorWithAlignment(SequenceGenerator):
975
+ def __init__(
976
+ self, models, tgt_dict, left_pad_target=False, print_alignment="hard", **kwargs
977
+ ):
978
+ """Generates translations of a given source sentence.
979
+
980
+ Produces alignments following "Jointly Learning to Align and
981
+ Translate with Transformer Models" (Garg et al., EMNLP 2019).
982
+
983
+ Args:
984
+ left_pad_target (bool, optional): Whether or not the
985
+ hypothesis should be left padded or not when they are
986
+ teacher forced for generating alignments.
987
+ """
988
+ super().__init__(EnsembleModelWithAlignment(models), tgt_dict, **kwargs)
989
+ self.left_pad_target = left_pad_target
990
+
991
+ if print_alignment == "hard":
992
+ self.extract_alignment = utils.extract_hard_alignment
993
+ elif print_alignment == "soft":
994
+ self.extract_alignment = utils.extract_soft_alignment
995
+
996
+ @torch.no_grad()
997
+ def generate(self, models, sample, **kwargs):
998
+ finalized = super()._generate(sample, **kwargs)
999
+
1000
+ src_tokens = sample["net_input"]["src_tokens"]
1001
+ bsz = src_tokens.shape[0]
1002
+ beam_size = self.beam_size
1003
+ (
1004
+ src_tokens,
1005
+ src_lengths,
1006
+ prev_output_tokens,
1007
+ tgt_tokens,
1008
+ ) = self._prepare_batch_for_alignment(sample, finalized)
1009
+ if any(getattr(m, "full_context_alignment", False) for m in self.model.models):
1010
+ attn = self.model.forward_align(src_tokens, src_lengths, prev_output_tokens)
1011
+ else:
1012
+ attn = [
1013
+ finalized[i // beam_size][i % beam_size]["attention"].transpose(1, 0)
1014
+ for i in range(bsz * beam_size)
1015
+ ]
1016
+
1017
+ if src_tokens.device != "cpu":
1018
+ src_tokens = src_tokens.to("cpu")
1019
+ tgt_tokens = tgt_tokens.to("cpu")
1020
+ attn = [i.to("cpu") for i in attn]
1021
+
1022
+ # Process the attn matrix to extract hard alignments.
1023
+ for i in range(bsz * beam_size):
1024
+ alignment = self.extract_alignment(
1025
+ attn[i], src_tokens[i], tgt_tokens[i], self.pad, self.eos
1026
+ )
1027
+ finalized[i // beam_size][i % beam_size]["alignment"] = alignment
1028
+ return finalized
1029
+
1030
+ def _prepare_batch_for_alignment(self, sample, hypothesis):
1031
+ src_tokens = sample["net_input"]["src_tokens"]
1032
+ bsz = src_tokens.shape[0]
1033
+ src_tokens = (
1034
+ src_tokens[:, None, :]
1035
+ .expand(-1, self.beam_size, -1)
1036
+ .contiguous()
1037
+ .view(bsz * self.beam_size, -1)
1038
+ )
1039
+ src_lengths = sample["net_input"]["src_lengths"]
1040
+ src_lengths = (
1041
+ src_lengths[:, None]
1042
+ .expand(-1, self.beam_size)
1043
+ .contiguous()
1044
+ .view(bsz * self.beam_size)
1045
+ )
1046
+ prev_output_tokens = data_utils.collate_tokens(
1047
+ [beam["tokens"] for example in hypothesis for beam in example],
1048
+ self.pad,
1049
+ self.eos,
1050
+ self.left_pad_target,
1051
+ move_eos_to_beginning=True,
1052
+ )
1053
+ tgt_tokens = data_utils.collate_tokens(
1054
+ [beam["tokens"] for example in hypothesis for beam in example],
1055
+ self.pad,
1056
+ self.eos,
1057
+ self.left_pad_target,
1058
+ move_eos_to_beginning=False,
1059
+ )
1060
+ return src_tokens, src_lengths, prev_output_tokens, tgt_tokens
1061
+
1062
+
1063
+ class EnsembleModelWithAlignment(EnsembleModel):
1064
+ """A wrapper around an ensemble of models."""
1065
+
1066
+ def __init__(self, models):
1067
+ super().__init__(models)
1068
+
1069
+ def forward_align(self, src_tokens, src_lengths, prev_output_tokens):
1070
+ avg_attn = None
1071
+ for model in self.models:
1072
+ decoder_out = model(src_tokens, src_lengths, prev_output_tokens)
1073
+ attn = decoder_out[1]["attn"][0]
1074
+ if avg_attn is None:
1075
+ avg_attn = attn
1076
+ else:
1077
+ avg_attn.add_(attn)
1078
+ if len(self.models) > 1:
1079
+ avg_attn.div_(len(self.models))
1080
+ return avg_attn
artst/tasks/__init__.py ADDED
File without changes
artst/tasks/artst.py ADDED
@@ -0,0 +1,711 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+
5
+ # Based on speecht5, fairseq and espnet code bases
6
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
7
+ # --------------------------------------------------------
8
+
9
+ import logging
10
+ import os.path as op
11
+ from argparse import Namespace
12
+ from collections import OrderedDict
13
+
14
+ import torch
15
+ from fairseq.data import (
16
+ Dictionary,
17
+ encoders,
18
+ PrependTokenDataset,
19
+ AppendTokenDataset,
20
+ data_utils,
21
+ StripTokenDataset,
22
+ TokenBlockDataset,
23
+ )
24
+ from fairseq.data.encoders.utils import get_whole_word_mask
25
+ from fairseq import utils
26
+ from artst.data.multitask_dataset import MultitaskDataset
27
+ from artst.data.speech_to_text_dataset import SpeechToTextDataset
28
+ from artst.data.text_to_speech_dataset import TextToSpeechDataset
29
+ from artst.data.speech_to_speech_dataset import SpeechToSpeechDataset
30
+ from artst.data.speech_to_class_dataset import SpeechToClassDataset
31
+ from artst.data.speech_dataset import SpeechPretrainDataset
32
+ from artst.data.text_dataset import TextPretrainDataset
33
+ from fairseq.data.shorten_dataset import maybe_shorten_dataset
34
+ from fairseq.tasks import LegacyFairseqTask, register_task
35
+ from fairseq.tasks.hubert_pretraining import LabelEncoder
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+ TASK_NAME = ["s2t", "t2s", "s2s", "s2c", "pretrain"]
40
+
41
+ @register_task("artst")
42
+ class ArTSTTask(LegacyFairseqTask):
43
+ @staticmethod
44
+ def add_args(parser):
45
+ parser.add_argument("data", help="manifest root path")
46
+ parser.add_argument(
47
+ "--config-yaml",
48
+ type=str,
49
+ default="config.yaml",
50
+ help="Configuration YAML filename (under manifest root)",
51
+ )
52
+ parser.add_argument(
53
+ "--max-speech-sample-size",
54
+ default=None,
55
+ type=int,
56
+ metavar="N",
57
+ help="max speech sample size",
58
+ )
59
+ parser.add_argument(
60
+ "--min-speech-sample-size",
61
+ default=None,
62
+ type=int,
63
+ metavar="N",
64
+ help="min speech sample size",
65
+ )
66
+ parser.add_argument(
67
+ "--max-speech-positions",
68
+ default=4000,
69
+ type=int,
70
+ metavar="N",
71
+ help="max number of tokens in the source sequence",
72
+ )
73
+ parser.add_argument(
74
+ "--max-text-positions",
75
+ default=450,
76
+ type=int,
77
+ metavar="N",
78
+ help="max number of tokens in the target sequence",
79
+ )
80
+ parser.add_argument(
81
+ '--t5-task',
82
+ choices=TASK_NAME,
83
+ help='task for training'
84
+ )
85
+ parser.add_argument(
86
+ "--bpe-tokenizer",
87
+ type=str,
88
+ default=None,
89
+ help="bpe tokenizer for s2t",
90
+ )
91
+ # Speaker Identification (SID)
92
+ parser.add_argument(
93
+ "--finetune-from-modules",
94
+ default=None,
95
+ # choices=[
96
+ # "encoder-decoder", "encoder", "decoder",
97
+ # "speech_encoder_prenet-encoder-decoder-text_decoder_prenet-text_decoder_postnet", # ASR, T5 SID
98
+ # "speech_encoder_prenet-encoder-decoder-text_decoder_prenet-speaker_decoder_postnet", # SID
99
+ # "speech_encoder_prenet-encoder-decoder-speech_decoder_prenet-speech_decoder_postnet", # VC, SE
100
+ # "text_encoder_prenet-encoder-decoder-speech_decoder_prenet-speech_decoder_postnet", # TTS
101
+ # ],
102
+ help="If set, using part modules of finetune model.",
103
+ )
104
+ parser.add_argument(
105
+ "--finetune-out-of-modules",
106
+ default=None,
107
+ # choices=[
108
+ # "speaker_decoder_postnet", # SID
109
+ # "speech_decoder_postnet", # SE with reduction factor 1
110
+ # ],
111
+ help="If set, remove part modules of finetune model.",
112
+ )
113
+ # BART
114
+ parser.add_argument(
115
+ "--shorten-method",
116
+ default="none",
117
+ choices=["none", "truncate", "random_crop"],
118
+ help="if not none, shorten sequences that exceed --tokens-per-sample",
119
+ )
120
+ parser.add_argument(
121
+ "--shorten-data-split-list",
122
+ default="",
123
+ help="comma-separated list of dataset splits to apply shortening to, "
124
+ 'e.g., "train,valid" (default: all dataset splits)',
125
+ )
126
+
127
+ parser.add_argument(
128
+ "--tokens-per-sample",
129
+ default=512,
130
+ type=int,
131
+ help="max number of total tokens over all segments"
132
+ " per sample for dataset",
133
+ )
134
+ parser.add_argument(
135
+ "--sample-break-mode",
136
+ default="eos",
137
+ type=str,
138
+ help="mode for breaking sentence",
139
+ )
140
+ parser.add_argument(
141
+ "--mask",
142
+ default=0.3,
143
+ type=float,
144
+ help="fraction of words/subwords that will be masked",
145
+ )
146
+ parser.add_argument(
147
+ "--mask-random",
148
+ default=0.1,
149
+ type=float,
150
+ help="instead of using [MASK], use random token this often",
151
+ )
152
+ parser.add_argument(
153
+ "--insert",
154
+ default=0.0,
155
+ type=float,
156
+ help="insert this percentage of additional random tokens",
157
+ )
158
+ parser.add_argument(
159
+ "--permute",
160
+ default=0.0,
161
+ type=float,
162
+ help="take this proportion of subwords and permute them",
163
+ )
164
+ parser.add_argument(
165
+ "--rotate",
166
+ default=0.0,
167
+ type=float,
168
+ help="rotate this proportion of inputs",
169
+ )
170
+ parser.add_argument(
171
+ "--poisson-lambda",
172
+ default=3.5,
173
+ type=float,
174
+ help="randomly shuffle sentences for this proportion of inputs",
175
+ )
176
+ parser.add_argument(
177
+ "--permute-sentences",
178
+ default=0.0,
179
+ type=float,
180
+ help="shuffle this proportion of sentences in all inputs",
181
+ )
182
+ # parser.add_argument(
183
+ # "--mask-length",
184
+ # default="span-poisson",
185
+ # type=str,
186
+ # choices=["subword", "word", "span-poisson"],
187
+ # help="mask length to choose",
188
+ # )
189
+ parser.add_argument(
190
+ "--replace-length",
191
+ default=1,
192
+ type=int,
193
+ help="when masking N tokens, replace with 0, 1, or N tokens (use -1 for N)",
194
+ )
195
+ parser.add_argument(
196
+ "--iid-noise-target",
197
+ action="store_true",
198
+ help="whether to use t5 form target",
199
+ )
200
+ # Hubert
201
+ parser.add_argument(
202
+ "--hubert-labels",
203
+ nargs="*",
204
+ type=str,
205
+ default=['km'],
206
+ help="extension of the label files to load, frame-level labels for pre-training, and sequence-level label for fine-tuning",
207
+ )
208
+ parser.add_argument(
209
+ "--hubert-label-dir",
210
+ type=str,
211
+ default=None,
212
+ help="if set, looks for labels in this directory instead",
213
+ )
214
+ parser.add_argument(
215
+ "--sample-rate",
216
+ default=100,
217
+ type=float,
218
+ help="target sample rate. audio files will be up/down sampled to this rate",
219
+ )
220
+ parser.add_argument(
221
+ "--label-rates",
222
+ default=-1,
223
+ type=float,
224
+ help="if set, looks for labels in this directory instead",
225
+ )
226
+ parser.add_argument(
227
+ "--normalize",
228
+ action="store_true",
229
+ help="if set, normalizes input to have 0 mean and unit variance",
230
+ )
231
+ parser.add_argument(
232
+ "--enable-padding",
233
+ action="store_true",
234
+ help="pad shorter samples instead of cropping",
235
+ )
236
+ parser.add_argument(
237
+ "--pad-audio",
238
+ action="store_true",
239
+ help="pad audio to the longest one in the batch if true",
240
+ )
241
+ parser.add_argument(
242
+ "--random-crop",
243
+ action="store_true",
244
+ help="always crop from the beginning if false",
245
+ )
246
+ parser.add_argument(
247
+ "--single-target",
248
+ action="store_true",
249
+ help="if set, AddTargetDatasets outputs same keys "
250
+ "as AddTargetDataset",
251
+ )
252
+ parser.add_argument(
253
+ "--batch-ratio",
254
+ default=None,
255
+ type=str,
256
+ help="ratio of bach size for each dataset",
257
+ )
258
+ parser.add_argument(
259
+ "--sample-ratios",
260
+ default=None,
261
+ type=str,
262
+ help="ratio of sample for each dataset",
263
+ )
264
+ parser.add_argument(
265
+ "--ctc-weight",
266
+ type=float,
267
+ default=0.0,
268
+ help="ctc weight for inference",
269
+ )
270
+ parser.add_argument(
271
+ "--inference-speech",
272
+ type=bool,
273
+ default=False,
274
+ help="inference for TTS",
275
+ )
276
+
277
+ def __init__(self, args, dicts, config):
278
+ super().__init__(args)
279
+ self.dicts = dicts
280
+ self.config = config
281
+ self.t5_task = args.t5_task
282
+ # Used for filter size
283
+ if self.t5_task in ['s2t', 't2s', 's2s', 's2c']:
284
+ self.max_pos = [self.args.max_speech_positions * 256]
285
+ elif self.t5_task == 'pretrain':
286
+ self.max_pos = [self.args.max_speech_positions * 256, self.args.max_text_positions]
287
+
288
+ self.mask_idx = self.dicts["text"].add_symbol("<mask>")
289
+ # add blank token for ctc
290
+ # if args.ctc_weight > 0:
291
+ self.blank_symbol_idx = self.dicts["text"].add_symbol("<ctc_blank>")
292
+ self.blank_symbol = "<ctc_blank>"
293
+
294
+ # add mask token
295
+ if hasattr(args, "iid_noise_target") and args.iid_noise_target:
296
+ self.uni_mask_idxs = []
297
+ for i in range(600):
298
+ self.uni_mask_idxs.append(self.dicts["text"].add_symbol("<mask>" + str(i)))
299
+ self.uni_mask_idxs = torch.tensor(self.uni_mask_idxs)
300
+
301
+ self.seed = args.seed
302
+
303
+ @classmethod
304
+ def setup_task(cls, args, **kwargs):
305
+ # load dictionaries and config
306
+ dicts = OrderedDict()
307
+ if args.t5_task == 'pretrain' and not hasattr(args, "shuffle_instance"):
308
+ args.shuffle_instance = False
309
+
310
+ # Prepare config
311
+ config = None
312
+ logger.info('No config file for ' + args.t5_task)
313
+
314
+ if args.t5_task == "pretrain":
315
+ dicts["hubert"] = [Dictionary.load(f"{args.hubert_label_dir}/dict.{label}.txt") for label in args.hubert_labels]
316
+ dicts["text"] = Dictionary.load(op.join(args.data, "dict.txt"))
317
+ else:
318
+ if config is None:
319
+ dicts["text"] = Dictionary.load(op.join(args.data, "dict.txt"))
320
+ else:
321
+ dicts["text"] = Dictionary.load(op.join(args.data, config.vocab_filename))
322
+
323
+ return cls(args, dicts, config)
324
+
325
+ def build_criterion(self, args):
326
+ from fairseq import criterions
327
+ return criterions.build_criterion(args, self)
328
+
329
+ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
330
+ sample_ratios = []
331
+ if self.t5_task == "s2t":
332
+ ## For speech to text task
333
+ bpe_tokenizer = self.build_bpe(self.args)
334
+ manifest = f"{self.args.data}/{split}.tsv"
335
+ procs = [LabelEncoder(self.dicts["text"])]
336
+ paths = [f"{self.args.hubert_label_dir}/{split}.txt"]
337
+ # Hawau: view dataset...
338
+ logger.info(f"Manifest: {manifest}")
339
+ # logger.info(f"Paths: {paths}")
340
+ self.datasets[split] = SpeechToTextDataset(
341
+ manifest,
342
+ sample_rate=self.args.sample_rate,
343
+ label_paths=paths,
344
+ label_processors=procs,
345
+ max_keep_sample_size=self.max_pos[0] if self.args.max_speech_sample_size is None else self.args.max_speech_sample_size,
346
+ min_keep_sample_size=self.args.min_speech_sample_size,
347
+ normalize=self.args.normalize,
348
+ store_labels=False,
349
+ tgt_dict=self.dicts["text"],
350
+ tokenizer=bpe_tokenizer,
351
+ )
352
+ elif self.t5_task == "t2s":
353
+ ## For text to speech task
354
+ from fairseq.data import ConcatDataset
355
+ bpe_tokenizer = self.build_bpe(self.args)
356
+ procs = [LabelEncoder(self.dicts["text"])]
357
+ t2s_datasets = [
358
+ TextToSpeechDataset(
359
+ manifest_path=f"{self.args.data}/{name}.tsv",
360
+ sample_rate=self.args.sample_rate,
361
+ label_paths=[f"{self.args.hubert_label_dir}/{name}.txt"],
362
+ label_processors=procs,
363
+ max_keep_sample_size=self.max_pos[0],
364
+ normalize=self.args.normalize,
365
+ store_labels=False,
366
+ src_dict=self.dicts["text"],
367
+ tokenizer=bpe_tokenizer,
368
+ reduction_factor=self.args.reduction_factor,
369
+ inference=self.args.inference_speech,
370
+ )
371
+ for name in split.split(",")
372
+ ]
373
+ self.datasets[split] = ConcatDataset(t2s_datasets) if len(t2s_datasets) > 1 else t2s_datasets[0]
374
+ elif self.t5_task == "s2s":
375
+ manifest = f"{self.args.data}/{split}.tsv"
376
+ self.datasets[split] = SpeechToSpeechDataset(
377
+ manifest_path=manifest,
378
+ sample_rate=self.args.sample_rate,
379
+ max_keep_sample_size=self.max_pos[0] if self.args.max_speech_sample_size is None else self.args.max_speech_sample_size,
380
+ min_keep_sample_size=self.args.min_speech_sample_size,
381
+ normalize=self.args.normalize,
382
+ reduction_factor=self.args.reduction_factor,
383
+ )
384
+ elif self.t5_task == "s2c":
385
+ is_train_split = ("train" in split)
386
+ is_valid_split = ("valid" in split)
387
+ if is_train_split:
388
+ max_length = 51200
389
+ elif is_valid_split:
390
+ max_length = 76800
391
+ else:
392
+ max_length = 2560000
393
+ manifest = op.join(f"{self.args.data}", f"{split}.tsv")
394
+ procs = LabelEncoder(self.dicts["text"]) # map speaker to id
395
+ self.datasets[split] = SpeechToClassDataset(
396
+ manifest_path=manifest,
397
+ sample_rate=self.args.sample_rate,
398
+ label_processors=procs,
399
+ max_keep_sample_size=self.max_pos[0] if self.args.max_speech_sample_size is None else self.args.max_speech_sample_size,
400
+ min_keep_sample_size=self.args.min_speech_sample_size,
401
+ normalize=self.args.normalize,
402
+ tgt_dict=self.dicts["text"],
403
+ max_length=max_length
404
+ )
405
+ elif self.t5_task == "pretrain":
406
+ is_train_split = ("train" in split)
407
+ pretrain_datasets = []
408
+ speech_split, text_split = split.split('|')
409
+
410
+ ## Speech pre-train
411
+ manifest = f"{self.args.data}/{speech_split}.tsv"
412
+ dicts = self.dicts["hubert"]
413
+ pad_list = [dict.pad() for dict in dicts]
414
+ eos_list = [dict.eos() for dict in dicts]
415
+ procs = [LabelEncoder(dict) for dict in dicts]
416
+ paths = [
417
+ f"{self.args.hubert_label_dir}/{speech_split}.{l}" for l in self.args.hubert_labels
418
+ ]
419
+ # hubert v1: pad_audio=True, random_crop=False;
420
+ self.args.dec_weight = getattr(self.args, "dec_weight", 1.0)
421
+ pretrain_datasets.append(
422
+ SpeechPretrainDataset(
423
+ manifest,
424
+ sample_rate=self.args.sample_rate,
425
+ label_paths=paths,
426
+ label_rates=self.args.label_rates,
427
+ pad_list=pad_list,
428
+ eos_list=eos_list,
429
+ label_processors=procs,
430
+ max_keep_sample_size=None,
431
+ min_keep_sample_size=32000,
432
+ max_sample_size=self.args.max_speech_sample_size,
433
+ pad_audio=self.args.pad_audio,
434
+ normalize=self.args.normalize,
435
+ store_labels=False,
436
+ random_crop=self.args.random_crop,
437
+ single_target=self.args.single_target,
438
+ reduction_factor=self.args.reduction_factor,
439
+ )
440
+ )
441
+ sample_ratios.append(sum([pretrain_datasets[0].size(i) for i in range(len(pretrain_datasets[0]))]))
442
+
443
+ ## Text pre-train
444
+ paths = utils.split_paths(self.args.data)
445
+ assert len(paths) > 0
446
+ data_path = paths[(epoch - 1) % len(paths)]
447
+ print(f"Loading {text_split} from data_path={data_path}")
448
+ split_path = op.join(data_path, text_split)
449
+ print(f"split_path={split_path}")
450
+ bart_dataset = data_utils.load_indexed_dataset(
451
+ split_path,
452
+ self.dicts["text"],
453
+ self.args.dataset_impl,
454
+ combine=combine,
455
+ )
456
+ if bart_dataset is None:
457
+ raise FileNotFoundError(
458
+ "Dataset not found: {} ({})".format(text_split, split_path)
459
+ )
460
+ bart_dataset = StripTokenDataset(bart_dataset, self.dicts["text"].eos())
461
+ bart_dataset = maybe_shorten_dataset(
462
+ bart_dataset,
463
+ text_split,
464
+ self.args.shorten_data_split_list,
465
+ self.args.shorten_method,
466
+ self.args.tokens_per_sample,
467
+ self.args.seed,
468
+ )
469
+ # create continuous blocks of tokens
470
+ bart_dataset = TokenBlockDataset(
471
+ bart_dataset,
472
+ bart_dataset.sizes,
473
+ self.args.tokens_per_sample - 2, # one less for <s> and one for </s>
474
+ pad=self.dicts["text"].pad(),
475
+ eos=self.dicts["text"].eos(),
476
+ break_mode=self.args.sample_break_mode,
477
+ document_sep_len=0,
478
+ )
479
+ # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
480
+ bart_dataset = PrependTokenDataset(bart_dataset, self.dicts["text"].bos())
481
+ bart_dataset = AppendTokenDataset(bart_dataset, self.dicts["text"].eos())
482
+ mask_whole_words = (
483
+ get_whole_word_mask(self.args, self.dicts["text"])
484
+ if self.args.mask_length != "subword"
485
+ else None
486
+ )
487
+ self.args.bert_weight = getattr(self.args, "bert_weight", 0.0)
488
+ pretrain_datasets.append(
489
+ TextPretrainDataset(
490
+ bart_dataset,
491
+ bart_dataset.sizes,
492
+ self.dicts["text"],
493
+ self.mask_idx,
494
+ mask_whole_words,
495
+ shuffle=self.args.shuffle_instance,
496
+ seed=self.seed,
497
+ args=self.args,
498
+ iid_noise_target=self.args.iid_noise_target,
499
+ uni_mask_idxs=self.uni_mask_idxs if self.args.iid_noise_target else None,
500
+ )
501
+ )
502
+ sample_ratios.append(sum(pretrain_datasets[1].sizes))
503
+ logger.info(
504
+ "Task: {0}, Loaded {1} samples of denoising_dataset".format(
505
+ 'bart',
506
+ len(pretrain_datasets[1]),
507
+ )
508
+ )
509
+
510
+ logger.info('token ratio is ' + str(sample_ratios))
511
+ if self.args.batch_ratio is not None:
512
+ batch_ratio = eval(self.args.batch_ratio)
513
+ assert len(batch_ratio) == len(sample_ratios)
514
+ sample_ratios = [sample_ratios[i] / batch_ratio[i] for i in range(len(sample_ratios))]
515
+ else:
516
+ batch_ratio = None
517
+ max_size = max(sample_ratios)
518
+ sample_ratios = [max_size / r for r in sample_ratios]
519
+ if hasattr(self.args, "sample_ratios") and self.args.sample_ratios is not None:
520
+ sample_ratios = eval(self.args.sample_ratios)
521
+ if is_train_split:
522
+ self.datasets[split] = MultitaskDataset(
523
+ pretrain_datasets, sample_ratios, batch_ratio
524
+ )
525
+ else:
526
+ self.datasets[split] = MultitaskDataset(
527
+ pretrain_datasets, batch_ratio=batch_ratio
528
+ )
529
+
530
+ def train_step(
531
+ self, sample, model, criterion, optimizer, update_num, ignore_grad=False
532
+ ):
533
+ model.train()
534
+ model.set_num_updates(update_num)
535
+
536
+ # Junyi: not use sample_size, but normalize the loss locally
537
+ agg_loss, agg_sample_size, agg_logging_output = 0.0, 1.0, {}
538
+ agg_logging_output['sample_size'] = 1
539
+
540
+ def forward_backward(model, samples, weight=1.0):
541
+ nonlocal agg_loss, agg_logging_output
542
+ if samples is None or len(samples) == 0:
543
+ return
544
+ loss, sample_size, logging_output = criterion(model, samples)
545
+ if ignore_grad:
546
+ loss *= 0
547
+ else:
548
+ loss *= weight
549
+ loss = loss / sample_size
550
+ optimizer.backward(loss)
551
+ agg_loss += loss.detach().item()
552
+ # # TODO make summing of the sample sizes configurable
553
+ for k in logging_output:
554
+ if k == 'ntokens' or k == 'nsentences':
555
+ if k not in agg_logging_output:
556
+ agg_logging_output[k] = 0
557
+ agg_logging_output[k] += logging_output[k]
558
+ # continue
559
+ # agg_logging_output[k] += logging_output[k]
560
+ # agg_logging_output[task_name] += logging_output[k]
561
+ agg_logging_output[samples['task_name']] = logging_output
562
+
563
+ forward_backward(model, sample)
564
+
565
+ agg_logging_output["loss"] = agg_loss
566
+
567
+ return agg_loss, agg_sample_size, agg_logging_output
568
+
569
+ def valid_step(self, sample, model, criterion):
570
+ model.eval()
571
+ with torch.no_grad():
572
+ from collections import defaultdict
573
+
574
+ agg_loss, agg_sample_size, agg_logging_output = 0.0, 1.0, defaultdict(float)
575
+ agg_logging_output['sample_size'] = 1
576
+ loss, sample_size, logging_output = criterion(model, sample)
577
+ loss = loss / sample_size
578
+ # agg_loss += loss.data.item() if isinstance(loss, torch.Tensor) else loss
579
+ agg_loss += loss.item() if isinstance(loss, torch.Tensor) else loss
580
+ agg_logging_output[sample['task_name']] = logging_output
581
+ agg_logging_output["loss"] = agg_loss
582
+ return agg_loss, agg_sample_size, agg_logging_output
583
+
584
+ @property
585
+ def target_dictionary(self):
586
+ return self.dicts["text"]
587
+
588
+ @property
589
+ def source_dictionary(self):
590
+ return None
591
+
592
+ def build_model(self, args):
593
+ try:
594
+ args.input_feat_per_channel = self.config.input_feat_per_channel
595
+ args.input_channels = self.config.input_channels
596
+ except Exception as e:
597
+ args.input_feat_per_channel = 80
598
+ args.input_channels = 1
599
+ logger.info(f"Cannot set input_feat_per_channel, input_channels, since: ")
600
+ logger.warn(e)
601
+ logger.info(f"Set to: {args.input_feat_per_channel} and {args.input_channels}")
602
+
603
+ args.speech_odim = args.input_feat_per_channel * args.input_channels
604
+
605
+ args.label_rates = self.args.label_rates
606
+ args.sample_rate = self.args.sample_rate
607
+ self.args.reduction_factor = args.reduction_factor
608
+ return super(ArTSTTask, self).build_model(args)
609
+
610
+ def build_generator(
611
+ self,
612
+ models,
613
+ args,
614
+ seq_gen_cls=None,
615
+ extra_gen_cls_kwargs=None,
616
+ ):
617
+ from artst.sequence_generator import SequenceGenerator
618
+ extra_gen_cls_kwargs = {
619
+ "ctc_weight": self.args.ctc_weight,
620
+ **extra_gen_cls_kwargs
621
+ }
622
+ return super().build_generator(
623
+ models, args, seq_gen_cls=SequenceGenerator, extra_gen_cls_kwargs=extra_gen_cls_kwargs
624
+ )
625
+
626
+ def build_tokenizer(self, args):
627
+ if self.config is None:
628
+ logger.info(f"pre-tokenizer: None")
629
+ return encoders.build_tokenizer(Namespace(**{"tokenizer": None}))
630
+ else:
631
+ logger.info(f"pre-tokenizer: {self.config.pre_tokenizer}")
632
+ return encoders.build_tokenizer(Namespace(**self.config.pre_tokenizer))
633
+
634
+ def build_bpe(self, args):
635
+ if self.config is not None:
636
+ logger.info(f"tokenizer: {self.config.bpe_tokenizer}")
637
+ return encoders.build_bpe(Namespace(**self.config.bpe_tokenizer))
638
+ else:
639
+ logger.info(f"tokenizer: {self.args.bpe_tokenizer}")
640
+ return encoders.build_bpe(Namespace(**{"bpe": "sentencepiece", "sentencepiece_model": self.args.bpe_tokenizer}))
641
+
642
+ def generate_class(self, models, net_input, prefix_tokens, **kwargs):
643
+ with torch.no_grad():
644
+ encoder_input = {
645
+ k: v for k, v in net_input.items() if k != "prev_output_tokens" and k != "task_name"
646
+ }
647
+ encoder_input.update(kwargs)
648
+ encoder_input.update({"prev_output_tokens": prefix_tokens})
649
+ return models[0].generate_class(**encoder_input)
650
+
651
+ def generate_speech(self, models, net_input, **kwargs):
652
+ with torch.no_grad():
653
+ encoder_input = {
654
+ k: v for k, v in net_input.items() if k != "prev_output_tokens" and k != "task_name"
655
+ }
656
+ encoder_input.update(kwargs)
657
+ return models[0].generate_speech(**encoder_input)
658
+
659
+ def inference_t2s(
660
+ self, models, sample
661
+ ):
662
+ with torch.no_grad():
663
+ xs = sample['net_input']['src_tokens']
664
+ spkemb = sample['net_input']['spkembs']
665
+ return models[0].inference(xs, spkemb)
666
+
667
+ def inference_s2s(
668
+ self, models, sample, force_equal_length=False
669
+ ):
670
+ with torch.no_grad():
671
+ x = sample['net_input']['src_tokens']
672
+ xlen = sample['net_input']['src_lengths']
673
+ spkemb = sample['net_input']['spkembs']
674
+ prev_output_tokens = sample['net_input']['prev_output_tokens']
675
+ padding_mask = sample['net_input']['padding_mask']
676
+ tgt_lengths = sample['net_input']['tgt_lengths']
677
+ return models[0].inference_s2s(x, xlen, spkemb, prev_output_tokens, tgt_lengths, force_equal_length=force_equal_length, padding_mask=padding_mask)
678
+
679
+ def inference_s2c(
680
+ self, models, sample
681
+ ):
682
+ with torch.no_grad():
683
+ x = sample['net_input']['src_tokens']
684
+ xlen = sample['net_input']['src_lengths']
685
+ prev_output_tokens = sample['net_input']['prev_output_tokens']
686
+ padding_mask = sample['net_input']['padding_mask']
687
+ assert prev_output_tokens.size(1) == 1, prev_output_tokens.size()
688
+ return models[0].inference_s2c(x, xlen, prev_output_tokens, padding_mask=padding_mask)
689
+
690
+ def filter_indices_by_size(
691
+ self, indices, dataset, max_positions=None, ignore_invalid_inputs=False
692
+ ):
693
+ """
694
+ Filter examples that are too large
695
+
696
+ Args:
697
+ indices (np.array): original array of sample indices
698
+ dataset (~fairseq.data.FairseqDataset): dataset to batch
699
+ max_positions (optional): max sentence length supported by the
700
+ model (default: None).
701
+ ignore_invalid_inputs (bool, optional): don't raise Exception for
702
+ sentences that are too long (default: False).
703
+ Returns:
704
+ np.array: array of filtered sample indices
705
+ """
706
+
707
+ indices, ignored = dataset.filter_indices_by_size(
708
+ indices,
709
+ self.max_pos
710
+ )
711
+ return indices
ckpts/mgb2_asr.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9aaf1c09f146aff9361e999885c861f8ab0a6b0c997ceb34682ad3366849cc91
3
+ size 1847116641
pre-requirements.txt ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cython==0.29.35
2
+ fairseq==0.12.2
3
+ datasets==2.12.0
4
+ editdistance==0.6.2
5
+ espnet==202304
6
+ espnet-tts-frontend==0.0.3
7
+ librosa==0.9.2
8
+ omegaconf==2.0.6
9
+ pandas==2.0.1
10
+ PyArabic==0.6.15
11
+ scipy
12
+ soundfile
13
+ tqdm==4.65.0
14
+ tweepy==4.14.0
15
+ tensorboard
16
+ kaldiio==2.18.0
17
+ numpy==1.23.5
18
+ cmake==3.26.4
19
+ pillow==10.0.0
20
+ nvidia-cublas-cu11==11.10.3.66
21
+ nvidia-cuda-cupti-cu11==11.7.101
22
+ nvidia-cuda-nvrtc-cu11==11.7.99
23
+ nvidia-cuda-runtime-cu11==11.7.99
24
+ nvidia-cudnn-cu11==8.5.0.96
25
+ nvidia-cufft-cu11==10.9.0.58
26
+ nvidia-curand-cu11==10.2.10.91
27
+ nvidia-cusolver-cu11==11.4.0.1
28
+ nvidia-cusparse-cu11==11.7.4.91
29
+ nvidia-nccl-cu11==2.14.3
30
+ nvidia-nvtx-cu11==11.7.91
31
+ tensorboardx==2.6
32
+ transformers
33
+ speechbrain
34
+ numpy==1.23.5
samples/sample_audio.wav ADDED
Binary file (97.9 kB). View file
 
utils/arabic.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:17dd0feab129e71329f98fe9efd6143be4f6e1aede2408fe6c586f803f7d6cc0
3
+ size 539226
utils/audios.tsv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ./
2
+ /tmp/gradio/8f338949403b3fbd48ea5459ce1a148e9ebf96ee/sample_audio.wav/t30000
utils/dict.txt ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ▁ 7888394
2
+ ا 5240837
3
+ ل 4086510
4
+ ي 2961895
5
+ م 2193275
6
+ ن 2167596
7
+ و 1852812
8
+ ت 1552697
9
+ ر 1543221
10
+ ع 1224967
11
+ ب 1116473
12
+ ه 1088784
13
+ د 1060960
14
+ ة 1037790
15
+ أ 940870
16
+ س 935180
17
+ ف 834969
18
+ ك 831374
19
+ ق 807989
20
+ ح 704969
21
+ ج 436806
22
+ ذ 368254
23
+ ط 344915
24
+ إ 327964
25
+ ش 314812
26
+ ى 307987
27
+ ص 290514
28
+ خ 277022
29
+ ض 252120
30
+ ث 197381
31
+ ز 166103
32
+ ئ 137402
33
+ ً 122431
34
+ غ 113416
35
+ ء 111245
36
+ ظ 98346
37
+ ُ 75189
38
+ آ 58959
39
+ ؤ 57561
40
+ ّ 39925
41
+ 0 27260
42
+ ٍ 24728
43
+ َ 22922
44
+ ِ 21587
45
+ 1 18933
46
+ 2 15936
47
+ ٌ 10327
48
+ 5 6938
49
+ 9 6502
50
+ 6 4829
51
+ 7 4380
52
+ 8 4041
53
+ e 3578
54
+ a 3330
55
+ % 2892
56
+ ـ 2768
57
+ t 2746
58
+ i 2683
59
+ n 2462
60
+ o 2390
61
+ r 2341
62
+ s 2073
63
+ c 1561
64
+ l 1504
65
+ m 1141
66
+ d 1018
67
+ p 988
68
+ u 920
69
+ h 888
70
+ g 777
71
+ b 762
72
+ f 603
73
+ y 562
74
+ k 515
75
+ w 470
76
+ v 293
77
+ j 254
78
+ z 193
79
+ x 105
80
+ @ 97
81
+ 3 52
82
+ 4 48
83
+ q 30
84
+ ̇ 6
85
+ ٱ 2
86
+ ⁄ 1
87
+ madeupword0000 0
88
+ madeupword0001 0
89
+ madeupword0002 0
90
+ madeupword0003 0
91
+ madeupword0004 0
92
+ madeupword0005 0