Spaces:
Runtime error
Runtime error
initial commit
Browse files- .gitignore +3 -0
- README.md +5 -5
- SpeechT5 +1 -0
- app.py +466 -0
- artst/__init__.py +1 -0
- artst/criterions/__init__.py +10 -0
- artst/criterions/artst_criterion.py +443 -0
- artst/criterions/speech_pretrain_criterion.py +265 -0
- artst/criterions/speech_to_text_loss.py +473 -0
- artst/criterions/text_pretrain_criterion.py +142 -0
- artst/criterions/text_to_speech_loss.py +425 -0
- artst/data/__init__.py +0 -0
- artst/data/multitask_dataset.py +263 -0
- artst/data/speech_dataset.py +475 -0
- artst/data/speech_to_class_dataset.py +260 -0
- artst/data/speech_to_speech_dataset.py +280 -0
- artst/data/speech_to_text_dataset.py +298 -0
- artst/data/text_dataset.py +474 -0
- artst/data/text_to_speech_dataset.py +344 -0
- artst/models/__init__.py +2 -0
- artst/models/artst.py +1448 -0
- artst/models/modules/__init__.py +0 -0
- artst/models/modules/decoder.py +323 -0
- artst/models/modules/encoder.py +380 -0
- artst/models/modules/multihead_attention.py +525 -0
- artst/models/modules/speaker_decoder_postnet.py +196 -0
- artst/models/modules/speech_decoder_postnet.py +75 -0
- artst/models/modules/speech_decoder_prenet.py +109 -0
- artst/models/modules/speech_encoder_postnet.py +123 -0
- artst/models/modules/speech_encoder_prenet.py +373 -0
- artst/models/modules/text_decoder_postnet.py +92 -0
- artst/models/modules/text_decoder_prenet.py +128 -0
- artst/models/modules/text_encoder_prenet.py +44 -0
- artst/models/modules/transformer_layer.py +410 -0
- artst/models/t5_transformer_lm.py +23 -0
- artst/sequence_generator.py +1080 -0
- artst/tasks/__init__.py +0 -0
- artst/tasks/artst.py +711 -0
- ckpts/mgb2_asr.pt +3 -0
- pre-requirements.txt +34 -0
- samples/sample_audio.wav +0 -0
- utils/arabic.model +3 -0
- utils/audios.tsv +2 -0
- utils/dict.txt +92 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
*/__pycache__
|
3 |
+
*/*/__pycache__
|
README.md
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
colorFrom: gray
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 4.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: ArtstTTS
|
3 |
+
emoji: 💭
|
4 |
colorFrom: gray
|
5 |
+
colorTo: blue
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.7.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
python_version: 3.8.2
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
SpeechT5
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit 8b5ade783571e63450aaa5507444150dcb08fa94
|
app.py
ADDED
@@ -0,0 +1,466 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3 -u
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
Translate pre-processed data with a trained model.
|
8 |
+
"""
|
9 |
+
|
10 |
+
import ast
|
11 |
+
import logging
|
12 |
+
import argparse
|
13 |
+
import math
|
14 |
+
import os
|
15 |
+
import sys
|
16 |
+
from argparse import Namespace
|
17 |
+
from itertools import chain
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
from omegaconf import DictConfig
|
22 |
+
|
23 |
+
from fairseq import checkpoint_utils, options, scoring, tasks, utils
|
24 |
+
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
25 |
+
from fairseq.logging import progress_bar
|
26 |
+
from fairseq.logging.meters import StopwatchMeter, TimeMeter
|
27 |
+
|
28 |
+
import os
|
29 |
+
import torch
|
30 |
+
import gradio as gr
|
31 |
+
import numpy as np
|
32 |
+
import os.path as op
|
33 |
+
import pyarabic.araby as araby
|
34 |
+
import subprocess
|
35 |
+
|
36 |
+
import soundfile as sf
|
37 |
+
|
38 |
+
|
39 |
+
from artst.tasks.artst import ArTSTTask
|
40 |
+
from artst.models.artst import ArTSTTransformerModel
|
41 |
+
from fairseq.tasks.hubert_pretraining import LabelEncoder
|
42 |
+
|
43 |
+
from fairseq import checkpoint_utils, options, scoring, tasks, utils
|
44 |
+
|
45 |
+
from loguru import logger
|
46 |
+
from fairseq.logging.meters import StopwatchMeter, TimeMeter
|
47 |
+
|
48 |
+
|
49 |
+
def postprocess(wav, cur_sample_rate):
|
50 |
+
if wav.dim() == 2:
|
51 |
+
wav = wav.mean(-1)
|
52 |
+
assert wav.dim() == 1, wav.dim()
|
53 |
+
|
54 |
+
if cur_sample_rate != 16000:
|
55 |
+
raise Exception(f"sr {cur_sample_rate} != {16000}")
|
56 |
+
return wav
|
57 |
+
|
58 |
+
|
59 |
+
def main(cfg: DictConfig, audio_path):
|
60 |
+
print('config')
|
61 |
+
print(cfg)
|
62 |
+
|
63 |
+
if isinstance(cfg, Namespace):
|
64 |
+
cfg = convert_namespace_to_omegaconf(cfg)
|
65 |
+
|
66 |
+
assert cfg.common_eval.path is not None, "--path required for generation!"
|
67 |
+
assert (
|
68 |
+
not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam
|
69 |
+
), "--sampling requires --nbest to be equal to --beam"
|
70 |
+
assert (
|
71 |
+
cfg.generation.replace_unk is None or cfg.dataset.dataset_impl == "raw"
|
72 |
+
), "--replace-unk requires a raw text dataset (--dataset-impl=raw)"
|
73 |
+
|
74 |
+
if cfg.common_eval.results_path is not None:
|
75 |
+
os.makedirs(cfg.common_eval.results_path, exist_ok=True)
|
76 |
+
output_path = os.path.join(
|
77 |
+
cfg.common_eval.results_path,
|
78 |
+
"generate-{}.txt".format(cfg.dataset.gen_subset),
|
79 |
+
)
|
80 |
+
with open(output_path, "w", buffering=1, encoding="utf-8") as h:
|
81 |
+
return _main(cfg, h)
|
82 |
+
else:
|
83 |
+
return _main(cfg, sys.stdout, audio_path)
|
84 |
+
|
85 |
+
|
86 |
+
def get_symbols_to_strip_from_output(generator):
|
87 |
+
if hasattr(generator, "symbols_to_strip_from_output"):
|
88 |
+
return generator.symbols_to_strip_from_output
|
89 |
+
else:
|
90 |
+
return {generator.eos}
|
91 |
+
|
92 |
+
|
93 |
+
def _main(cfg: DictConfig, output_file, audio_path):
|
94 |
+
logging.basicConfig(
|
95 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
96 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
97 |
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
98 |
+
stream=output_file,
|
99 |
+
)
|
100 |
+
logger = logging.getLogger("fairseq_cli.generate")
|
101 |
+
|
102 |
+
utils.import_user_module(cfg.common)
|
103 |
+
|
104 |
+
if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None:
|
105 |
+
cfg.dataset.max_tokens = 12000
|
106 |
+
logger.info(cfg)
|
107 |
+
|
108 |
+
# Fix seed for stochastic decoding
|
109 |
+
if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
|
110 |
+
np.random.seed(cfg.common.seed)
|
111 |
+
utils.set_torch_seed(cfg.common.seed)
|
112 |
+
|
113 |
+
use_cuda = torch.cuda.is_available() and not cfg.common.cpu
|
114 |
+
|
115 |
+
# Load dataset splits
|
116 |
+
task = tasks.setup_task(cfg.task)
|
117 |
+
|
118 |
+
# Set dictionaries
|
119 |
+
try:
|
120 |
+
src_dict = getattr(task, "source_dictionary", None)
|
121 |
+
except NotImplementedError:
|
122 |
+
src_dict = None
|
123 |
+
tgt_dict = task.target_dictionary
|
124 |
+
|
125 |
+
overrides = ast.literal_eval(cfg.common_eval.model_overrides)
|
126 |
+
|
127 |
+
# Load ensemble
|
128 |
+
logger.info("loading model(s) from {}".format(cfg.common_eval.path))
|
129 |
+
models, saved_cfg = checkpoint_utils.load_model_ensemble(
|
130 |
+
utils.split_paths(cfg.common_eval.path),
|
131 |
+
arg_overrides=overrides,
|
132 |
+
task=task,
|
133 |
+
suffix=cfg.checkpoint.checkpoint_suffix,
|
134 |
+
strict=(cfg.checkpoint.checkpoint_shard_count == 1),
|
135 |
+
num_shards=cfg.checkpoint.checkpoint_shard_count,
|
136 |
+
)
|
137 |
+
|
138 |
+
# loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
|
139 |
+
# task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task)
|
140 |
+
|
141 |
+
if cfg.generation.lm_path is not None:
|
142 |
+
overrides["data"] = cfg.task.data
|
143 |
+
|
144 |
+
try:
|
145 |
+
lms, _ = checkpoint_utils.load_model_ensemble(
|
146 |
+
[cfg.generation.lm_path], arg_overrides=overrides, task=None
|
147 |
+
)
|
148 |
+
except:
|
149 |
+
logger.warning(
|
150 |
+
f"Failed to load language model! Please make sure that the language model dict is the same "
|
151 |
+
f"as target dict and is located in the data dir ({cfg.task.data})"
|
152 |
+
)
|
153 |
+
raise
|
154 |
+
|
155 |
+
assert len(lms) == 1
|
156 |
+
else:
|
157 |
+
lms = [None]
|
158 |
+
|
159 |
+
# Optimize ensemble for generation
|
160 |
+
for model in chain(models, lms):
|
161 |
+
if model is None:
|
162 |
+
continue
|
163 |
+
if cfg.common.fp16:
|
164 |
+
model.half()
|
165 |
+
if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
|
166 |
+
model.cuda()
|
167 |
+
model.prepare_for_inference_(cfg)
|
168 |
+
|
169 |
+
# Load alignment dictionary for unknown word replacement
|
170 |
+
# (None if no unknown word replacement, empty if no path to align dictionary)
|
171 |
+
align_dict = utils.load_align_dict(cfg.generation.replace_unk)
|
172 |
+
|
173 |
+
# Initialize generator
|
174 |
+
gen_timer = StopwatchMeter()
|
175 |
+
|
176 |
+
extra_gen_cls_kwargs = {"lm_model": lms[0], "lm_weight": cfg.generation.lm_weight}
|
177 |
+
generator = task.build_generator(
|
178 |
+
models, cfg.generation, extra_gen_cls_kwargs=extra_gen_cls_kwargs
|
179 |
+
)
|
180 |
+
|
181 |
+
# Handle tokenization and BPE
|
182 |
+
tokenizer = task.build_tokenizer(cfg.tokenizer)
|
183 |
+
bpe = task.build_bpe(cfg.bpe)
|
184 |
+
|
185 |
+
def decode_fn(x):
|
186 |
+
if bpe is not None:
|
187 |
+
x = bpe.decode(x)
|
188 |
+
if tokenizer is not None:
|
189 |
+
x = tokenizer.decode(x)
|
190 |
+
return x
|
191 |
+
|
192 |
+
scorer = scoring.build_scorer(cfg.scoring, tgt_dict)
|
193 |
+
|
194 |
+
num_sentences = 0
|
195 |
+
has_target = True
|
196 |
+
wps_meter = TimeMeter()
|
197 |
+
|
198 |
+
|
199 |
+
wav, cur_sample_rate = sf.read(audio_path)
|
200 |
+
wav = torch.from_numpy(wav).float()
|
201 |
+
wav = postprocess(wav, cur_sample_rate)
|
202 |
+
sample = {'index': 0, 'net_input': {'source': torch.tensor(wav).unsqueeze(dim=0), 'padding_mask':
|
203 |
+
torch.BoolTensor(wav.shape).fill_(False).unsqueeze(dim=0)}, 'id': [0], 'target': [[None], ]}
|
204 |
+
|
205 |
+
prefix_tokens = None
|
206 |
+
if cfg.generation.prefix_size > 0:
|
207 |
+
prefix_tokens = sample["target"][:, : cfg.generation.prefix_size]
|
208 |
+
|
209 |
+
constraints = None
|
210 |
+
if "constraints" in sample:
|
211 |
+
constraints = sample["constraints"]
|
212 |
+
|
213 |
+
gen_timer.start()
|
214 |
+
hypos = task.inference_step(
|
215 |
+
generator,
|
216 |
+
models,
|
217 |
+
sample,
|
218 |
+
prefix_tokens=prefix_tokens,
|
219 |
+
constraints=constraints,
|
220 |
+
)
|
221 |
+
num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
|
222 |
+
gen_timer.stop(num_generated_tokens)
|
223 |
+
|
224 |
+
for i, sample_id in enumerate(sample["id"]):
|
225 |
+
has_target = False
|
226 |
+
|
227 |
+
# Remove padding
|
228 |
+
if "src_tokens" in sample["net_input"]:
|
229 |
+
src_tokens = utils.strip_pad(
|
230 |
+
sample["net_input"]["src_tokens"][i, :], tgt_dict.pad()
|
231 |
+
)
|
232 |
+
else:
|
233 |
+
src_tokens = None
|
234 |
+
|
235 |
+
target_tokens = None
|
236 |
+
if has_target:
|
237 |
+
target_tokens = (
|
238 |
+
utils.strip_pad(sample["target"][i, :], tgt_dict.pad()).int().cpu()
|
239 |
+
)
|
240 |
+
|
241 |
+
# Either retrieve the original sentences or regenerate them from tokens.
|
242 |
+
if align_dict is not None:
|
243 |
+
src_str = task.dataset(cfg.dataset.gen_subset).src.get_original_text(
|
244 |
+
sample_id
|
245 |
+
)
|
246 |
+
target_str = task.dataset(cfg.dataset.gen_subset).tgt.get_original_text(
|
247 |
+
sample_id
|
248 |
+
)
|
249 |
+
else:
|
250 |
+
if src_dict is not None:
|
251 |
+
src_str = src_dict.string(src_tokens, cfg.common_eval.post_process)
|
252 |
+
else:
|
253 |
+
src_str = ""
|
254 |
+
if has_target:
|
255 |
+
target_str = tgt_dict.string(
|
256 |
+
target_tokens,
|
257 |
+
cfg.common_eval.post_process,
|
258 |
+
escape_unk=True,
|
259 |
+
extra_symbols_to_ignore=get_symbols_to_strip_from_output(
|
260 |
+
generator
|
261 |
+
),
|
262 |
+
)
|
263 |
+
|
264 |
+
src_str = decode_fn(src_str)
|
265 |
+
if has_target:
|
266 |
+
target_str = decode_fn(target_str)
|
267 |
+
|
268 |
+
if not cfg.common_eval.quiet:
|
269 |
+
if src_dict is not None:
|
270 |
+
print("S-{}\t{}".format(sample_id, src_str), file=output_file)
|
271 |
+
if has_target:
|
272 |
+
print("T-{}\t{}".format(sample_id, target_str), file=output_file)
|
273 |
+
|
274 |
+
# Process top predictions
|
275 |
+
for j, hypo in enumerate(hypos[i][: cfg.generation.nbest]):
|
276 |
+
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
|
277 |
+
hypo_tokens=hypo["tokens"].int().cpu(),
|
278 |
+
src_str=src_str,
|
279 |
+
alignment=hypo["alignment"],
|
280 |
+
align_dict=align_dict,
|
281 |
+
tgt_dict=tgt_dict,
|
282 |
+
remove_bpe=cfg.common_eval.post_process,
|
283 |
+
extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator),
|
284 |
+
)
|
285 |
+
detok_hypo_str = decode_fn(hypo_str)
|
286 |
+
if not cfg.common_eval.quiet:
|
287 |
+
score = hypo["score"] / math.log(2) # convert to base 2
|
288 |
+
# original hypothesis (after tokenization and BPE)
|
289 |
+
print(
|
290 |
+
"H-{}\t{}\t{}".format(sample_id, score, hypo_str),
|
291 |
+
file=output_file,
|
292 |
+
)
|
293 |
+
# detokenized hypothesis
|
294 |
+
print(
|
295 |
+
"D-{}\t{}\t{}".format(sample_id, score, detok_hypo_str),
|
296 |
+
file=output_file,
|
297 |
+
)
|
298 |
+
print(
|
299 |
+
"P-{}\t{}".format(
|
300 |
+
sample_id,
|
301 |
+
" ".join(
|
302 |
+
map(
|
303 |
+
lambda x: "{:.4f}".format(x),
|
304 |
+
# convert from base e to base 2
|
305 |
+
hypo["positional_scores"]
|
306 |
+
.div_(math.log(2))
|
307 |
+
.tolist(),
|
308 |
+
)
|
309 |
+
),
|
310 |
+
),
|
311 |
+
file=output_file,
|
312 |
+
)
|
313 |
+
|
314 |
+
if cfg.generation.print_alignment == "hard":
|
315 |
+
print(
|
316 |
+
"A-{}\t{}".format(
|
317 |
+
sample_id,
|
318 |
+
" ".join(
|
319 |
+
[
|
320 |
+
"{}-{}".format(src_idx, tgt_idx)
|
321 |
+
for src_idx, tgt_idx in alignment
|
322 |
+
]
|
323 |
+
),
|
324 |
+
),
|
325 |
+
file=output_file,
|
326 |
+
)
|
327 |
+
if cfg.generation.print_alignment == "soft":
|
328 |
+
print(
|
329 |
+
"A-{}\t{}".format(
|
330 |
+
sample_id,
|
331 |
+
" ".join(
|
332 |
+
[",".join(src_probs) for src_probs in alignment]
|
333 |
+
),
|
334 |
+
),
|
335 |
+
file=output_file,
|
336 |
+
)
|
337 |
+
|
338 |
+
if cfg.generation.print_step:
|
339 |
+
print(
|
340 |
+
"I-{}\t{}".format(sample_id, hypo["steps"]),
|
341 |
+
file=output_file,
|
342 |
+
)
|
343 |
+
|
344 |
+
if cfg.generation.retain_iter_history:
|
345 |
+
for step, h in enumerate(hypo["history"]):
|
346 |
+
_, h_str, _ = utils.post_process_prediction(
|
347 |
+
hypo_tokens=h["tokens"].int().cpu(),
|
348 |
+
src_str=src_str,
|
349 |
+
alignment=None,
|
350 |
+
align_dict=None,
|
351 |
+
tgt_dict=tgt_dict,
|
352 |
+
remove_bpe=None,
|
353 |
+
)
|
354 |
+
print(
|
355 |
+
"E-{}_{}\t{}".format(sample_id, step, h_str),
|
356 |
+
file=output_file,
|
357 |
+
)
|
358 |
+
|
359 |
+
# Score only the top hypothesis
|
360 |
+
if has_target and j == 0:
|
361 |
+
if (
|
362 |
+
align_dict is not None
|
363 |
+
or cfg.common_eval.post_process is not None
|
364 |
+
):
|
365 |
+
# Convert back to tokens for evaluation with unk replacement and/or without BPE
|
366 |
+
target_tokens = tgt_dict.encode_line(
|
367 |
+
target_str, add_if_not_exist=True
|
368 |
+
)
|
369 |
+
hypo_tokens = tgt_dict.encode_line(
|
370 |
+
detok_hypo_str, add_if_not_exist=True
|
371 |
+
)
|
372 |
+
if hasattr(scorer, "add_string"):
|
373 |
+
scorer.add_string(target_str, detok_hypo_str)
|
374 |
+
else:
|
375 |
+
scorer.add(target_tokens, hypo_tokens)
|
376 |
+
|
377 |
+
wps_meter.update(num_generated_tokens)
|
378 |
+
# progress.log({"wps": round(wps_meter.avg)})
|
379 |
+
|
380 |
+
logger.info("NOTE: hypothesis and token scores are output in base 2")
|
381 |
+
if has_target:
|
382 |
+
if cfg.bpe and not cfg.generation.sacrebleu:
|
383 |
+
if cfg.common_eval.post_process:
|
384 |
+
logger.warning(
|
385 |
+
"BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization"
|
386 |
+
)
|
387 |
+
else:
|
388 |
+
logger.warning(
|
389 |
+
"If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words. Use --sacrebleu for standard 13a BLEU tokenization"
|
390 |
+
)
|
391 |
+
# use print to be consistent with other main outputs: S-, H-, T-, D- and so on
|
392 |
+
print(
|
393 |
+
"Generate {} with beam={}: {}".format(
|
394 |
+
cfg.dataset.gen_subset, cfg.generation.beam, scorer.result_string()
|
395 |
+
),
|
396 |
+
file=output_file,
|
397 |
+
)
|
398 |
+
return detok_hypo_str
|
399 |
+
|
400 |
+
def inference(audio_path):
|
401 |
+
# parser = options.get_generation_parser()
|
402 |
+
# TODO: replace this workaround with refactoring of `AudioPretraining`
|
403 |
+
parser = argparse.ArgumentParser(description='Process some integers.')
|
404 |
+
parser.add_argument(
|
405 |
+
"--arch",
|
406 |
+
"-a",
|
407 |
+
metavar="ARCH",
|
408 |
+
default="wav2vec2",
|
409 |
+
help="Model architecture. For constructing tasks that rely on "
|
410 |
+
"model args (e.g. `AudioPretraining`)",
|
411 |
+
)
|
412 |
+
parser.add_argument('--data', type=str, default='./utils', metavar='data')
|
413 |
+
parser.add_argument('--bpe-tokenizer', type=str, default='./utils/arabic.model')
|
414 |
+
parser.add_argument('--user-dir', type=str, default='./SpeechT5/SpeechT5/speecht5')
|
415 |
+
parser.add_argument('--task', type=str, default='artst')
|
416 |
+
parser.add_argument('--t5-task', type=str, default='s2t')
|
417 |
+
parser.add_argument('--path', type=str, default='./ckpts/mgb2_asr.pt')
|
418 |
+
parser.add_argument('--ctc-weight', type=float, default=0.25)
|
419 |
+
parser.add_argument('--max-tokens', type=int, default=350000)
|
420 |
+
parser.add_argument('--beam', type=int, default=5)
|
421 |
+
parser.add_argument('--scoring', type=str, default='wer')
|
422 |
+
parser.add_argument('--max-len-a', type=float, default=0)
|
423 |
+
parser.add_argument('--max-len-b', type=int, default=1000)
|
424 |
+
parser.add_argument('--sample-rate', type=int, default=16000)
|
425 |
+
parser.add_argument('--batch-size', type=int, default=1)
|
426 |
+
# parser.add_argument('--num-workers', type=int, default=4)
|
427 |
+
parser.add_argument('--seed', type=int, default=4)
|
428 |
+
parser.add_argument('--normalize', type=bool, default=True)
|
429 |
+
|
430 |
+
args = parser.parse_args()
|
431 |
+
return main(args, audio_path=audio_path)
|
432 |
+
|
433 |
+
|
434 |
+
text_box = gr.Textbox(label="Arabic Text")
|
435 |
+
input_audio = gr.Audio(label="Upload Audio", type="filepath", sources="upload")
|
436 |
+
title="ArTST: Arabic Speech Recognition"
|
437 |
+
description="ArTST: Arabic text and speech transformer based on the T5 transformer. This space demonstarates the ASR checkpoint finetuned on \
|
438 |
+
the MGB-2 dataset. The model is pre-trained on the MGB-2 dataset."
|
439 |
+
|
440 |
+
examples=["/l/users/amirbek.djanibekov/artst-tts-demo/samples/sample_audio.wav"]
|
441 |
+
|
442 |
+
article = """
|
443 |
+
<div style='margin:20px auto;'>
|
444 |
+
<p>References: <a href="https://arxiv.org/abs/2310.16621">ArTST paper</a> |
|
445 |
+
<a href="https://github.com/mbzuai-nlp/ArTST">GitHub</a> |
|
446 |
+
<a href="https://huggingface.co/MBZUAI/ArTST">Weights and Tokenizer</a></p>
|
447 |
+
<pre>
|
448 |
+
@misc{toyin2023artst,
|
449 |
+
title={ArTST: Arabic Text and Speech Transformer},
|
450 |
+
author={Hawau Olamide Toyin and Amirbek Djanibekov and Ajinkya Kulkarni and Hanan Aldarmaki},
|
451 |
+
year={2023},
|
452 |
+
eprint={2310.16621},
|
453 |
+
archivePrefix={arXiv},
|
454 |
+
primaryClass={cs.CL}
|
455 |
+
}
|
456 |
+
</pre>
|
457 |
+
<p>Speaker embeddings were generated from <a href="http://www.festvox.org/cmu_arctic/">CMU ARCTIC</a>.</p>
|
458 |
+
<p>ArTST is based on <a href="https://arxiv.org/abs/2110.07205">SpeechT5 architecture</a>.</p>
|
459 |
+
</div>
|
460 |
+
"""
|
461 |
+
|
462 |
+
demo = gr.Interface(inference, \
|
463 |
+
inputs=input_audio, outputs=text_box, title=title, description=description, examples=examples, article=article)
|
464 |
+
|
465 |
+
if __name__ == "__main__":
|
466 |
+
demo.launch(share=True)
|
artst/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . import data, tasks, criterions, models # noqa
|
artst/criterions/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import os
|
3 |
+
|
4 |
+
|
5 |
+
for file in os.listdir(os.path.dirname(__file__)):
|
6 |
+
if file.endswith(".py") and not file.startswith("_"):
|
7 |
+
criterion_name = file[: file.find(".py")]
|
8 |
+
importlib.import_module(
|
9 |
+
"artst.criterions." + criterion_name
|
10 |
+
)
|
artst/criterions/artst_criterion.py
ADDED
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transform (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
# Based on speecht5, fairseq and espnet code bases
|
5 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import re
|
9 |
+
from dataclasses import dataclass
|
10 |
+
|
11 |
+
import math
|
12 |
+
from fairseq import metrics, utils
|
13 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
14 |
+
from artst.criterions.text_to_speech_loss import TexttoSpeechLoss
|
15 |
+
from artst.criterions.text_pretrain_criterion import TextPretrainCriterion, TextPretrainCriterionConfig
|
16 |
+
from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterionConfig
|
17 |
+
from artst.criterions.speech_pretrain_criterion import SpeechPretrainCriterion, SpeechPretrainCriterionConfig
|
18 |
+
from artst.criterions.speech_to_text_loss import SpeechtoTextLoss, SpeechtoTextLossConfig
|
19 |
+
from fairseq.logging.meters import safe_round
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class ArTSTCriterionConfig(
|
23 |
+
LabelSmoothedCrossEntropyCriterionConfig,
|
24 |
+
TextPretrainCriterionConfig,
|
25 |
+
SpeechPretrainCriterionConfig,
|
26 |
+
SpeechtoTextLossConfig
|
27 |
+
):
|
28 |
+
pass
|
29 |
+
|
30 |
+
@register_criterion(
|
31 |
+
"artst", dataclass=ArTSTCriterionConfig
|
32 |
+
)
|
33 |
+
class ArTSTCriterion(FairseqCriterion):
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
task,
|
37 |
+
sentence_avg,
|
38 |
+
label_smoothing,
|
39 |
+
pred_masked_weight,
|
40 |
+
pred_nomask_weight,
|
41 |
+
loss_weights=None,
|
42 |
+
log_keys=None,
|
43 |
+
ignore_prefix_size=0,
|
44 |
+
report_accuracy=False,
|
45 |
+
use_masking=True,
|
46 |
+
use_weighted_masking=False,
|
47 |
+
loss_type="L1",
|
48 |
+
bce_pos_weight=5.0,
|
49 |
+
bce_loss_lambda=1.0,
|
50 |
+
use_guided_attn_loss=False,
|
51 |
+
num_heads_applied_guided_attn=2,
|
52 |
+
ce_weight=1.0,
|
53 |
+
ctc_weight=0.0,
|
54 |
+
hubert_weight=1.0,
|
55 |
+
dec_weight=1.0,
|
56 |
+
bart_weight=1.0,
|
57 |
+
):
|
58 |
+
super().__init__(task)
|
59 |
+
self.speech_criterion = TexttoSpeechLoss(
|
60 |
+
task,
|
61 |
+
sentence_avg,
|
62 |
+
use_masking,
|
63 |
+
use_weighted_masking,
|
64 |
+
loss_type,
|
65 |
+
bce_pos_weight,
|
66 |
+
bce_loss_lambda,
|
67 |
+
use_guided_attn_loss,
|
68 |
+
num_heads_applied_guided_attn=num_heads_applied_guided_attn,
|
69 |
+
)
|
70 |
+
self.text_criterion = SpeechtoTextLoss(
|
71 |
+
SpeechtoTextLossConfig,
|
72 |
+
task,
|
73 |
+
sentence_avg,
|
74 |
+
label_smoothing,
|
75 |
+
ignore_prefix_size,
|
76 |
+
report_accuracy,
|
77 |
+
ce_weight,
|
78 |
+
ctc_weight
|
79 |
+
)
|
80 |
+
self.text_pretrain_criterion = TextPretrainCriterion(
|
81 |
+
task,
|
82 |
+
sentence_avg,
|
83 |
+
bart_weight,
|
84 |
+
loss_weights,
|
85 |
+
)
|
86 |
+
self.speech_pretrain_criterion = SpeechPretrainCriterion(
|
87 |
+
task,
|
88 |
+
sentence_avg,
|
89 |
+
pred_masked_weight,
|
90 |
+
pred_nomask_weight,
|
91 |
+
loss_weights,
|
92 |
+
log_keys,
|
93 |
+
use_masking,
|
94 |
+
use_weighted_masking,
|
95 |
+
loss_type,
|
96 |
+
bce_pos_weight,
|
97 |
+
hubert_weight,
|
98 |
+
dec_weight
|
99 |
+
)
|
100 |
+
|
101 |
+
def forward(self, model, sample, reduce=True):
|
102 |
+
"""Compute the loss for the given sample.
|
103 |
+
|
104 |
+
Returns a tuple with three elements:
|
105 |
+
1) the loss
|
106 |
+
2) the sample size, which is used as the denominator for the gradient
|
107 |
+
3) logging outputs to display while training
|
108 |
+
"""
|
109 |
+
|
110 |
+
task_name = sample['task_name']
|
111 |
+
if task_name == 's2t' or task_name == 's2c':
|
112 |
+
return self.text_criterion(model, sample, reduce)
|
113 |
+
elif task_name == 't2s' or task_name == 's2s':
|
114 |
+
return self.speech_criterion(model, sample)
|
115 |
+
elif task_name == 'text_pretrain':
|
116 |
+
return self.text_pretrain_criterion(model, sample, reduce)
|
117 |
+
elif task_name == 'speech_pretrain':
|
118 |
+
return self.speech_pretrain_criterion(model, sample, reduce)
|
119 |
+
|
120 |
+
@classmethod
|
121 |
+
def reduce_metrics(cls, logging_outputs):
|
122 |
+
"""Aggregate logging outputs from data parallel training."""
|
123 |
+
logging_outputs_dict = {}
|
124 |
+
for logging_output in logging_outputs:
|
125 |
+
for task_name in logging_output:
|
126 |
+
if task_name not in ['s2t', 't2s', 's2c', 's2s', 'text_pretrain', 'speech_pretrain']:
|
127 |
+
continue
|
128 |
+
|
129 |
+
if task_name not in logging_outputs_dict:
|
130 |
+
logging_outputs_dict[task_name] = []
|
131 |
+
logging_outputs_dict[task_name].append(logging_output[task_name])
|
132 |
+
|
133 |
+
for task_name in logging_outputs_dict:
|
134 |
+
if task_name == 's2t':
|
135 |
+
# LabelSmoothedCrossEntropyCriterion.reduce_metrics([logging_output['s2t'] for logging_output in logging_outputs])
|
136 |
+
s2t_logging_output = logging_outputs_dict[task_name]
|
137 |
+
# s2t_sum = sum(log.get("ce_loss", 0) for log in logging_outputs)
|
138 |
+
loss_sum = sum(log.get("loss", 0) for log in s2t_logging_output)
|
139 |
+
nll_loss_sum = sum(log.get("nll_loss", 0) for log in s2t_logging_output)
|
140 |
+
ntokens = sum(log.get("ntokens", 0) for log in s2t_logging_output)
|
141 |
+
ce_loss_sum = sum(log.get("ce_loss", 0) for log in s2t_logging_output)
|
142 |
+
ctc_loss_sum = sum(log.get("ctc_loss", 0) for log in s2t_logging_output)
|
143 |
+
|
144 |
+
sample_size = max(1, sum(log.get("sample_size", 0) for log in s2t_logging_output))
|
145 |
+
metrics.log_scalar(
|
146 |
+
"s2t_loss", loss_sum / sample_size / math.log(2), sample_size, 1, round=3
|
147 |
+
)
|
148 |
+
|
149 |
+
metrics.log_scalar(
|
150 |
+
"s2t_nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, 2, round=3
|
151 |
+
)
|
152 |
+
metrics.log_derived(
|
153 |
+
"s2t_ppl", lambda meters: utils.get_perplexity(meters["s2t_nll_loss"].avg, 2)
|
154 |
+
)
|
155 |
+
metrics.log_scalar(
|
156 |
+
"ctc_loss", ctc_loss_sum / sample_size / math.log(2), ntokens, 2, round=3
|
157 |
+
)
|
158 |
+
metrics.log_scalar(
|
159 |
+
"ce_loss", ce_loss_sum / ntokens, ntokens, 2, round=3
|
160 |
+
)
|
161 |
+
|
162 |
+
total = utils.item(sum(log.get("total", 0) for log in s2t_logging_output))
|
163 |
+
if total > 0:
|
164 |
+
metrics.log_scalar("s2t_total", total)
|
165 |
+
n_correct = utils.item(
|
166 |
+
sum(log.get("n_correct", 0) for log in s2t_logging_output)
|
167 |
+
)
|
168 |
+
metrics.log_scalar("s2t_n_correct", n_correct)
|
169 |
+
metrics.log_derived(
|
170 |
+
"s2t_accuracy",
|
171 |
+
lambda meters: round(
|
172 |
+
meters["s2t_n_correct"].sum * 100.0 / meters["s2t_total"].sum, 3
|
173 |
+
)
|
174 |
+
if meters["s2t_total"].sum > 0
|
175 |
+
else float("nan"),
|
176 |
+
2
|
177 |
+
)
|
178 |
+
c_errors = sum(log.get("c_errors", 0) for log in s2t_logging_output)
|
179 |
+
metrics.log_scalar("_c_errors", c_errors)
|
180 |
+
c_total = sum(log.get("c_total", 0) for log in s2t_logging_output)
|
181 |
+
metrics.log_scalar("_c_total", c_total)
|
182 |
+
w_errors = sum(log.get("w_errors", 0) for log in s2t_logging_output)
|
183 |
+
metrics.log_scalar("_w_errors", w_errors)
|
184 |
+
wv_errors = sum(log.get("wv_errors", 0) for log in s2t_logging_output)
|
185 |
+
metrics.log_scalar("_wv_errors", wv_errors)
|
186 |
+
w_total = sum(log.get("w_total", 0) for log in s2t_logging_output)
|
187 |
+
metrics.log_scalar("_w_total", w_total)
|
188 |
+
if c_total > 0:
|
189 |
+
metrics.log_derived(
|
190 |
+
"uer",
|
191 |
+
lambda meters: safe_round(
|
192 |
+
meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3
|
193 |
+
)
|
194 |
+
if meters["_c_total"].sum > 0
|
195 |
+
else float("nan"),
|
196 |
+
)
|
197 |
+
if w_total > 0:
|
198 |
+
metrics.log_derived(
|
199 |
+
"wer",
|
200 |
+
lambda meters: safe_round(
|
201 |
+
meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3
|
202 |
+
)
|
203 |
+
if meters["_w_total"].sum > 0
|
204 |
+
else float("nan"),
|
205 |
+
)
|
206 |
+
metrics.log_derived(
|
207 |
+
"raw_wer",
|
208 |
+
lambda meters: safe_round(
|
209 |
+
meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3
|
210 |
+
)
|
211 |
+
if meters["_w_total"].sum > 0
|
212 |
+
else float("nan"),
|
213 |
+
)
|
214 |
+
|
215 |
+
if task_name == 't2s':
|
216 |
+
# TTSLossCriterion.reduce_metrics([logging_output['t2s'] for logging_output in logging_outputs])
|
217 |
+
# t2s_sum = sum(log.get("speech_loss", 0) for log in logging_outputs)
|
218 |
+
t2s_logging_output = logging_outputs_dict[task_name]
|
219 |
+
loss_sum = sum(log.get("loss", 0) for log in t2s_logging_output)
|
220 |
+
l1_loss_sum = sum(log.get("l1_loss", 0) for log in t2s_logging_output)
|
221 |
+
l2_loss_sum = sum(log.get("l2_loss", 0) for log in t2s_logging_output)
|
222 |
+
bce_loss_sum = sum(log.get("bce_loss", 0) for log in t2s_logging_output)
|
223 |
+
sample_size = max(1, sum(log.get("sample_size", 0) for log in t2s_logging_output))
|
224 |
+
metrics.log_scalar(
|
225 |
+
"t2s_loss", loss_sum / sample_size, sample_size, 1, round=5
|
226 |
+
)
|
227 |
+
encoder_alpha_sum = sum(log.get("encoder_alpha", 0) for log in t2s_logging_output)
|
228 |
+
decoder_alpha_sum = sum(log.get("decoder_alpha", 0) for log in t2s_logging_output)
|
229 |
+
ngpu = sum(log.get("ngpu", 0) for log in t2s_logging_output)
|
230 |
+
|
231 |
+
metrics.log_scalar(
|
232 |
+
"t2s_l1_loss", l1_loss_sum / sample_size, sample_size, 2, round=5
|
233 |
+
)
|
234 |
+
metrics.log_scalar(
|
235 |
+
"t2s_l2_loss", l2_loss_sum / sample_size, sample_size, 2, round=5
|
236 |
+
)
|
237 |
+
metrics.log_scalar(
|
238 |
+
"t2s_bce_loss", bce_loss_sum / sample_size, sample_size, 2, round=5
|
239 |
+
)
|
240 |
+
metrics.log_scalar(
|
241 |
+
"t2s_encoder_alpha", encoder_alpha_sum / sample_size, sample_size, round=5
|
242 |
+
)
|
243 |
+
metrics.log_scalar(
|
244 |
+
"t2s_decoder_alpha", decoder_alpha_sum / sample_size, sample_size, round=5
|
245 |
+
)
|
246 |
+
|
247 |
+
if "enc_dec_attn_loss" in t2s_logging_output[0]:
|
248 |
+
enc_dec_attn_loss_sum = sum(log.get("enc_dec_attn_loss", 0) for log in t2s_logging_output)
|
249 |
+
metrics.log_scalar(
|
250 |
+
"t2s_enc_dec_attn_loss", enc_dec_attn_loss_sum / sample_size, sample_size, round=8
|
251 |
+
)
|
252 |
+
|
253 |
+
if task_name == 's2c':
|
254 |
+
s2c_logging_output = logging_outputs_dict[task_name]
|
255 |
+
loss_sum = sum(log.get("loss", 0) for log in s2c_logging_output)
|
256 |
+
nll_loss_sum = sum(log.get("nll_loss", 0) for log in s2c_logging_output)
|
257 |
+
ntokens = sum(log.get("ntokens", 0) for log in s2c_logging_output)
|
258 |
+
|
259 |
+
sample_size = max(1, sum(log.get("sample_size", 0) for log in s2c_logging_output))
|
260 |
+
metrics.log_scalar(
|
261 |
+
"s2c_loss", loss_sum / sample_size / math.log(2), sample_size, 1, round=3
|
262 |
+
)
|
263 |
+
|
264 |
+
metrics.log_scalar(
|
265 |
+
"s2c_nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, 2, round=3
|
266 |
+
)
|
267 |
+
|
268 |
+
total = utils.item(sum(log.get("total", 0) for log in s2c_logging_output))
|
269 |
+
if total > 0:
|
270 |
+
metrics.log_scalar("s2c_total", total)
|
271 |
+
n_correct = utils.item(sum(log.get("n_correct", 0) for log in s2c_logging_output))
|
272 |
+
metrics.log_scalar("s2c_n_correct", n_correct)
|
273 |
+
metrics.log_derived(
|
274 |
+
"s2c_accuracy",
|
275 |
+
lambda meters: round(
|
276 |
+
meters["s2c_n_correct"].sum * 100.0 / meters["s2c_total"].sum, 3
|
277 |
+
)
|
278 |
+
if meters["s2c_total"].sum > 0
|
279 |
+
else float("nan"),
|
280 |
+
2
|
281 |
+
)
|
282 |
+
|
283 |
+
if task_name == 's2s':
|
284 |
+
s2s_logging_output = logging_outputs_dict[task_name]
|
285 |
+
loss_sum = sum(log.get("loss", 0) for log in s2s_logging_output)
|
286 |
+
l1_loss_sum = sum(log.get("l1_loss", 0) for log in s2s_logging_output)
|
287 |
+
l2_loss_sum = sum(log.get("l2_loss", 0) for log in s2s_logging_output)
|
288 |
+
bce_loss_sum = sum(log.get("bce_loss", 0) for log in s2s_logging_output)
|
289 |
+
sample_size = max(1, sum(log.get("sample_size", 0) for log in s2s_logging_output))
|
290 |
+
metrics.log_scalar(
|
291 |
+
"s2s_loss", loss_sum / sample_size, sample_size, 1, round=5
|
292 |
+
)
|
293 |
+
encoder_alpha_sum = sum(log.get("encoder_alpha", 0) for log in s2s_logging_output)
|
294 |
+
decoder_alpha_sum = sum(log.get("decoder_alpha", 0) for log in s2s_logging_output)
|
295 |
+
ngpu = sum(log.get("ngpu", 0) for log in s2s_logging_output)
|
296 |
+
|
297 |
+
metrics.log_scalar(
|
298 |
+
"s2s_l1_loss", l1_loss_sum / sample_size, sample_size, 2, round=5
|
299 |
+
)
|
300 |
+
metrics.log_scalar(
|
301 |
+
"s2s_l2_loss", l2_loss_sum / sample_size, sample_size, 2, round=5
|
302 |
+
)
|
303 |
+
metrics.log_scalar(
|
304 |
+
"s2s_bce_loss", bce_loss_sum / sample_size, sample_size, 2, round=5
|
305 |
+
)
|
306 |
+
metrics.log_scalar(
|
307 |
+
"s2s_decoder_alpha", decoder_alpha_sum / sample_size, sample_size, round=5
|
308 |
+
)
|
309 |
+
|
310 |
+
if "enc_dec_attn_loss" in s2s_logging_output[0]:
|
311 |
+
enc_dec_attn_loss_sum = sum(log.get("enc_dec_attn_loss", 0) for log in s2s_logging_output)
|
312 |
+
metrics.log_scalar(
|
313 |
+
"s2s_enc_dec_attn_loss", enc_dec_attn_loss_sum / sample_size, sample_size, round=8
|
314 |
+
)
|
315 |
+
|
316 |
+
if task_name == 'text_pretrain':
|
317 |
+
bart_logging_output = logging_outputs_dict[task_name]
|
318 |
+
loss_sum = sum(log.get("loss", 0) for log in bart_logging_output)
|
319 |
+
ntokens = sum(log.get("ntokens", 0) for log in bart_logging_output)
|
320 |
+
sample_size = max(1, sum(log.get("sample_size", 0) for log in bart_logging_output))
|
321 |
+
bart_loss_sum = sum(log.get("bart_loss", 0) for log in bart_logging_output)
|
322 |
+
|
323 |
+
# we divide by log(2) to convert the loss from base e to base 2
|
324 |
+
metrics.log_scalar(
|
325 |
+
"text_loss", loss_sum / sample_size / math.log(2), sample_size, round=3
|
326 |
+
)
|
327 |
+
metrics.log_scalar(
|
328 |
+
"bart_loss", bart_loss_sum / sample_size / math.log(2), ntokens, 2, round=3
|
329 |
+
)
|
330 |
+
if sample_size != ntokens:
|
331 |
+
metrics.log_scalar(
|
332 |
+
"bart_nll_loss", bart_loss_sum / ntokens / math.log(2), ntokens, round=3
|
333 |
+
)
|
334 |
+
metrics.log_derived(
|
335 |
+
"bart_ppl", lambda meters: utils.get_perplexity(meters["bart_nll_loss"].avg)
|
336 |
+
)
|
337 |
+
else:
|
338 |
+
metrics.log_derived(
|
339 |
+
"bart_ppl", lambda meters: utils.get_perplexity(meters["bart_loss"].avg)
|
340 |
+
)
|
341 |
+
metrics.log_scalar("bart_wpb", ntokens, priority=180, round=1)
|
342 |
+
|
343 |
+
val_prob_perplexity = 0
|
344 |
+
val_code_perplexity = 0
|
345 |
+
sample_size_pp = 0
|
346 |
+
count_log_cp = 0
|
347 |
+
for log in bart_logging_output:
|
348 |
+
if "loss_prob_perplexity" in log:
|
349 |
+
val_prob_perplexity = val_prob_perplexity + log["loss_prob_perplexity"]
|
350 |
+
sample_size_pp = sample_size_pp + log["sample_size"]
|
351 |
+
if "code_perplexity" in log:
|
352 |
+
val_code_perplexity = val_code_perplexity + log["code_perplexity"]
|
353 |
+
count_log_cp = count_log_cp + 1
|
354 |
+
if val_prob_perplexity > 0:
|
355 |
+
metrics.log_scalar("text_loss_prob_perplexity", val_prob_perplexity / sample_size_pp / math.log(2), round=3)
|
356 |
+
if val_code_perplexity > 0:
|
357 |
+
metrics.log_scalar("text_code_perplexity", val_code_perplexity / count_log_cp, round=3)
|
358 |
+
|
359 |
+
if task_name == 'speech_pretrain':
|
360 |
+
hubert_logging_output = logging_outputs_dict[task_name]
|
361 |
+
loss_sum = sum(log.get("loss", 0) for log in hubert_logging_output)
|
362 |
+
ntokens = sum(log.get("ntokens", 0) for log in hubert_logging_output)
|
363 |
+
sample_size = max(1, sum(log.get("sample_size", 0) for log in hubert_logging_output))
|
364 |
+
dec_loss_sum = sum(log.get("dec_loss", 0) for log in hubert_logging_output)
|
365 |
+
l1_loss_sum = sum(log.get("l1_loss", 0) for log in hubert_logging_output)
|
366 |
+
l2_loss_sum = sum(log.get("l2_loss", 0) for log in hubert_logging_output)
|
367 |
+
bce_loss_sum = sum(log.get("bce_loss", 0) for log in hubert_logging_output)
|
368 |
+
ngpu = sum(log.get("ngpu", 0) for log in hubert_logging_output)
|
369 |
+
|
370 |
+
metrics.log_scalar("hubert_loss", loss_sum / sample_size / math.log(2), sample_size, round=3)
|
371 |
+
if sample_size != ntokens:
|
372 |
+
metrics.log_scalar("hubert_nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3)
|
373 |
+
metrics.log_derived("hubert_ppl", lambda meters: utils.get_perplexity(meters["hubert_nll_loss"].avg))
|
374 |
+
else:
|
375 |
+
metrics.log_derived("hubert_ppl", lambda meters: utils.get_perplexity(meters["hubert_loss"].avg))
|
376 |
+
|
377 |
+
counts = {}
|
378 |
+
for lk in hubert_logging_output[0].keys():
|
379 |
+
if lk.startswith("count_"):
|
380 |
+
val = sum(log[lk] for log in hubert_logging_output)
|
381 |
+
metrics.log_scalar("hubert_" + lk, val)
|
382 |
+
counts[lk] = val
|
383 |
+
|
384 |
+
for lk in hubert_logging_output[0].keys():
|
385 |
+
if lk.startswith("loss_") and lk != 'loss_prob_perplexity':
|
386 |
+
val = sum(log[lk] for log in hubert_logging_output)
|
387 |
+
metrics.log_scalar("hubert_" + lk, val / sample_size / math.log(2), round=3)
|
388 |
+
elif lk.startswith("correct_"):
|
389 |
+
val = sum(log[lk] for log in hubert_logging_output)
|
390 |
+
metrics.log_scalar("hubert_" + lk, val / counts[re.sub("correct", "count", lk)])
|
391 |
+
# elif lk == 'code_perplexity':
|
392 |
+
# val = sum(log[lk] for log in hubert_logging_output)
|
393 |
+
# metrics.log_scalar("hubert_" + lk, val / len(hubert_logging_output), round=3)
|
394 |
+
|
395 |
+
val_prob_perplexity = 0
|
396 |
+
val_code_perplexity = 0
|
397 |
+
sample_size_pp = 0
|
398 |
+
count_log_cp = 0
|
399 |
+
for log in hubert_logging_output:
|
400 |
+
if "loss_prob_perplexity" in log:
|
401 |
+
val_prob_perplexity = val_prob_perplexity + log["loss_prob_perplexity"]
|
402 |
+
sample_size_pp = sample_size_pp + log["sample_size"]
|
403 |
+
if "code_perplexity" in log:
|
404 |
+
val_code_perplexity = val_code_perplexity + log["code_perplexity"]
|
405 |
+
count_log_cp = count_log_cp + 1
|
406 |
+
if val_prob_perplexity > 0:
|
407 |
+
metrics.log_scalar("hubert_loss_prob_perplexity", val_prob_perplexity / sample_size_pp / math.log(2), round=3)
|
408 |
+
if val_code_perplexity > 0:
|
409 |
+
metrics.log_scalar("hubert_code_perplexity", val_code_perplexity / count_log_cp, round=3)
|
410 |
+
|
411 |
+
metrics.log_scalar(
|
412 |
+
"hubert_dec_loss", dec_loss_sum / ngpu, sample_size, 2, round=5
|
413 |
+
)
|
414 |
+
metrics.log_scalar(
|
415 |
+
"hubert_l1_loss", l1_loss_sum / ngpu, sample_size, 2, round=5
|
416 |
+
)
|
417 |
+
metrics.log_scalar(
|
418 |
+
"hubert_l2_loss", l2_loss_sum / ngpu, sample_size, 2, round=5
|
419 |
+
)
|
420 |
+
metrics.log_scalar(
|
421 |
+
"hubert_bce_loss", bce_loss_sum / ngpu, sample_size, 2, round=5
|
422 |
+
)
|
423 |
+
if "enc_dec_attn_loss" in hubert_logging_output[0]:
|
424 |
+
enc_dec_attn_loss_sum = sum(log.get("enc_dec_attn_loss", 0) for log in hubert_logging_output)
|
425 |
+
metrics.log_scalar(
|
426 |
+
"hubert_enc_dec_attn_loss", enc_dec_attn_loss_sum / ngpu, sample_size, round=8
|
427 |
+
)
|
428 |
+
metrics.log_scalar("hubert_wpb", ntokens, priority=180, round=1)
|
429 |
+
|
430 |
+
loss = sum(log.get("loss", 0) for log in logging_outputs)
|
431 |
+
sample_size = max(1, sum(log.get("sample_size", 0) for log in logging_outputs))
|
432 |
+
metrics.log_scalar(
|
433 |
+
"loss", loss / sample_size, sample_size, 1, round=5
|
434 |
+
)
|
435 |
+
|
436 |
+
@staticmethod
|
437 |
+
def logging_outputs_can_be_summed() -> bool:
|
438 |
+
"""
|
439 |
+
Whether the logging outputs returned by `forward` can be summed
|
440 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
441 |
+
to True will improves distributed training speed.
|
442 |
+
"""
|
443 |
+
return False
|
artst/criterions/speech_pretrain_criterion.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transform (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
# Based on speecht5, fairseq and espnet code bases
|
5 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import math
|
9 |
+
import re
|
10 |
+
from dataclasses import dataclass, field
|
11 |
+
from typing import List, Optional
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from fairseq import metrics, utils
|
16 |
+
from fairseq.criterions import FairseqCriterion
|
17 |
+
from artst.criterions.text_to_speech_loss import TexttoSpeechLoss, TexttoSpeechLossConfig
|
18 |
+
|
19 |
+
|
20 |
+
@dataclass
|
21 |
+
class SpeechPretrainCriterionConfig(TexttoSpeechLossConfig):
|
22 |
+
pred_masked_weight: float = field(
|
23 |
+
default=1.0,
|
24 |
+
metadata={"help": "weight for predictive loss for masked frames"},
|
25 |
+
)
|
26 |
+
pred_nomask_weight: float = field(
|
27 |
+
default=0.0,
|
28 |
+
metadata={"help": "weight for predictive loss for unmasked frames"},
|
29 |
+
)
|
30 |
+
loss_weights: Optional[List[float]] = field(
|
31 |
+
default_factory=lambda: [10,],
|
32 |
+
metadata={"help": "weights for additional loss terms (not first one)"},
|
33 |
+
)
|
34 |
+
log_keys: List[str] = field(
|
35 |
+
default_factory=lambda: [],
|
36 |
+
metadata={"help": "output keys to log"},
|
37 |
+
)
|
38 |
+
hubert_weight: float = field(
|
39 |
+
default=1.0,
|
40 |
+
metadata={"help": "weight of hubert loss"},
|
41 |
+
)
|
42 |
+
dec_weight: float = field(
|
43 |
+
default=1.0,
|
44 |
+
metadata={"help": "weight of decoder loss"},
|
45 |
+
)
|
46 |
+
|
47 |
+
|
48 |
+
class SpeechPretrainCriterion(FairseqCriterion):
|
49 |
+
def __init__(
|
50 |
+
self,
|
51 |
+
task,
|
52 |
+
sentence_avg,
|
53 |
+
pred_masked_weight,
|
54 |
+
pred_nomask_weight,
|
55 |
+
loss_weights=None,
|
56 |
+
log_keys=None,
|
57 |
+
use_masking=True,
|
58 |
+
use_weighted_masking=False,
|
59 |
+
loss_type="L1",
|
60 |
+
bce_pos_weight=5.0,
|
61 |
+
hubert_weight=1.0,
|
62 |
+
dec_weight=1.0,
|
63 |
+
):
|
64 |
+
super().__init__(task)
|
65 |
+
self.pred_masked_weight = pred_masked_weight
|
66 |
+
self.pred_nomask_weight = pred_nomask_weight
|
67 |
+
self.loss_weights = loss_weights
|
68 |
+
self.log_keys = [] if log_keys is None else log_keys
|
69 |
+
self.hubert_weight = hubert_weight
|
70 |
+
self.dec_weight = dec_weight
|
71 |
+
|
72 |
+
self.speech_criterion = TexttoSpeechLoss(
|
73 |
+
task,
|
74 |
+
sentence_avg,
|
75 |
+
use_masking,
|
76 |
+
use_weighted_masking,
|
77 |
+
loss_type,
|
78 |
+
bce_pos_weight,
|
79 |
+
)
|
80 |
+
|
81 |
+
def forward(self, model, sample, reduce=True, log_pred=False):
|
82 |
+
"""Compute the loss for the given sample.
|
83 |
+
Returns a tuple with three elements:
|
84 |
+
1) the loss
|
85 |
+
2) the sample size, which is used as the denominator for the gradient
|
86 |
+
3) logging outputs to display while training
|
87 |
+
"""
|
88 |
+
if self.dec_weight == 0:
|
89 |
+
sample["net_input"]["only_hubert"] = True
|
90 |
+
net_output, net_output_dec = model(target_list=sample["target_list"], **sample["net_input"])
|
91 |
+
loss = 0.
|
92 |
+
sample_size = 0
|
93 |
+
logging_output = {}
|
94 |
+
reduction = "sum" if reduce else "none"
|
95 |
+
|
96 |
+
loss_m_list = []
|
97 |
+
logp_m_list = model.get_logits(net_output, True)
|
98 |
+
targ_m_list = model.get_targets(None, net_output, True)
|
99 |
+
assert self.pred_masked_weight == 0 or len(logp_m_list) > 0
|
100 |
+
for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)):
|
101 |
+
loss_m = F.cross_entropy(logp_m, targ_m, reduction=reduction)
|
102 |
+
loss_m_list.append(loss_m)
|
103 |
+
logging_output[f"loss_m_{i}"] = loss_m.detach().item()
|
104 |
+
if self.pred_masked_weight > 0:
|
105 |
+
loss += self.pred_masked_weight * sum(loss_m_list)
|
106 |
+
sample_size += targ_m_list[0].numel()
|
107 |
+
|
108 |
+
loss_u_list = []
|
109 |
+
logp_u_list = model.get_logits(net_output, False)
|
110 |
+
targ_u_list = model.get_targets(None, net_output, False)
|
111 |
+
assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0
|
112 |
+
for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)):
|
113 |
+
loss_u = F.cross_entropy(logp_u, targ_u, reduction=reduction)
|
114 |
+
loss_u_list.append(loss_u)
|
115 |
+
logging_output[f"loss_u_{i}"] = loss_u.detach().item()
|
116 |
+
if self.pred_nomask_weight > 0:
|
117 |
+
loss += self.pred_nomask_weight * sum(loss_u_list)
|
118 |
+
sample_size += targ_u_list[0].numel()
|
119 |
+
|
120 |
+
if self.loss_weights is not None:
|
121 |
+
assert hasattr(model, "get_extra_losses")
|
122 |
+
extra_losses, names = model.get_extra_losses(net_output)
|
123 |
+
if torch.is_tensor(extra_losses):
|
124 |
+
extra_losses = [extra_losses]
|
125 |
+
names = [names]
|
126 |
+
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
|
127 |
+
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
|
128 |
+
if len(self.loss_weights) > len(extra_losses):
|
129 |
+
modified_loss_weight = self.loss_weights[:len(extra_losses)]
|
130 |
+
else:
|
131 |
+
modified_loss_weight = self.loss_weights
|
132 |
+
|
133 |
+
# assert len(extra_losses) == len(self.loss_weights), f"{len(extra_losses)}, {len(self.loss_weights)}"
|
134 |
+
for p, n, coef in zip(extra_losses, names, modified_loss_weight):
|
135 |
+
# print(n + str(coef))
|
136 |
+
if coef != 0 and p is not None:
|
137 |
+
p = coef * p.float() * sample_size
|
138 |
+
loss += p
|
139 |
+
logging_output[f"loss_{n}"] = p.detach().item()
|
140 |
+
|
141 |
+
logging_output = {
|
142 |
+
"ntokens": sample_size,
|
143 |
+
"nsentences": sample["id"].numel(),
|
144 |
+
"sample_size": sample_size,
|
145 |
+
"ngpu": 1,
|
146 |
+
**logging_output,
|
147 |
+
}
|
148 |
+
|
149 |
+
if 'loss_prob_perplexity' in logging_output:
|
150 |
+
logging_output['code_perplexity'] = net_output['code_perplexity'].detach().item()
|
151 |
+
|
152 |
+
for lk in self.log_keys:
|
153 |
+
if lk in net_output:
|
154 |
+
logging_output[lk] = float((net_output[lk].item()))
|
155 |
+
|
156 |
+
def compute_correct(logits):
|
157 |
+
if logits.numel() == 0:
|
158 |
+
return 0, 0
|
159 |
+
else:
|
160 |
+
assert logits.dim() > 1, logits.shape
|
161 |
+
max = logits.argmax(-1) == 0
|
162 |
+
min = logits.argmin(-1) == 0
|
163 |
+
both = max & min
|
164 |
+
corr = max.long().sum().item() - both.long().sum().item()
|
165 |
+
count = max.numel()
|
166 |
+
return corr, count
|
167 |
+
|
168 |
+
with torch.no_grad():
|
169 |
+
for i, logp_m in enumerate(logp_m_list):
|
170 |
+
corr_m, count_m = compute_correct(logp_m)
|
171 |
+
logging_output[f"correct_m_{i}"] = corr_m
|
172 |
+
logging_output[f"count_m_{i}"] = count_m
|
173 |
+
|
174 |
+
for i, logp_u in enumerate(logp_u_list):
|
175 |
+
corr_u, count_u = compute_correct(logp_u)
|
176 |
+
logging_output[f"correct_u_{i}"] = corr_u
|
177 |
+
logging_output[f"count_u_{i}"] = count_u
|
178 |
+
|
179 |
+
if self.dec_weight == 0.0:
|
180 |
+
logging_output["loss"] = loss.item() if reduce else loss
|
181 |
+
return loss, sample_size, logging_output
|
182 |
+
|
183 |
+
# ## dec loss
|
184 |
+
dec_loss, l1_loss, l2_loss, bce_loss, enc_dec_attn_loss = self.speech_criterion.compute_loss(model, net_output_dec, sample)
|
185 |
+
|
186 |
+
# Log tts loss
|
187 |
+
logging_output['dec_loss'] = dec_loss.item()
|
188 |
+
logging_output['l1_loss'] = l1_loss.item()
|
189 |
+
logging_output['l2_loss'] = l2_loss.item()
|
190 |
+
logging_output['bce_loss'] = bce_loss.item()
|
191 |
+
if enc_dec_attn_loss is not None:
|
192 |
+
logging_output['enc_dec_attn_loss'] = enc_dec_attn_loss.item()
|
193 |
+
|
194 |
+
loss = self.hubert_weight * loss + self.dec_weight * sample_size * dec_loss
|
195 |
+
logging_output["loss"] = loss.item() if reduce else loss
|
196 |
+
return loss, sample_size, logging_output
|
197 |
+
|
198 |
+
@staticmethod
|
199 |
+
def reduce_metrics(logging_outputs) -> None:
|
200 |
+
"""Aggregate logging outputs from data parallel training (copied from normal cross entropy)."""
|
201 |
+
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
202 |
+
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
203 |
+
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
204 |
+
dec_loss_sum = sum(log.get("dec_loss", 0) for log in logging_outputs)
|
205 |
+
l1_loss_sum = sum(log.get("l1_loss", 0) for log in logging_outputs)
|
206 |
+
l2_loss_sum = sum(log.get("l2_loss", 0) for log in logging_outputs)
|
207 |
+
bce_loss_sum = sum(log.get("bce_loss", 0) for log in logging_outputs)
|
208 |
+
ngpu = sum(log.get("ngpu", 0) for log in logging_outputs)
|
209 |
+
|
210 |
+
metrics.log_scalar("loss", loss_sum / sample_size / math.log(2), sample_size, round=3)
|
211 |
+
if sample_size != ntokens:
|
212 |
+
metrics.log_scalar("nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3)
|
213 |
+
metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg))
|
214 |
+
else:
|
215 |
+
metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["loss"].avg))
|
216 |
+
|
217 |
+
counts = {}
|
218 |
+
for lk in logging_outputs[0].keys():
|
219 |
+
if lk.startswith("count_"):
|
220 |
+
val = sum(log[lk] for log in logging_outputs)
|
221 |
+
metrics.log_scalar(lk, val)
|
222 |
+
counts[lk] = val
|
223 |
+
|
224 |
+
for lk in logging_outputs[0].keys():
|
225 |
+
if lk.startswith("loss_"):
|
226 |
+
val = sum(log[lk] for log in logging_outputs)
|
227 |
+
metrics.log_scalar(lk, val / sample_size / math.log(2), round=3)
|
228 |
+
elif lk.startswith("correct_"):
|
229 |
+
val = sum(log[lk] for log in logging_outputs)
|
230 |
+
metrics.log_scalar(lk, val / counts[re.sub("correct", "count", lk)])
|
231 |
+
elif lk == 'code_perplexity':
|
232 |
+
val = sum(log[lk] for log in logging_outputs)
|
233 |
+
metrics.log_scalar(lk, val / len(logging_outputs), round=3)
|
234 |
+
|
235 |
+
metrics.log_scalar(
|
236 |
+
"dec_loss", dec_loss_sum / ngpu, sample_size, 2, round=5
|
237 |
+
)
|
238 |
+
metrics.log_scalar(
|
239 |
+
"l1_loss", l1_loss_sum / ngpu, sample_size, 2, round=5
|
240 |
+
)
|
241 |
+
metrics.log_scalar(
|
242 |
+
"l2_loss", l2_loss_sum / ngpu, sample_size, 2, round=5
|
243 |
+
)
|
244 |
+
metrics.log_scalar(
|
245 |
+
"bce_loss", bce_loss_sum / ngpu, sample_size, 2, round=5
|
246 |
+
)
|
247 |
+
if "enc_dec_attn_loss" in logging_outputs[0]:
|
248 |
+
enc_dec_attn_loss_sum = sum(log.get("enc_dec_attn_loss", 0) for log in logging_outputs)
|
249 |
+
metrics.log_scalar(
|
250 |
+
"enc_dec_attn_loss", enc_dec_attn_loss_sum / ngpu, sample_size, round=8
|
251 |
+
)
|
252 |
+
|
253 |
+
@staticmethod
|
254 |
+
def aggregate_logging_outputs(logging_outputs):
|
255 |
+
"""Aggregate logging outputs from data parallel training."""
|
256 |
+
raise NotImplementedError()
|
257 |
+
|
258 |
+
@staticmethod
|
259 |
+
def logging_outputs_can_be_summed() -> bool:
|
260 |
+
"""
|
261 |
+
Whether the logging outputs returned by `forward` can be summed
|
262 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
263 |
+
to True will improves distributed training speed.
|
264 |
+
"""
|
265 |
+
return False
|
artst/criterions/speech_to_text_loss.py
ADDED
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transform (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
# Based on speecht5, fairseq and espnet code bases
|
5 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import math
|
9 |
+
from argparse import Namespace
|
10 |
+
from dataclasses import dataclass, field
|
11 |
+
from omegaconf import II
|
12 |
+
from typing import Optional
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from fairseq import metrics, utils
|
17 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
18 |
+
from fairseq.dataclass import FairseqDataclass
|
19 |
+
from fairseq.data.data_utils import post_process
|
20 |
+
from fairseq.tasks import FairseqTask
|
21 |
+
from fairseq.logging.meters import safe_round
|
22 |
+
|
23 |
+
import logging
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class SpeechtoTextLossConfig(FairseqDataclass):
|
28 |
+
zero_infinity: bool = field(
|
29 |
+
default=False,
|
30 |
+
metadata={"help": "zero inf loss when source length <= target length"},
|
31 |
+
)
|
32 |
+
sentence_avg: bool = II("optimization.sentence_avg")
|
33 |
+
post_process: Optional[str] = field(
|
34 |
+
default="sentencepiece",
|
35 |
+
metadata={
|
36 |
+
"help": "how to post process predictions into words. can be letter, "
|
37 |
+
"wordpiece, BPE symbols, etc. "
|
38 |
+
"See fairseq.data.data_utils.post_process() for full list of options"
|
39 |
+
},
|
40 |
+
)
|
41 |
+
wer_kenlm_model: Optional[str] = field(
|
42 |
+
default=None,
|
43 |
+
metadata={
|
44 |
+
"help": "if this is provided, use kenlm to compute wer (along with other wer_* args)"
|
45 |
+
},
|
46 |
+
)
|
47 |
+
wer_lexicon: Optional[str] = field(
|
48 |
+
default=None,
|
49 |
+
metadata={"help": "lexicon to use with wer_kenlm_model"},
|
50 |
+
)
|
51 |
+
wer_lm_weight: float = field(
|
52 |
+
default=2.0,
|
53 |
+
metadata={"help": "lm weight to use with wer_kenlm_model"},
|
54 |
+
)
|
55 |
+
wer_word_score: float = field(
|
56 |
+
default=-1.0,
|
57 |
+
metadata={"help": "lm word score to use with wer_kenlm_model"},
|
58 |
+
)
|
59 |
+
|
60 |
+
wer_args: Optional[str] = field(
|
61 |
+
default=None,
|
62 |
+
metadata={
|
63 |
+
"help": "DEPRECATED: tuple of (wer_kenlm_model, wer_lexicon, wer_lm_weight, wer_word_score)"
|
64 |
+
},
|
65 |
+
)
|
66 |
+
|
67 |
+
label_smoothing: float = field(
|
68 |
+
default=0.0,
|
69 |
+
metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
|
70 |
+
)
|
71 |
+
report_accuracy: bool = field(
|
72 |
+
default=False,
|
73 |
+
metadata={"help": "report accuracy metric"},
|
74 |
+
)
|
75 |
+
ignore_prefix_size: int = field(
|
76 |
+
default=0,
|
77 |
+
metadata={"help": "Ignore first N tokens"},
|
78 |
+
)
|
79 |
+
#: bool = II("optimization.sentence_avg")
|
80 |
+
|
81 |
+
ce_weight: float = field(
|
82 |
+
default=1.0,
|
83 |
+
metadata={"help": "loss weight for cross entropy"},
|
84 |
+
)
|
85 |
+
ctc_weight: float = field(
|
86 |
+
default=0.0,
|
87 |
+
metadata={"help": "loss weiehgt for ctc in ASR"},
|
88 |
+
)
|
89 |
+
|
90 |
+
|
91 |
+
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True):
|
92 |
+
if target.dim() == lprobs.dim() - 1:
|
93 |
+
target = target.unsqueeze(-1)
|
94 |
+
nll_loss = -lprobs.gather(dim=-1, index=target)
|
95 |
+
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
|
96 |
+
if ignore_index is not None:
|
97 |
+
pad_mask = target.eq(ignore_index)
|
98 |
+
nll_loss.masked_fill_(pad_mask, 0.0)
|
99 |
+
smooth_loss.masked_fill_(pad_mask, 0.0)
|
100 |
+
else:
|
101 |
+
nll_loss = nll_loss.squeeze(-1)
|
102 |
+
smooth_loss = smooth_loss.squeeze(-1)
|
103 |
+
if reduce:
|
104 |
+
nll_loss = nll_loss.sum()
|
105 |
+
smooth_loss = smooth_loss.sum()
|
106 |
+
eps_i = epsilon / (lprobs.size(-1) - 1)
|
107 |
+
loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss
|
108 |
+
return loss, nll_loss
|
109 |
+
|
110 |
+
|
111 |
+
class SpeechtoTextLoss(FairseqCriterion):
|
112 |
+
def __init__(
|
113 |
+
self,
|
114 |
+
cfg: SpeechtoTextLossConfig,
|
115 |
+
task: FairseqTask,
|
116 |
+
sentence_avg=True,
|
117 |
+
label_smoothing=0.1,
|
118 |
+
ignore_prefix_size=0,
|
119 |
+
report_accuracy=False,
|
120 |
+
ce_weight=1.0,
|
121 |
+
ctc_weight=0.0,
|
122 |
+
):
|
123 |
+
|
124 |
+
super().__init__(task)
|
125 |
+
self.blank_idx = (
|
126 |
+
task.target_dictionary.index(task.blank_symbol)
|
127 |
+
if hasattr(task, "blank_symbol")
|
128 |
+
else 0
|
129 |
+
)
|
130 |
+
#print ("self.blank_idx: ", self.blank_idx)
|
131 |
+
|
132 |
+
self.pad_idx = task.target_dictionary.pad()
|
133 |
+
self.eos_idx = task.target_dictionary.eos()
|
134 |
+
self.post_process = cfg.post_process
|
135 |
+
self.ce_weight = ce_weight
|
136 |
+
self.ctc_weight = ctc_weight
|
137 |
+
|
138 |
+
## for ce
|
139 |
+
self.sentence_avg = sentence_avg
|
140 |
+
self.eps = label_smoothing
|
141 |
+
self.ignore_prefix_size = ignore_prefix_size
|
142 |
+
self.report_accuracy = report_accuracy
|
143 |
+
|
144 |
+
if cfg.wer_args is not None:
|
145 |
+
(
|
146 |
+
cfg.wer_kenlm_model,
|
147 |
+
cfg.wer_lexicon,
|
148 |
+
cfg.wer_lm_weight,
|
149 |
+
cfg.wer_word_score,
|
150 |
+
) = eval(cfg.wer_args)
|
151 |
+
|
152 |
+
if cfg.wer_kenlm_model is not None:
|
153 |
+
from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
|
154 |
+
|
155 |
+
dec_args = Namespace()
|
156 |
+
dec_args.nbest = 1
|
157 |
+
dec_args.criterion = "ctc"
|
158 |
+
dec_args.kenlm_model = cfg.wer_kenlm_model
|
159 |
+
dec_args.lexicon = cfg.wer_lexicon
|
160 |
+
dec_args.beam = 50
|
161 |
+
dec_args.beam_size_token = min(50, len(task.target_dictionary))
|
162 |
+
dec_args.beam_threshold = min(50, len(task.target_dictionary))
|
163 |
+
dec_args.lm_weight = cfg.wer_lm_weight
|
164 |
+
dec_args.word_score = cfg.wer_word_score
|
165 |
+
dec_args.unk_weight = -math.inf
|
166 |
+
dec_args.sil_weight = 0
|
167 |
+
|
168 |
+
self.w2l_decoder = W2lKenLMDecoder(dec_args, task.target_dictionary)
|
169 |
+
else:
|
170 |
+
self.w2l_decoder = None
|
171 |
+
|
172 |
+
self.zero_infinity = cfg.zero_infinity
|
173 |
+
#self.sentence_avg = cfg.sentence_avg
|
174 |
+
|
175 |
+
if self.ce_weight > 0 and self.ctc_weight > 0:
|
176 |
+
logger.info("Using cross entropy loss and CTC loss for ASR")
|
177 |
+
elif self.ce_weight > 0:
|
178 |
+
logger.info("Only using CE loss")
|
179 |
+
elif self.ctc_weight > 0:
|
180 |
+
logger.info("Only using CTC loss for ASR")
|
181 |
+
else:
|
182 |
+
logger.info("ERROR")
|
183 |
+
|
184 |
+
def forward(self, model, sample, reduce=True):
|
185 |
+
|
186 |
+
if self.ce_weight == 0 and self.ctc_weight > 0:
|
187 |
+
sample["only_ctc"] = True
|
188 |
+
|
189 |
+
net_output_decoder, net_output = model(**sample["net_input"])
|
190 |
+
|
191 |
+
if self.ce_weight > 0:
|
192 |
+
loss_ce, nll_loss_ce = self.compute_loss(model, net_output_decoder, sample, reduce=reduce)
|
193 |
+
#print ("loss_ce: ", loss_ce)
|
194 |
+
else:
|
195 |
+
nll_loss_ce = None
|
196 |
+
|
197 |
+
if self.ctc_weight > 0:
|
198 |
+
loss_ctc, lprobs, input_lengths = self.compute_loss_ctc(model, net_output, sample)
|
199 |
+
|
200 |
+
if self.ce_weight > 0 and self.ctc_weight > 0:
|
201 |
+
loss = self.ce_weight * loss_ce + self.ctc_weight * loss_ctc
|
202 |
+
elif self.ce_weight > 0:
|
203 |
+
loss = loss_ce
|
204 |
+
elif self.ctc_weight > 0:
|
205 |
+
loss = loss_ctc
|
206 |
+
else:
|
207 |
+
logger.info("ERROR: must ce_weight > 0 or ctc_weight > 0")
|
208 |
+
|
209 |
+
ntokens = (
|
210 |
+
sample["ntokens"] if "ntokens" in sample else sample["target_lengths"].sum().item()
|
211 |
+
)
|
212 |
+
|
213 |
+
sample_size = sample["target"].size(0) if self.sentence_avg else ntokens
|
214 |
+
|
215 |
+
logging_output = {
|
216 |
+
"loss": loss.item(),
|
217 |
+
"ce_loss": loss_ce.item() if self.ce_weight > 0 else 0,
|
218 |
+
"ctc_loss": loss_ctc.item() if self.ctc_weight > 0 else 0,
|
219 |
+
"nll_loss": nll_loss_ce.item() if nll_loss_ce is not None else 0,
|
220 |
+
"ntokens": sample["ntokens"],
|
221 |
+
"nsentences": sample["target"].size(0),
|
222 |
+
"sample_size": sample_size,
|
223 |
+
}
|
224 |
+
|
225 |
+
if self.ce_weight > 0 and self.report_accuracy:
|
226 |
+
n_correct, total = self.compute_accuracy(model, net_output_decoder, sample)
|
227 |
+
logging_output["n_correct"] = utils.item(n_correct.item())
|
228 |
+
logging_output["total"] = utils.item(total.data)
|
229 |
+
|
230 |
+
if self.ctc_weight > 0 and not model.training:
|
231 |
+
import editdistance
|
232 |
+
|
233 |
+
with torch.no_grad():
|
234 |
+
lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu()
|
235 |
+
|
236 |
+
c_err = 0
|
237 |
+
c_len = 0
|
238 |
+
w_errs = 0
|
239 |
+
w_len = 0
|
240 |
+
wv_errs = 0
|
241 |
+
for lp, t, inp_l in zip(
|
242 |
+
lprobs_t,
|
243 |
+
sample["target_label"]
|
244 |
+
if "target_label" in sample
|
245 |
+
else sample["target"],
|
246 |
+
input_lengths,
|
247 |
+
):
|
248 |
+
lp = lp[:inp_l].unsqueeze(0)
|
249 |
+
|
250 |
+
decoded = None
|
251 |
+
if self.w2l_decoder is not None:
|
252 |
+
decoded = self.w2l_decoder.decode(lp)
|
253 |
+
if len(decoded) < 1:
|
254 |
+
decoded = None
|
255 |
+
else:
|
256 |
+
decoded = decoded[0]
|
257 |
+
if len(decoded) < 1:
|
258 |
+
decoded = None
|
259 |
+
else:
|
260 |
+
decoded = decoded[0]
|
261 |
+
|
262 |
+
p = (t != self.task.target_dictionary.pad()) & (
|
263 |
+
t != self.task.target_dictionary.eos()
|
264 |
+
)
|
265 |
+
targ = t[p]
|
266 |
+
targ_units = self.task.target_dictionary.string(targ)
|
267 |
+
targ_units_arr = targ.tolist()
|
268 |
+
|
269 |
+
toks = lp.argmax(dim=-1).unique_consecutive()
|
270 |
+
pred_units_arr = toks[toks != self.blank_idx].tolist()
|
271 |
+
|
272 |
+
c_err += editdistance.eval(pred_units_arr, targ_units_arr)
|
273 |
+
c_len += len(targ_units_arr)
|
274 |
+
|
275 |
+
targ_words = post_process(targ_units, self.post_process).split()
|
276 |
+
|
277 |
+
pred_units = self.task.target_dictionary.string(pred_units_arr)
|
278 |
+
pred_words_raw = post_process(pred_units, self.post_process).split()
|
279 |
+
|
280 |
+
if decoded is not None and "words" in decoded:
|
281 |
+
pred_words = decoded["words"]
|
282 |
+
w_errs += editdistance.eval(pred_words, targ_words)
|
283 |
+
wv_errs += editdistance.eval(pred_words_raw, targ_words)
|
284 |
+
else:
|
285 |
+
dist = editdistance.eval(pred_words_raw, targ_words)
|
286 |
+
w_errs += dist
|
287 |
+
wv_errs += dist
|
288 |
+
|
289 |
+
w_len += len(targ_words)
|
290 |
+
|
291 |
+
logging_output["wv_errors"] = wv_errs
|
292 |
+
logging_output["w_errors"] = w_errs
|
293 |
+
logging_output["w_total"] = w_len
|
294 |
+
logging_output["c_errors"] = c_err
|
295 |
+
logging_output["c_total"] = c_len
|
296 |
+
|
297 |
+
return loss, sample_size, logging_output
|
298 |
+
|
299 |
+
def compute_loss_ctc(self, model, net_output, sample):
|
300 |
+
lprobs = model.get_normalized_probs_for_ctc(
|
301 |
+
net_output, log_probs=True
|
302 |
+
).contiguous() # (T, B, C) from the encoder
|
303 |
+
|
304 |
+
if net_output["encoder_padding_mask"] is not None:
|
305 |
+
non_padding_mask = ~net_output["encoder_padding_mask"][0]
|
306 |
+
input_lengths = non_padding_mask.long().sum(-1)
|
307 |
+
else:
|
308 |
+
input_lengths = lprobs.new_full(
|
309 |
+
(lprobs.size(1),), lprobs.size(0), dtype=torch.long
|
310 |
+
)
|
311 |
+
|
312 |
+
pad_mask = (sample["target"] != self.pad_idx) & (
|
313 |
+
sample["target"] != self.eos_idx
|
314 |
+
)
|
315 |
+
targets_flat = sample["target"].masked_select(pad_mask)
|
316 |
+
if "target_lengths" in sample:
|
317 |
+
target_lengths = sample["target_lengths"]
|
318 |
+
else:
|
319 |
+
target_lengths = pad_mask.sum(-1)
|
320 |
+
|
321 |
+
##processing
|
322 |
+
target_lengths = target_lengths - 1
|
323 |
+
|
324 |
+
with torch.backends.cudnn.flags(enabled=False):
|
325 |
+
loss_ctc = F.ctc_loss(
|
326 |
+
lprobs,
|
327 |
+
targets_flat,
|
328 |
+
input_lengths,
|
329 |
+
target_lengths,
|
330 |
+
blank=self.blank_idx,
|
331 |
+
reduction="sum",
|
332 |
+
zero_infinity=True,
|
333 |
+
)
|
334 |
+
|
335 |
+
return loss_ctc, lprobs, input_lengths
|
336 |
+
|
337 |
+
## for ce
|
338 |
+
def get_lprobs_and_target(self, model, net_output, sample):
|
339 |
+
lprobs = model.get_normalized_probs(net_output, log_probs=True)
|
340 |
+
target = model.get_targets(sample, net_output)
|
341 |
+
if self.ignore_prefix_size > 0:
|
342 |
+
if getattr(lprobs, "batch_first", False):
|
343 |
+
lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
|
344 |
+
target = target[:, self.ignore_prefix_size :].contiguous()
|
345 |
+
else:
|
346 |
+
lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous()
|
347 |
+
target = target[self.ignore_prefix_size :, :].contiguous()
|
348 |
+
return lprobs.view(-1, lprobs.size(-1)), target.view(-1)
|
349 |
+
|
350 |
+
def compute_loss(self, model, net_output, sample, reduce=True):
|
351 |
+
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
|
352 |
+
loss, nll_loss = label_smoothed_nll_loss(
|
353 |
+
lprobs,
|
354 |
+
target,
|
355 |
+
self.eps,
|
356 |
+
ignore_index=self.padding_idx,
|
357 |
+
reduce=reduce,
|
358 |
+
)
|
359 |
+
return loss, nll_loss
|
360 |
+
|
361 |
+
def compute_accuracy(self, model, net_output, sample):
|
362 |
+
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
|
363 |
+
mask = target.ne(self.padding_idx)
|
364 |
+
n_correct = torch.sum(
|
365 |
+
lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
|
366 |
+
)
|
367 |
+
total = torch.sum(mask)
|
368 |
+
return n_correct, total
|
369 |
+
|
370 |
+
|
371 |
+
@staticmethod
|
372 |
+
def reduce_metrics(logging_outputs) -> None:
|
373 |
+
"""Aggregate logging outputs from data parallel training."""
|
374 |
+
|
375 |
+
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
|
376 |
+
nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
|
377 |
+
ce_loss_sum = sum(log.get("ce_loss", 0) for log in logging_outputs)
|
378 |
+
ctc_loss_sum = sum(log.get("ctc_loss", 0) for log in logging_outputs)
|
379 |
+
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
|
380 |
+
nsentences = utils.item(
|
381 |
+
sum(log.get("nsentences", 0) for log in logging_outputs)
|
382 |
+
)
|
383 |
+
sample_size = utils.item(
|
384 |
+
sum(log.get("sample_size", 0) for log in logging_outputs)
|
385 |
+
)
|
386 |
+
|
387 |
+
metrics.log_scalar(
|
388 |
+
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
|
389 |
+
)
|
390 |
+
|
391 |
+
metrics.log_scalar(
|
392 |
+
"ctc_loss", ctc_loss_sum / sample_size / math.log(2), ntokens, 2, round=3
|
393 |
+
)
|
394 |
+
metrics.log_scalar(
|
395 |
+
"ce_loss", ce_loss_sum / ntokens, ntokens, 2, round=3
|
396 |
+
)
|
397 |
+
metrics.log_scalar(
|
398 |
+
"nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, 2, round=3
|
399 |
+
)
|
400 |
+
metrics.log_derived(
|
401 |
+
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg, 2)
|
402 |
+
)
|
403 |
+
|
404 |
+
total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
|
405 |
+
if total > 0:
|
406 |
+
metrics.log_scalar("total", total)
|
407 |
+
n_correct = utils.item(
|
408 |
+
sum(log.get("n_correct", 0) for log in logging_outputs)
|
409 |
+
)
|
410 |
+
metrics.log_scalar("n_correct", n_correct)
|
411 |
+
metrics.log_derived(
|
412 |
+
"accuracy",
|
413 |
+
lambda meters: round(
|
414 |
+
meters["n_correct"].sum * 100.0 / meters["total"].sum, 3
|
415 |
+
)
|
416 |
+
if meters["total"].sum > 0
|
417 |
+
else float("nan"),
|
418 |
+
2
|
419 |
+
)
|
420 |
+
|
421 |
+
metrics.log_scalar("ntokens", ntokens)
|
422 |
+
metrics.log_scalar("nsentences", nsentences)
|
423 |
+
if sample_size != ntokens:
|
424 |
+
metrics.log_scalar(
|
425 |
+
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
|
426 |
+
)
|
427 |
+
|
428 |
+
c_errors = sum(log.get("c_errors", 0) for log in logging_outputs)
|
429 |
+
metrics.log_scalar("_c_errors", c_errors)
|
430 |
+
c_total = sum(log.get("c_total", 0) for log in logging_outputs)
|
431 |
+
metrics.log_scalar("_c_total", c_total)
|
432 |
+
w_errors = sum(log.get("w_errors", 0) for log in logging_outputs)
|
433 |
+
metrics.log_scalar("_w_errors", w_errors)
|
434 |
+
wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs)
|
435 |
+
metrics.log_scalar("_wv_errors", wv_errors)
|
436 |
+
w_total = sum(log.get("w_total", 0) for log in logging_outputs)
|
437 |
+
metrics.log_scalar("_w_total", w_total)
|
438 |
+
|
439 |
+
if c_total > 0:
|
440 |
+
metrics.log_derived(
|
441 |
+
"uer",
|
442 |
+
lambda meters: safe_round(
|
443 |
+
meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3
|
444 |
+
)
|
445 |
+
if meters["_c_total"].sum > 0
|
446 |
+
else float("nan"),
|
447 |
+
)
|
448 |
+
if w_total > 0:
|
449 |
+
metrics.log_derived(
|
450 |
+
"wer",
|
451 |
+
lambda meters: safe_round(
|
452 |
+
meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3
|
453 |
+
)
|
454 |
+
if meters["_w_total"].sum > 0
|
455 |
+
else float("nan"),
|
456 |
+
)
|
457 |
+
metrics.log_derived(
|
458 |
+
"raw_wer",
|
459 |
+
lambda meters: safe_round(
|
460 |
+
meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3
|
461 |
+
)
|
462 |
+
if meters["_w_total"].sum > 0
|
463 |
+
else float("nan"),
|
464 |
+
)
|
465 |
+
|
466 |
+
@staticmethod
|
467 |
+
def logging_outputs_can_be_summed() -> bool:
|
468 |
+
"""
|
469 |
+
Whether the logging outputs returned by `forward` can be summed
|
470 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
471 |
+
to True will improves distributed training speed.
|
472 |
+
"""
|
473 |
+
return True
|
artst/criterions/text_pretrain_criterion.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transform (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
# Based on speecht5, fairseq and espnet code bases
|
5 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import math
|
9 |
+
from dataclasses import dataclass, field
|
10 |
+
from typing import List, Optional
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from fairseq import metrics, utils
|
15 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
16 |
+
from fairseq.dataclass import FairseqDataclass
|
17 |
+
from omegaconf import II
|
18 |
+
|
19 |
+
|
20 |
+
@dataclass
|
21 |
+
class TextPretrainCriterionConfig(FairseqDataclass):
|
22 |
+
sentence_avg: bool = II("optimization.sentence_avg")
|
23 |
+
loss_weights: Optional[List[float]] = field(
|
24 |
+
default_factory=lambda: [0.1,],
|
25 |
+
metadata={"help": "weights for additional loss terms (not first one)"},
|
26 |
+
)
|
27 |
+
bart_weight: float = field(
|
28 |
+
default=1.0,
|
29 |
+
metadata={"help": "loss weight for cross entropy"},
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
class TextPretrainCriterion(FairseqCriterion):
|
34 |
+
def __init__(self, task, sentence_avg, bart_weight, loss_weights=None):
|
35 |
+
super().__init__(task)
|
36 |
+
self.sentence_avg = sentence_avg
|
37 |
+
self.loss_weights = loss_weights
|
38 |
+
self.bart_weight = bart_weight
|
39 |
+
|
40 |
+
def forward(self, model, sample, reduce=True):
|
41 |
+
"""Compute the loss for the given sample.
|
42 |
+
|
43 |
+
Returns a tuple with three elements:
|
44 |
+
1) the loss
|
45 |
+
2) the sample size, which is used as the denominator for the gradient
|
46 |
+
3) logging outputs to display while training
|
47 |
+
"""
|
48 |
+
net_output, codebook_out, encoder_output = model(**sample["net_input"])
|
49 |
+
bart_loss, _ = self.compute_loss(model, net_output, sample, reduce=reduce)
|
50 |
+
sample_size = (
|
51 |
+
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
|
52 |
+
)
|
53 |
+
|
54 |
+
loss = self.bart_weight * bart_loss
|
55 |
+
logging_output = {
|
56 |
+
"loss": loss.item(),
|
57 |
+
"ntokens": sample["ntokens"],
|
58 |
+
"nsentences": sample["target"].size(0),
|
59 |
+
"bart_loss": bart_loss.item(),
|
60 |
+
"sample_size": sample_size,
|
61 |
+
}
|
62 |
+
|
63 |
+
if "prob_perplexity" in codebook_out:
|
64 |
+
assert hasattr(model, "get_extra_losses")
|
65 |
+
extra_losses, names = model.get_extra_losses(codebook_out)
|
66 |
+
if torch.is_tensor(extra_losses):
|
67 |
+
extra_losses = [extra_losses]
|
68 |
+
names = [names]
|
69 |
+
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
|
70 |
+
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
|
71 |
+
if len(self.loss_weights) > len(extra_losses):
|
72 |
+
modified_loss_weight = self.loss_weights[len(extra_losses):]
|
73 |
+
else:
|
74 |
+
modified_loss_weight = self.loss_weights
|
75 |
+
|
76 |
+
# assert len(extra_losses) == len(self.loss_weights), f"{len(extra_losses)}, {len(self.loss_weights)}"
|
77 |
+
for p, n, coef in zip(extra_losses, names, modified_loss_weight):
|
78 |
+
# print(n + str(coef))
|
79 |
+
if coef != 0 and p is not None:
|
80 |
+
p = coef * p.float() * sample_size
|
81 |
+
loss += p
|
82 |
+
logging_output[f"loss_{n}"] = p.item()
|
83 |
+
|
84 |
+
if 'loss_prob_perplexity' in logging_output:
|
85 |
+
logging_output['code_perplexity'] = codebook_out['code_perplexity'].item()
|
86 |
+
|
87 |
+
return loss, sample_size, logging_output
|
88 |
+
|
89 |
+
def compute_loss(self, model, net_output, sample, reduce=True):
|
90 |
+
lprobs = model.get_normalized_probs(net_output, log_probs=True)
|
91 |
+
lprobs = lprobs.view(-1, lprobs.size(-1))
|
92 |
+
target = model.get_targets(sample, net_output).view(-1)
|
93 |
+
loss = F.nll_loss(
|
94 |
+
lprobs,
|
95 |
+
target,
|
96 |
+
ignore_index=self.padding_idx,
|
97 |
+
reduction="sum" if reduce else "none",
|
98 |
+
)
|
99 |
+
return loss, loss
|
100 |
+
|
101 |
+
@staticmethod
|
102 |
+
def reduce_metrics(logging_outputs) -> None:
|
103 |
+
"""Aggregate logging outputs from data parallel training."""
|
104 |
+
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
105 |
+
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
106 |
+
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
107 |
+
bart_loss_sum = sum(log.get("bart_loss", 0) for log in logging_outputs)
|
108 |
+
|
109 |
+
# we divide by log(2) to convert the loss from base e to base 2
|
110 |
+
metrics.log_scalar(
|
111 |
+
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
|
112 |
+
)
|
113 |
+
metrics.log_scalar(
|
114 |
+
"bart_loss", bart_loss_sum / sample_size / math.log(2), ntokens, 2, round=3
|
115 |
+
)
|
116 |
+
if sample_size != ntokens:
|
117 |
+
metrics.log_scalar(
|
118 |
+
"nll_loss", bart_loss_sum / ntokens / math.log(2), ntokens, round=3
|
119 |
+
)
|
120 |
+
metrics.log_derived(
|
121 |
+
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
|
122 |
+
)
|
123 |
+
else:
|
124 |
+
metrics.log_derived(
|
125 |
+
"ppl", lambda meters: utils.get_perplexity(meters["bart_loss"].avg)
|
126 |
+
)
|
127 |
+
|
128 |
+
if "loss_prob_perplexity" in logging_outputs[0].keys():
|
129 |
+
val = sum(log["loss_prob_perplexity"] for log in logging_outputs)
|
130 |
+
metrics.log_scalar("loss_prob_perplexity", val / sample_size / math.log(2), round=3)
|
131 |
+
if "code_perplexity" in logging_outputs[0].keys():
|
132 |
+
val = sum(log["code_perplexity"] for log in logging_outputs)
|
133 |
+
metrics.log_scalar("code_perplexity", val / len(logging_outputs), round=3)
|
134 |
+
|
135 |
+
@staticmethod
|
136 |
+
def logging_outputs_can_be_summed() -> bool:
|
137 |
+
"""
|
138 |
+
Whether the logging outputs returned by `forward` can be summed
|
139 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
140 |
+
to True will improves distributed training speed.
|
141 |
+
"""
|
142 |
+
return True
|
artst/criterions/text_to_speech_loss.py
ADDED
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transform (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
# Based on speecht5, fairseq and espnet code bases
|
5 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
from dataclasses import dataclass, field
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from fairseq import metrics, utils
|
12 |
+
from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask
|
13 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
14 |
+
from fairseq.dataclass import FairseqDataclass
|
15 |
+
from artst.models.modules.speech_encoder_prenet import SpeechEncoderPrenet
|
16 |
+
from espnet.nets.pytorch_backend.e2e_tts_tacotron2 import GuidedAttentionLoss
|
17 |
+
from omegaconf import II
|
18 |
+
from typing import Any
|
19 |
+
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class TexttoSpeechLossConfig(FairseqDataclass):
|
23 |
+
use_masking: bool = field(
|
24 |
+
default=True,
|
25 |
+
metadata={"help": "Whether to use masking in calculation of loss"},
|
26 |
+
)
|
27 |
+
use_weighted_masking: bool = field(
|
28 |
+
default=False,
|
29 |
+
metadata={"help": "Whether to use weighted masking in calculation of loss"},
|
30 |
+
)
|
31 |
+
loss_type: str = field(
|
32 |
+
default="L1",
|
33 |
+
metadata={"help": "How to calc loss"},
|
34 |
+
)
|
35 |
+
bce_pos_weight: float = field(
|
36 |
+
default=5.0,
|
37 |
+
metadata={"help": "Positive sample weight in BCE calculation (only for use-masking=True)"},
|
38 |
+
)
|
39 |
+
bce_loss_lambda: float = field(
|
40 |
+
default=1.0,
|
41 |
+
metadata={"help": "Lambda in bce loss"},
|
42 |
+
)
|
43 |
+
use_guided_attn_loss: bool = field(
|
44 |
+
default=False,
|
45 |
+
metadata={"help": "Whether to use guided attention loss"},
|
46 |
+
)
|
47 |
+
guided_attn_loss_sigma: float = field(
|
48 |
+
default=0.4,
|
49 |
+
metadata={"help": "Sigma in guided attention loss"},
|
50 |
+
)
|
51 |
+
guided_attn_loss_lambda: float = field(
|
52 |
+
default=10.0,
|
53 |
+
metadata={"help": "Lambda in guided attention loss"},
|
54 |
+
)
|
55 |
+
num_layers_applied_guided_attn: int = field(
|
56 |
+
default=2,
|
57 |
+
metadata={"help": "Number of layers to be applied guided attention loss, if set -1, all of the layers will be applied."},
|
58 |
+
)
|
59 |
+
num_heads_applied_guided_attn: int = field(
|
60 |
+
default=2,
|
61 |
+
metadata={"help": "Number of heads in each layer to be applied guided attention loss, if set -1, all of the heads will be applied."},
|
62 |
+
)
|
63 |
+
modules_applied_guided_attn: Any = field(
|
64 |
+
default=("encoder-decoder",),
|
65 |
+
metadata={"help": "Module name list to be applied guided attention loss"},
|
66 |
+
)
|
67 |
+
sentence_avg: bool = II("optimization.sentence_avg")
|
68 |
+
|
69 |
+
|
70 |
+
class TexttoSpeechLoss(FairseqCriterion):
|
71 |
+
def __init__(
|
72 |
+
self,
|
73 |
+
task,
|
74 |
+
sentence_avg,
|
75 |
+
use_masking=True,
|
76 |
+
use_weighted_masking=False,
|
77 |
+
loss_type="L1",
|
78 |
+
bce_pos_weight=5.0,
|
79 |
+
bce_loss_lambda=1.0,
|
80 |
+
use_guided_attn_loss=False,
|
81 |
+
guided_attn_loss_sigma=0.4,
|
82 |
+
guided_attn_loss_lambda=1.0,
|
83 |
+
num_layers_applied_guided_attn=2,
|
84 |
+
num_heads_applied_guided_attn=2,
|
85 |
+
modules_applied_guided_attn=["encoder-decoder"],
|
86 |
+
):
|
87 |
+
super().__init__(task)
|
88 |
+
self.sentence_avg = sentence_avg
|
89 |
+
self.use_masking = use_masking
|
90 |
+
self.use_weighted_masking = use_weighted_masking
|
91 |
+
self.loss_type = loss_type
|
92 |
+
self.bce_pos_weight = bce_pos_weight
|
93 |
+
self.bce_loss_lambda = bce_loss_lambda
|
94 |
+
self.use_guided_attn_loss = use_guided_attn_loss
|
95 |
+
self.guided_attn_loss_sigma = guided_attn_loss_sigma
|
96 |
+
self.guided_attn_loss_lambda = guided_attn_loss_lambda
|
97 |
+
# define loss function
|
98 |
+
self.criterion = Tacotron2Loss(
|
99 |
+
use_masking=use_masking,
|
100 |
+
use_weighted_masking=use_weighted_masking,
|
101 |
+
bce_pos_weight=bce_pos_weight,
|
102 |
+
)
|
103 |
+
if self.use_guided_attn_loss:
|
104 |
+
self.num_layers_applied_guided_attn = num_layers_applied_guided_attn
|
105 |
+
self.num_heads_applied_guided_attn = num_heads_applied_guided_attn
|
106 |
+
self.modules_applied_guided_attn = modules_applied_guided_attn
|
107 |
+
if self.use_guided_attn_loss:
|
108 |
+
self.attn_criterion = GuidedMultiHeadAttentionLoss(
|
109 |
+
sigma=guided_attn_loss_sigma,
|
110 |
+
alpha=guided_attn_loss_lambda,
|
111 |
+
)
|
112 |
+
|
113 |
+
def forward(self, model, sample):
|
114 |
+
"""Compute the loss for the given sample.
|
115 |
+
|
116 |
+
Returns a tuple with three elements:
|
117 |
+
1) the loss
|
118 |
+
2) the sample size, which is used as the denominator for the gradient
|
119 |
+
3) logging outputs to display while training
|
120 |
+
"""
|
121 |
+
net_output = model(**sample["net_input"])
|
122 |
+
loss, l1_loss, l2_loss, bce_loss, enc_dec_attn_loss = self.compute_loss(model, net_output, sample)
|
123 |
+
# sample_size = (
|
124 |
+
# sample["target"].size(0) if self.sentence_avg else sample["nframes"]
|
125 |
+
# )
|
126 |
+
sample_size = 1
|
127 |
+
logging_output = {
|
128 |
+
"loss": loss.item(),
|
129 |
+
"l1_loss": l1_loss.item(),
|
130 |
+
"l2_loss": l2_loss.item(),
|
131 |
+
"bce_loss": bce_loss.item(),
|
132 |
+
"sample_size": 1,
|
133 |
+
"ntokens": sample["ntokens"],
|
134 |
+
"nsentences": sample["target"].size(0),
|
135 |
+
}
|
136 |
+
|
137 |
+
if enc_dec_attn_loss is not None:
|
138 |
+
logging_output['enc_dec_attn_loss'] = enc_dec_attn_loss.item()
|
139 |
+
|
140 |
+
if hasattr(model, 'text_encoder_prenet'):
|
141 |
+
logging_output["encoder_alpha"] = model.text_encoder_prenet.encoder_prenet[-1].alpha.item()
|
142 |
+
logging_output["decoder_alpha"] = model.speech_decoder_prenet.decoder_prenet[-1].alpha.item()
|
143 |
+
elif hasattr(model, "speech_encoder_prenet"):
|
144 |
+
logging_output["decoder_alpha"] = model.speech_decoder_prenet.decoder_prenet[-1].alpha.item()
|
145 |
+
else:
|
146 |
+
if 'task' not in sample:
|
147 |
+
logging_output["encoder_alpha"] = model.encoder_prenet.encoder_prenet[-1].alpha.item()
|
148 |
+
logging_output["decoder_alpha"] = model.decoder_prenet.decoder_prenet[-1].alpha.item()
|
149 |
+
|
150 |
+
return loss, sample_size, logging_output
|
151 |
+
|
152 |
+
def compute_loss(self, model, net_output, sample):
|
153 |
+
before_outs, after_outs, logits, attn = net_output
|
154 |
+
labels = sample["labels"]
|
155 |
+
ys = sample["dec_target"]
|
156 |
+
olens = sample["dec_target_lengths"]
|
157 |
+
ilens = sample["src_lengths"]
|
158 |
+
|
159 |
+
# modifiy mod part of groundtruth
|
160 |
+
if model.reduction_factor > 1:
|
161 |
+
olens_in = olens.new([torch.div(olen, model.reduction_factor, rounding_mode='floor') for olen in olens])
|
162 |
+
olens = olens.new([olen - olen % model.reduction_factor for olen in olens])
|
163 |
+
max_olen = max(olens)
|
164 |
+
ys = ys[:, :max_olen]
|
165 |
+
labels = labels[:, :max_olen]
|
166 |
+
labels = torch.scatter(labels, 1, (olens - 1).unsqueeze(1), 1.0) # make sure at least one frame has 1
|
167 |
+
# labels[:, -1] = 1.0
|
168 |
+
else:
|
169 |
+
olens_in = olens
|
170 |
+
|
171 |
+
# caluculate loss values
|
172 |
+
l1_loss, l2_loss, bce_loss = self.criterion(
|
173 |
+
after_outs, before_outs, logits, ys, labels, olens
|
174 |
+
)
|
175 |
+
|
176 |
+
# l1_loss = l1_loss / ys.size(2)
|
177 |
+
# l2_loss = l2_loss / ys.size(2)
|
178 |
+
|
179 |
+
if self.loss_type == "L1":
|
180 |
+
loss = l1_loss + self.bce_loss_lambda * bce_loss if self.bce_loss_lambda > 0.0 else l1_loss
|
181 |
+
elif self.loss_type == "L2":
|
182 |
+
loss = l2_loss + self.bce_loss_lambda * bce_loss if self.bce_loss_lambda > 0.0 else l2_loss
|
183 |
+
elif self.loss_type == "L1+L2":
|
184 |
+
loss = l1_loss + l2_loss + self.bce_loss_lambda * bce_loss if self.bce_loss_lambda > 0.0 else l1_loss + l2_loss
|
185 |
+
else:
|
186 |
+
raise ValueError("unknown --loss-type " + self.loss_type)
|
187 |
+
|
188 |
+
# calculate guided attention loss
|
189 |
+
enc_dec_attn_loss = None
|
190 |
+
if self.use_guided_attn_loss:
|
191 |
+
# calculate the input lengths of encoder, which is determined by encoder prenet
|
192 |
+
if hasattr(model, 'encoder_reduction_factor') and model.encoder_reduction_factor > 1:
|
193 |
+
ilens_in = ilens.new([ilen // model.encoder_reduction_factor for ilen in ilens])
|
194 |
+
else:
|
195 |
+
ilens_in = ilens
|
196 |
+
# work for speech to speech model's input
|
197 |
+
if "task_name" in sample and sample["task_name"] == "s2s":
|
198 |
+
m = None
|
199 |
+
if hasattr(model, 'encoder_prenet'):
|
200 |
+
m = model.encoder_prenet
|
201 |
+
elif hasattr(model, 'speech_encoder_prenet'):
|
202 |
+
m = model.speech_encoder_prenet
|
203 |
+
if m is not None and isinstance(m, SpeechEncoderPrenet):
|
204 |
+
ilens_in = m.get_src_lengths(ilens_in)
|
205 |
+
# calculate for encoder-decoder
|
206 |
+
if "encoder-decoder" in self.modules_applied_guided_attn:
|
207 |
+
attn = [att_l[:, : self.num_heads_applied_guided_attn] for att_l in attn]
|
208 |
+
att_ws = torch.cat(attn, dim=1) # (B, H*L, T_out, T_in)
|
209 |
+
enc_dec_attn_loss = self.attn_criterion(att_ws, ilens_in, olens_in)
|
210 |
+
loss = loss + enc_dec_attn_loss
|
211 |
+
|
212 |
+
return loss, l1_loss, l2_loss, bce_loss, enc_dec_attn_loss
|
213 |
+
|
214 |
+
@classmethod
|
215 |
+
def reduce_metrics(cls, logging_outputs) -> None:
|
216 |
+
"""Aggregate logging outputs from data parallel training."""
|
217 |
+
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
218 |
+
l1_loss_sum = sum(log.get("l1_loss", 0) for log in logging_outputs)
|
219 |
+
l2_loss_sum = sum(log.get("l2_loss", 0) for log in logging_outputs)
|
220 |
+
bce_loss_sum = sum(log.get("bce_loss", 0) for log in logging_outputs)
|
221 |
+
sample_size = max(1, sum(log.get("sample_size", 0) for log in logging_outputs))
|
222 |
+
metrics.log_scalar(
|
223 |
+
"loss", loss_sum / sample_size, sample_size, 1, round=5
|
224 |
+
)
|
225 |
+
encoder_alpha_sum = sum(log.get("encoder_alpha", 0) for log in logging_outputs)
|
226 |
+
decoder_alpha_sum = sum(log.get("decoder_alpha", 0) for log in logging_outputs)
|
227 |
+
ngpu = sum(log.get("ngpu", 0) for log in logging_outputs)
|
228 |
+
|
229 |
+
metrics.log_scalar(
|
230 |
+
"l1_loss", l1_loss_sum / sample_size, sample_size, 2, round=5
|
231 |
+
)
|
232 |
+
metrics.log_scalar(
|
233 |
+
"l2_loss", l2_loss_sum / sample_size, sample_size, 2, round=5
|
234 |
+
)
|
235 |
+
metrics.log_scalar(
|
236 |
+
"bce_loss", bce_loss_sum / sample_size, sample_size, 2, round=5
|
237 |
+
)
|
238 |
+
metrics.log_scalar(
|
239 |
+
"encoder_alpha", encoder_alpha_sum / sample_size, sample_size, round=5
|
240 |
+
)
|
241 |
+
metrics.log_scalar(
|
242 |
+
"decoder_alpha", decoder_alpha_sum / sample_size, sample_size, round=5
|
243 |
+
)
|
244 |
+
|
245 |
+
if "enc_dec_attn_loss" in logging_outputs[0]:
|
246 |
+
enc_dec_attn_loss_sum = sum(log.get("enc_dec_attn_loss", 0) for log in logging_outputs)
|
247 |
+
metrics.log_scalar(
|
248 |
+
"enc_dec_attn_loss", enc_dec_attn_loss_sum / sample_size, sample_size, round=8
|
249 |
+
)
|
250 |
+
|
251 |
+
|
252 |
+
@staticmethod
|
253 |
+
def logging_outputs_can_be_summed() -> bool:
|
254 |
+
"""
|
255 |
+
Whether the logging outputs returned by `forward` can be summed
|
256 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
257 |
+
to True will improves distributed training speed.
|
258 |
+
"""
|
259 |
+
return True
|
260 |
+
|
261 |
+
class Tacotron2Loss(torch.nn.Module):
|
262 |
+
"""Loss function module for Tacotron2."""
|
263 |
+
|
264 |
+
def __init__(
|
265 |
+
self, use_masking=True, use_weighted_masking=False, bce_pos_weight=20.0
|
266 |
+
):
|
267 |
+
"""Initialize Tactoron2 loss module.
|
268 |
+
|
269 |
+
Args:
|
270 |
+
use_masking (bool): Whether to apply masking
|
271 |
+
for padded part in loss calculation.
|
272 |
+
use_weighted_masking (bool):
|
273 |
+
Whether to apply weighted masking in loss calculation.
|
274 |
+
bce_pos_weight (float): Weight of positive sample of stop token.
|
275 |
+
|
276 |
+
"""
|
277 |
+
super(Tacotron2Loss, self).__init__()
|
278 |
+
assert (use_masking != use_weighted_masking) or not use_masking
|
279 |
+
self.use_masking = use_masking
|
280 |
+
self.use_weighted_masking = use_weighted_masking
|
281 |
+
|
282 |
+
# define criterions
|
283 |
+
# reduction = "none" if self.use_weighted_masking else "sum"
|
284 |
+
reduction = "none" if self.use_weighted_masking else "mean"
|
285 |
+
self.l1_criterion = torch.nn.L1Loss(reduction=reduction)
|
286 |
+
self.mse_criterion = torch.nn.MSELoss(reduction=reduction)
|
287 |
+
self.bce_criterion = torch.nn.BCEWithLogitsLoss(
|
288 |
+
reduction=reduction, pos_weight=torch.tensor(bce_pos_weight)
|
289 |
+
)
|
290 |
+
|
291 |
+
# NOTE(kan-bayashi): register pre hook function for the compatibility
|
292 |
+
self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook)
|
293 |
+
|
294 |
+
def forward(self, after_outs, before_outs, logits, ys, labels, olens):
|
295 |
+
"""Calculate forward propagation.
|
296 |
+
|
297 |
+
Args:
|
298 |
+
after_outs (Tensor): Batch of outputs after postnets (B, Lmax, odim).
|
299 |
+
before_outs (Tensor): Batch of outputs before postnets (B, Lmax, odim).
|
300 |
+
logits (Tensor): Batch of stop logits (B, Lmax).
|
301 |
+
ys (Tensor): Batch of padded target features (B, Lmax, odim).
|
302 |
+
labels (LongTensor): Batch of the sequences of stop token labels (B, Lmax).
|
303 |
+
olens (LongTensor): Batch of the lengths of each target (B,).
|
304 |
+
|
305 |
+
Returns:
|
306 |
+
Tensor: L1 loss value.
|
307 |
+
Tensor: Mean square error loss value.
|
308 |
+
Tensor: Binary cross entropy loss value.
|
309 |
+
|
310 |
+
"""
|
311 |
+
# make mask and apply it
|
312 |
+
if self.use_masking:
|
313 |
+
masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
|
314 |
+
ys = ys.masked_select(masks)
|
315 |
+
after_outs = after_outs.masked_select(masks)
|
316 |
+
before_outs = before_outs.masked_select(masks)
|
317 |
+
labels = labels.masked_select(masks[:, :, 0])
|
318 |
+
logits = logits.masked_select(masks[:, :, 0])
|
319 |
+
|
320 |
+
# calculate loss
|
321 |
+
l1_loss = self.l1_criterion(after_outs, ys) + self.l1_criterion(before_outs, ys)
|
322 |
+
mse_loss = self.mse_criterion(after_outs, ys) + self.mse_criterion(
|
323 |
+
before_outs, ys
|
324 |
+
)
|
325 |
+
bce_loss = self.bce_criterion(logits, labels)
|
326 |
+
|
327 |
+
# make weighted mask and apply it
|
328 |
+
if self.use_weighted_masking:
|
329 |
+
masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
|
330 |
+
weights = masks.float() / masks.sum(dim=1, keepdim=True).float()
|
331 |
+
out_weights = weights.div(ys.size(0) * ys.size(2))
|
332 |
+
logit_weights = weights.div(ys.size(0))
|
333 |
+
|
334 |
+
# apply weight
|
335 |
+
l1_loss = l1_loss.mul(out_weights).masked_select(masks).sum()
|
336 |
+
mse_loss = mse_loss.mul(out_weights).masked_select(masks).sum()
|
337 |
+
bce_loss = (
|
338 |
+
bce_loss.mul(logit_weights.squeeze(-1))
|
339 |
+
.masked_select(masks.squeeze(-1))
|
340 |
+
.sum()
|
341 |
+
)
|
342 |
+
|
343 |
+
return l1_loss, mse_loss, bce_loss
|
344 |
+
|
345 |
+
def _load_state_dict_pre_hook(
|
346 |
+
self,
|
347 |
+
state_dict,
|
348 |
+
prefix,
|
349 |
+
local_metadata,
|
350 |
+
strict,
|
351 |
+
missing_keys,
|
352 |
+
unexpected_keys,
|
353 |
+
error_msgs,
|
354 |
+
):
|
355 |
+
"""Apply pre hook fucntion before loading state dict.
|
356 |
+
|
357 |
+
From v.0.6.1 `bce_criterion.pos_weight` param is registered as a parameter but
|
358 |
+
old models do not include it and as a result, it causes missing key error when
|
359 |
+
loading old model parameter. This function solve the issue by adding param in
|
360 |
+
state dict before loading as a pre hook function
|
361 |
+
of the `load_state_dict` method.
|
362 |
+
|
363 |
+
"""
|
364 |
+
key = prefix + "bce_criterion.pos_weight"
|
365 |
+
if key not in state_dict:
|
366 |
+
state_dict[key] = self.bce_criterion.pos_weight
|
367 |
+
|
368 |
+
class GuidedMultiHeadAttentionLoss(GuidedAttentionLoss):
|
369 |
+
"""Guided attention loss function module for multi head attention.
|
370 |
+
Args:
|
371 |
+
sigma (float, optional): Standard deviation to control
|
372 |
+
how close attention to a diagonal.
|
373 |
+
alpha (float, optional): Scaling coefficient (lambda).
|
374 |
+
reset_always (bool, optional): Whether to always reset masks.
|
375 |
+
"""
|
376 |
+
|
377 |
+
def forward(self, att_ws, ilens, olens):
|
378 |
+
"""Calculate forward propagation.
|
379 |
+
Args:
|
380 |
+
att_ws (Tensor):
|
381 |
+
Batch of multi head attention weights (B, H, T_max_out, T_max_in).
|
382 |
+
ilens (LongTensor): Batch of input lenghts (B,).
|
383 |
+
olens (LongTensor): Batch of output lenghts (B,).
|
384 |
+
Returns:
|
385 |
+
Tensor: Guided attention loss value.
|
386 |
+
"""
|
387 |
+
if self.guided_attn_masks is None:
|
388 |
+
self.guided_attn_masks = (
|
389 |
+
self._make_guided_attention_masks(ilens, olens)
|
390 |
+
.to(att_ws.device)
|
391 |
+
.unsqueeze(1)
|
392 |
+
)
|
393 |
+
if self.masks is None:
|
394 |
+
self.masks = self._make_masks(ilens, olens).to(att_ws.device).unsqueeze(1)
|
395 |
+
losses = self.guided_attn_masks * att_ws
|
396 |
+
loss = torch.mean(losses.masked_select(self.masks))
|
397 |
+
if self.reset_always:
|
398 |
+
self._reset_masks()
|
399 |
+
|
400 |
+
return self.alpha * loss
|
401 |
+
|
402 |
+
def _make_guided_attention_masks(self, ilens, olens):
|
403 |
+
n_batches = len(ilens)
|
404 |
+
max_ilen = max(ilens)
|
405 |
+
max_olen = max(olens)
|
406 |
+
guided_attn_masks = torch.zeros((n_batches, max_olen, max_ilen), device=olens.device)
|
407 |
+
for idx, (ilen, olen) in enumerate(zip(ilens, olens)):
|
408 |
+
guided_attn_masks[idx, :olen, :ilen] = self._make_guided_attention_mask(
|
409 |
+
ilen, olen, self.sigma
|
410 |
+
)
|
411 |
+
return guided_attn_masks
|
412 |
+
|
413 |
+
@staticmethod
|
414 |
+
def _make_guided_attention_mask(ilen, olen, sigma):
|
415 |
+
grid_x, grid_y = torch.meshgrid(torch.arange(olen, device=olen.device), torch.arange(ilen, device=olen.device))
|
416 |
+
grid_x, grid_y = grid_x.float(), grid_y.float()
|
417 |
+
return 1.0 - torch.exp(
|
418 |
+
-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma**2))
|
419 |
+
)
|
420 |
+
|
421 |
+
@staticmethod
|
422 |
+
def _make_masks(ilens, olens):
|
423 |
+
in_masks = make_non_pad_mask(ilens).to(ilens.device) # (B, T_in)
|
424 |
+
out_masks = make_non_pad_mask(olens).to(olens.device) # (B, T_out)
|
425 |
+
return out_masks.unsqueeze(-1) & in_masks.unsqueeze(-2) # (B, T_out, T_in)
|
artst/data/__init__.py
ADDED
File without changes
|
artst/data/multitask_dataset.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
# Based on speecht5, fairseq and espnet code bases
|
5 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import bisect
|
9 |
+
|
10 |
+
import logging
|
11 |
+
import numpy as np
|
12 |
+
from torch.utils.data.dataloader import default_collate
|
13 |
+
from fairseq.data import data_utils
|
14 |
+
|
15 |
+
from fairseq.data.fairseq_dataset import FairseqDataset
|
16 |
+
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
class MultitaskDataset(FairseqDataset):
|
20 |
+
@staticmethod
|
21 |
+
def cumsum(sequence):
|
22 |
+
r, s = [], 0
|
23 |
+
for e in sequence:
|
24 |
+
curr_len = len(e)
|
25 |
+
r.append(curr_len + s)
|
26 |
+
s += curr_len
|
27 |
+
return r
|
28 |
+
|
29 |
+
def __init__(self, datasets, sample_ratios=1, batch_ratio=None):
|
30 |
+
super(MultitaskDataset, self).__init__()
|
31 |
+
assert len(datasets) > 0, "datasets should not be an empty iterable"
|
32 |
+
self.datasets = list(datasets)
|
33 |
+
if isinstance(sample_ratios, int):
|
34 |
+
sample_ratios = [sample_ratios] * len(self.datasets)
|
35 |
+
if batch_ratio is not None:
|
36 |
+
logger.info('batch ratio is ' + str(batch_ratio))
|
37 |
+
self.batch_ratio = batch_ratio
|
38 |
+
else:
|
39 |
+
self.batch_ratio = None
|
40 |
+
else:
|
41 |
+
logger.info('set sample ratio to ' + str(sample_ratios))
|
42 |
+
if batch_ratio is not None:
|
43 |
+
logger.info('batch ratio is ' + str(batch_ratio))
|
44 |
+
self.batch_ratio = batch_ratio
|
45 |
+
else:
|
46 |
+
self.batch_ratio = None
|
47 |
+
self.sample_ratios = sample_ratios
|
48 |
+
self._ordered_indices = None
|
49 |
+
self._update_size()
|
50 |
+
|
51 |
+
def __len__(self):
|
52 |
+
return self.cumulative_sizes[-1]
|
53 |
+
|
54 |
+
def __getitem__(self, idx):
|
55 |
+
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
|
56 |
+
sample = self.datasets[dataset_idx][sample_idx]
|
57 |
+
if isinstance(sample, dict):
|
58 |
+
sample["dataset_idx"] = dataset_idx
|
59 |
+
else:
|
60 |
+
sample = sample + (dataset_idx,)
|
61 |
+
return sample
|
62 |
+
|
63 |
+
def _update_size(self):
|
64 |
+
self.cumulative_sizes = self.cumsum(self.datasets)
|
65 |
+
self.real_sizes = [len(d) for d in self.datasets]
|
66 |
+
|
67 |
+
def _get_dataset_and_sample_index(self, idx: int):
|
68 |
+
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
69 |
+
if dataset_idx == 0:
|
70 |
+
sample_idx = idx
|
71 |
+
else:
|
72 |
+
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
73 |
+
sample_idx = sample_idx % self.real_sizes[dataset_idx]
|
74 |
+
return dataset_idx, sample_idx
|
75 |
+
|
76 |
+
def collater(self, samples, **extra_args):
|
77 |
+
# For now only supports datasets with same underlying collater implementations
|
78 |
+
if samples is not None and len(samples) > 0:
|
79 |
+
if isinstance(samples[0], dict):
|
80 |
+
dataset_idx = samples[0]["dataset_idx"]
|
81 |
+
else:
|
82 |
+
dataset_idx = samples[0][-1]
|
83 |
+
samples = [sample[:-1] for sample in samples]
|
84 |
+
else:
|
85 |
+
dataset_idx = 0
|
86 |
+
|
87 |
+
if hasattr(self.datasets[dataset_idx], "collater"):
|
88 |
+
return self.datasets[dataset_idx].collater(samples, **extra_args)
|
89 |
+
else:
|
90 |
+
return default_collate(samples, **extra_args)
|
91 |
+
|
92 |
+
def size(self, idx: int):
|
93 |
+
"""
|
94 |
+
Return an example's size as a float or tuple.
|
95 |
+
"""
|
96 |
+
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
|
97 |
+
return self.datasets[dataset_idx].size(sample_idx)
|
98 |
+
|
99 |
+
def num_tokens(self, index: int):
|
100 |
+
return np.max(self.size(index))
|
101 |
+
|
102 |
+
def attr(self, attr: str, index: int):
|
103 |
+
dataset_idx = bisect.bisect_right(self.cumulative_sizes, index)
|
104 |
+
return getattr(self.datasets[dataset_idx], attr, None)
|
105 |
+
|
106 |
+
@property
|
107 |
+
def sizes(self):
|
108 |
+
_dataset_sizes = []
|
109 |
+
for ds in self.datasets:
|
110 |
+
if isinstance(ds.sizes, np.ndarray):
|
111 |
+
_dataset_sizes.append(ds.sizes)
|
112 |
+
else:
|
113 |
+
# Only support underlying dataset with single size array.
|
114 |
+
assert isinstance(ds.sizes, list)
|
115 |
+
_dataset_sizes.append(ds.sizes[0])
|
116 |
+
return np.concatenate(_dataset_sizes)
|
117 |
+
|
118 |
+
@property
|
119 |
+
def supports_prefetch(self):
|
120 |
+
return all(d.supports_prefetch for d in self.datasets)
|
121 |
+
|
122 |
+
def ordered_indices(self):
|
123 |
+
# ordered_indices = []
|
124 |
+
# for i, dataset in enumerate(self.datasets):
|
125 |
+
# indice = dataset.ordered_indices()
|
126 |
+
# ordered_indices.append(indice)
|
127 |
+
if self._ordered_indices is None:
|
128 |
+
# Call the underlying dataset's ordered_indices() here, so that we
|
129 |
+
# get the same random ordering as we would have from using the
|
130 |
+
# underlying sub-datasets directly.
|
131 |
+
self._ordered_indices = [
|
132 |
+
dataset.ordered_indices()
|
133 |
+
for dataset in self.datasets
|
134 |
+
]
|
135 |
+
return np.arange(len(self))
|
136 |
+
|
137 |
+
def prefetch(self, indices):
|
138 |
+
frm = 0
|
139 |
+
for to, ds in zip(self.cumulative_sizes, self.datasets):
|
140 |
+
real_size = len(ds)
|
141 |
+
if getattr(ds, "supports_prefetch", False):
|
142 |
+
ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to])
|
143 |
+
frm = to
|
144 |
+
|
145 |
+
def batch_by_size(
|
146 |
+
self,
|
147 |
+
indices,
|
148 |
+
max_tokens=None,
|
149 |
+
max_sentences=None,
|
150 |
+
required_batch_size_multiple=1,
|
151 |
+
):
|
152 |
+
if not hasattr(self, "max_tokens"):
|
153 |
+
self.max_tokens = max_tokens
|
154 |
+
if not hasattr(self, "max_sentences"):
|
155 |
+
self.max_sentences = max_sentences
|
156 |
+
if not hasattr(self, "required_batch_size_multiple"):
|
157 |
+
self.required_batch_size_multiple = required_batch_size_multiple
|
158 |
+
batch_samplers = []
|
159 |
+
for i, dataset in enumerate(self.datasets):
|
160 |
+
batch_sampler = dataset.batch_by_size(
|
161 |
+
self._ordered_indices[i],
|
162 |
+
max_tokens=max_tokens if self.batch_ratio is None else max_tokens * self.batch_ratio[i],
|
163 |
+
max_sentences=max_sentences,
|
164 |
+
required_batch_size_multiple=required_batch_size_multiple,
|
165 |
+
)
|
166 |
+
if i > 0:
|
167 |
+
for batch in batch_sampler:
|
168 |
+
batch += self.cumulative_sizes[i - 1]
|
169 |
+
if self.sample_ratios[i] != 1.0:
|
170 |
+
batch_sampler = np.array(batch_sampler)
|
171 |
+
batch_sampler = np.random.choice(batch_sampler, int(len(batch_sampler) * self.sample_ratios[i]))
|
172 |
+
batch_sampler = list(batch_sampler)
|
173 |
+
logger.info('Adjust batch by ratio ' + str(self.sample_ratios[i]) + ' and the number of batch is ' + str(int(len(batch_sampler))) + ' for dataset ' + str(i))
|
174 |
+
batch_samplers.extend(batch_sampler)
|
175 |
+
return batch_samplers
|
176 |
+
|
177 |
+
def filter_indices_by_size(self, indices, max_positions):
|
178 |
+
"""
|
179 |
+
Filter each sub-dataset independently, then update the round robin to work
|
180 |
+
on the filtered sub-datasets.
|
181 |
+
"""
|
182 |
+
if not hasattr(self, "max_positions"):
|
183 |
+
self.max_positions = max_positions
|
184 |
+
ignored_some = False
|
185 |
+
for i in range(len(self.datasets)):
|
186 |
+
# ignored = []
|
187 |
+
self._ordered_indices[i], ignored = self.datasets[i].filter_indices_by_size(
|
188 |
+
self._ordered_indices[i], self.max_positions[i]
|
189 |
+
)
|
190 |
+
if len(ignored) > 0:
|
191 |
+
ignored_some = True
|
192 |
+
logger.warning(
|
193 |
+
f"{len(ignored)} samples from {i} have invalid sizes and will be skipped, "
|
194 |
+
f"max_positions={self.max_positions[i]}, first few sample ids={ignored[:10]}"
|
195 |
+
)
|
196 |
+
|
197 |
+
logger.info('update dataset size')
|
198 |
+
self._update_size()
|
199 |
+
|
200 |
+
# Since we are modifying in place the _ordered_indices,
|
201 |
+
# it's not possible anymore to return valid ignored indices.
|
202 |
+
# Hopefully the extra debug information print above should be enough to debug.
|
203 |
+
# Ideally we would receive ignore_invalid_inputs so that we could have
|
204 |
+
# a proper error message.
|
205 |
+
return (np.arange(len(self)), [0] if ignored_some else [])
|
206 |
+
|
207 |
+
@property
|
208 |
+
def can_reuse_epoch_itr_across_epochs(self):
|
209 |
+
return all(d.can_reuse_epoch_itr_across_epochs for d in self.datasets)
|
210 |
+
|
211 |
+
def set_epoch(self, epoch):
|
212 |
+
super().set_epoch(epoch)
|
213 |
+
for ds in self.datasets:
|
214 |
+
if hasattr(ds, "set_epoch"):
|
215 |
+
ds.set_epoch(epoch)
|
216 |
+
|
217 |
+
def shuffle_batches(self, batches, seed):
|
218 |
+
logger.info("shuffle batches")
|
219 |
+
new_batches_fromlist = []
|
220 |
+
new_batches_notlist = []
|
221 |
+
new_batches = []
|
222 |
+
with data_utils.numpy_seed(seed):
|
223 |
+
np.random.shuffle(batches)
|
224 |
+
for batch in batches:
|
225 |
+
if isinstance(batch, list):
|
226 |
+
# np.random.shuffle(batch)
|
227 |
+
new_batches_fromlist.append(batch)
|
228 |
+
else:
|
229 |
+
new_batches_notlist.append(batch)
|
230 |
+
logger.info("Get " + str(len(new_batches_fromlist)) + " chunk from speech sides")
|
231 |
+
logger.info("Get " + str(sum([len(batch_list) for batch_list in new_batches_fromlist])) + " batches from speech sides")
|
232 |
+
logger.info("Get " + str(len(new_batches_notlist)) + " batches from text sides")
|
233 |
+
if len(new_batches_fromlist) == 0:
|
234 |
+
return new_batches_notlist
|
235 |
+
st_ratio = int(len(new_batches_notlist) / len(new_batches_fromlist))
|
236 |
+
logger.info("Get st_ratio " + str(st_ratio))
|
237 |
+
last_idx = 0
|
238 |
+
for i in range(len(new_batches_fromlist)):
|
239 |
+
if i == len(new_batches_fromlist) - 1:
|
240 |
+
new_batches_fromlist[i].extend(new_batches_notlist[last_idx:])
|
241 |
+
else:
|
242 |
+
new_batches_fromlist[i].extend(new_batches_notlist[last_idx : last_idx + st_ratio])
|
243 |
+
np.random.shuffle(new_batches_fromlist[i])
|
244 |
+
new_batches.extend(new_batches_fromlist[i])
|
245 |
+
last_idx = last_idx + st_ratio
|
246 |
+
logger.info("Finish shuffle")
|
247 |
+
return new_batches
|
248 |
+
|
249 |
+
def reset_batch_sampler(self):
|
250 |
+
logger.info("reset batch sampler")
|
251 |
+
self._ordered_indices = [
|
252 |
+
self.datasets[i].ordered_indices()
|
253 |
+
for i in range(len(self.datasets))
|
254 |
+
]
|
255 |
+
self.filter_indices_by_size(None, None)
|
256 |
+
|
257 |
+
batch_samplers = self.batch_by_size(
|
258 |
+
None,
|
259 |
+
self.max_tokens,
|
260 |
+
self.max_sentences,
|
261 |
+
self.required_batch_size_multiple
|
262 |
+
)
|
263 |
+
return batch_samplers
|
artst/data/speech_dataset.py
ADDED
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
|
5 |
+
# Based on speecht5, fairseq and espnet code bases
|
6 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
7 |
+
# --------------------------------------------------------
|
8 |
+
|
9 |
+
import itertools
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
import sys
|
13 |
+
from typing import Any, List, Optional, Union
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn.functional as F
|
19 |
+
import librosa
|
20 |
+
from fairseq.data.audio.speech_to_text_dataset import get_features_or_waveform
|
21 |
+
from fairseq.data import data_utils
|
22 |
+
from fairseq.data.fairseq_dataset import FairseqDataset
|
23 |
+
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
|
26 |
+
def _collate_frames(
|
27 |
+
frames: List[torch.Tensor], is_audio_input: bool = False
|
28 |
+
):
|
29 |
+
"""
|
30 |
+
Convert a list of 2D frames into a padded 3D tensor
|
31 |
+
Args:
|
32 |
+
frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is
|
33 |
+
length of i-th frame and f_dim is static dimension of features
|
34 |
+
Returns:
|
35 |
+
3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
|
36 |
+
"""
|
37 |
+
max_len = max(frame.size(0) for frame in frames)
|
38 |
+
if is_audio_input:
|
39 |
+
out = frames[0].new_zeros((len(frames), max_len))
|
40 |
+
else:
|
41 |
+
out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1)))
|
42 |
+
for i, v in enumerate(frames):
|
43 |
+
out[i, : v.size(0)] = v
|
44 |
+
return out
|
45 |
+
|
46 |
+
def add_first_frame_and_remove_last_frame(ys):
|
47 |
+
ys_in = torch.cat(
|
48 |
+
[ys.new_zeros((ys.shape[0], 1, ys.shape[2])), ys[:, :-1]], dim=1
|
49 |
+
)
|
50 |
+
return ys_in
|
51 |
+
|
52 |
+
def load_audio(manifest_path, max_keep, min_keep):
|
53 |
+
n_long, n_short = 0, 0
|
54 |
+
names, inds, sizes, spk_embeds = [], [], [], []
|
55 |
+
with open(manifest_path) as f:
|
56 |
+
root = f.readline().strip()
|
57 |
+
for ind, line in enumerate(f):
|
58 |
+
items = line.strip().split("\t")
|
59 |
+
assert len(items) == 3, line
|
60 |
+
sz = int(items[1])
|
61 |
+
if min_keep is not None and sz < min_keep:
|
62 |
+
n_short += 1
|
63 |
+
elif max_keep is not None and sz > max_keep:
|
64 |
+
n_long += 1
|
65 |
+
else:
|
66 |
+
names.append(items[0])
|
67 |
+
spk_embeds.append(items[2])
|
68 |
+
inds.append(ind)
|
69 |
+
sizes.append(sz)
|
70 |
+
tot = ind + 1
|
71 |
+
logger.info(
|
72 |
+
(
|
73 |
+
f"max_keep={max_keep}, min_keep={min_keep}, "
|
74 |
+
f"loaded {len(names)}, skipped {n_short} short and {n_long} long, "
|
75 |
+
f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
|
76 |
+
)
|
77 |
+
)
|
78 |
+
return root, names, inds, tot, sizes, spk_embeds
|
79 |
+
|
80 |
+
|
81 |
+
def load_label(label_path, inds, tot):
|
82 |
+
with open(label_path) as f:
|
83 |
+
labels = [line.rstrip() for line in f]
|
84 |
+
assert (
|
85 |
+
len(labels) == tot
|
86 |
+
), f"number of labels does not match ({len(labels)} != {tot})"
|
87 |
+
labels = [labels[i] for i in inds]
|
88 |
+
return labels
|
89 |
+
|
90 |
+
|
91 |
+
def load_label_offset(label_path, inds, tot):
|
92 |
+
with open(label_path) as f:
|
93 |
+
code_lengths = [len(line.encode("utf-8")) for line in f]
|
94 |
+
assert (
|
95 |
+
len(code_lengths) == tot
|
96 |
+
), f"number of labels does not match ({len(code_lengths)} != {tot})"
|
97 |
+
offsets = list(itertools.accumulate([0] + code_lengths))
|
98 |
+
offsets = [(offsets[i], offsets[i + 1]) for i in inds]
|
99 |
+
return offsets
|
100 |
+
|
101 |
+
|
102 |
+
def verify_label_lengths(
|
103 |
+
audio_sizes,
|
104 |
+
audio_rate,
|
105 |
+
label_path,
|
106 |
+
label_rate,
|
107 |
+
inds,
|
108 |
+
tot,
|
109 |
+
tol=0.1, # tolerance in seconds
|
110 |
+
):
|
111 |
+
if label_rate < 0:
|
112 |
+
logger.info(f"{label_path} is sequence label. skipped")
|
113 |
+
return
|
114 |
+
|
115 |
+
with open(label_path) as f:
|
116 |
+
lengths = [len(line.rstrip().split()) for line in f]
|
117 |
+
assert len(lengths) == tot
|
118 |
+
lengths = [lengths[i] for i in inds]
|
119 |
+
num_invalid = 0
|
120 |
+
for i, ind in enumerate(inds):
|
121 |
+
dur_from_audio = audio_sizes[i] / audio_rate
|
122 |
+
dur_from_label = lengths[i] / label_rate
|
123 |
+
if abs(dur_from_audio - dur_from_label) > tol:
|
124 |
+
logger.warning(
|
125 |
+
(
|
126 |
+
f"audio and label duration differ too much "
|
127 |
+
f"(|{dur_from_audio} - {dur_from_label}| > {tol}) "
|
128 |
+
f"in line {ind+1} of {label_path}. Check if `label_rate` "
|
129 |
+
f"is correctly set (currently {label_rate}). "
|
130 |
+
f"num. of samples = {audio_sizes[i]}; "
|
131 |
+
f"label length = {lengths[i]}"
|
132 |
+
)
|
133 |
+
)
|
134 |
+
num_invalid += 1
|
135 |
+
if num_invalid > 0:
|
136 |
+
logger.warning(
|
137 |
+
f"total {num_invalid} (audio, label) pairs with mismatched lengths"
|
138 |
+
)
|
139 |
+
|
140 |
+
|
141 |
+
def logmelfilterbank(
|
142 |
+
audio,
|
143 |
+
sampling_rate,
|
144 |
+
fft_size=1024,
|
145 |
+
hop_size=256,
|
146 |
+
win_length=None,
|
147 |
+
window="hann",
|
148 |
+
num_mels=80,
|
149 |
+
fmin=80,
|
150 |
+
fmax=7600,
|
151 |
+
eps=1e-10,
|
152 |
+
):
|
153 |
+
"""Compute log-Mel filterbank feature.
|
154 |
+
(https://github.com/kan-bayashi/ParallelWaveGAN/blob/master/parallel_wavegan/bin/preprocess.py)
|
155 |
+
|
156 |
+
Args:
|
157 |
+
audio (ndarray): Audio signal (T,).
|
158 |
+
sampling_rate (int): Sampling rate.
|
159 |
+
fft_size (int): FFT size.
|
160 |
+
hop_size (int): Hop size.
|
161 |
+
win_length (int): Window length. If set to None, it will be the same as fft_size.
|
162 |
+
window (str): Window function type.
|
163 |
+
num_mels (int): Number of mel basis.
|
164 |
+
fmin (int): Minimum frequency in mel basis calculation.
|
165 |
+
fmax (int): Maximum frequency in mel basis calculation.
|
166 |
+
eps (float): Epsilon value to avoid inf in log calculation.
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
ndarray: Log Mel filterbank feature (#frames, num_mels).
|
170 |
+
|
171 |
+
"""
|
172 |
+
# get amplitude spectrogram
|
173 |
+
x_stft = librosa.stft(audio, n_fft=fft_size, hop_length=hop_size,
|
174 |
+
win_length=win_length, window=window, pad_mode="reflect")
|
175 |
+
spc = np.abs(x_stft).T # (#frames, #bins)
|
176 |
+
|
177 |
+
# get mel basis
|
178 |
+
fmin = 0 if fmin is None else fmin
|
179 |
+
fmax = sampling_rate / 2 if fmax is None else fmax
|
180 |
+
mel_basis = librosa.filters.mel(sr=sampling_rate, n_fft=fft_size, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
181 |
+
|
182 |
+
return np.log10(np.maximum(eps, np.dot(spc, mel_basis.T)))
|
183 |
+
|
184 |
+
|
185 |
+
class SpeechPretrainDataset(FairseqDataset):
|
186 |
+
def __init__(
|
187 |
+
self,
|
188 |
+
manifest_path: str,
|
189 |
+
sample_rate: float,
|
190 |
+
label_paths: List[str],
|
191 |
+
label_rates: Union[List[float], float], # -1 for sequence labels
|
192 |
+
pad_list: List[str],
|
193 |
+
eos_list: List[str],
|
194 |
+
label_processors: Optional[List[Any]] = None,
|
195 |
+
max_keep_sample_size: Optional[int] = None,
|
196 |
+
min_keep_sample_size: Optional[int] = None,
|
197 |
+
max_sample_size: Optional[int] = None,
|
198 |
+
shuffle: bool = True,
|
199 |
+
pad_audio: bool = False,
|
200 |
+
normalize: bool = False,
|
201 |
+
store_labels: bool = True,
|
202 |
+
random_crop: bool = False,
|
203 |
+
single_target: bool = False,
|
204 |
+
reduction_factor: int = 1,
|
205 |
+
):
|
206 |
+
self.audio_root, self.audio_names, inds, tot, self.sizes, self.spk_embeds = load_audio(
|
207 |
+
manifest_path, max_keep_sample_size, min_keep_sample_size
|
208 |
+
)
|
209 |
+
self.sample_rate = sample_rate
|
210 |
+
self.shuffle = shuffle
|
211 |
+
self.random_crop = random_crop
|
212 |
+
|
213 |
+
self.num_labels = len(label_paths)
|
214 |
+
self.pad_list = pad_list
|
215 |
+
self.eos_list = eos_list
|
216 |
+
self.label_processors = label_processors
|
217 |
+
self.single_target = single_target
|
218 |
+
self.label_rates = (
|
219 |
+
[label_rates for _ in range(len(label_paths))]
|
220 |
+
if isinstance(label_rates, float)
|
221 |
+
else label_rates
|
222 |
+
)
|
223 |
+
self.store_labels = store_labels
|
224 |
+
if store_labels:
|
225 |
+
self.label_list = [load_label(p, inds, tot) for p in label_paths]
|
226 |
+
else:
|
227 |
+
self.label_paths = label_paths
|
228 |
+
self.label_offsets_list = [
|
229 |
+
load_label_offset(p, inds, tot) for p in label_paths
|
230 |
+
]
|
231 |
+
assert label_processors is None or len(label_processors) == self.num_labels
|
232 |
+
for label_path, label_rate in zip(label_paths, self.label_rates):
|
233 |
+
verify_label_lengths(
|
234 |
+
self.sizes, sample_rate, label_path, label_rate, inds, tot
|
235 |
+
)
|
236 |
+
|
237 |
+
self.max_sample_size = (
|
238 |
+
max_sample_size if max_sample_size is not None else sys.maxsize
|
239 |
+
)
|
240 |
+
self.pad_audio = pad_audio
|
241 |
+
self.normalize = normalize
|
242 |
+
self.reduction_factor = reduction_factor
|
243 |
+
logger.info(
|
244 |
+
f"pad_audio={pad_audio}, random_crop={random_crop}, reduction_factor={reduction_factor}, "
|
245 |
+
f"normalize={normalize}, max_sample_size={self.max_sample_size}"
|
246 |
+
)
|
247 |
+
|
248 |
+
def get_audio(self, index):
|
249 |
+
import soundfile as sf
|
250 |
+
|
251 |
+
wav_path = os.path.join(self.audio_root, self.audio_names[index])
|
252 |
+
wav, cur_sample_rate = sf.read(wav_path)
|
253 |
+
wav = torch.from_numpy(wav).float()
|
254 |
+
fbank = logmelfilterbank(
|
255 |
+
wav.view(-1).cpu().numpy(), 16000
|
256 |
+
)
|
257 |
+
fbank = torch.from_numpy(fbank).float()
|
258 |
+
wav = self.postprocess(wav, cur_sample_rate)
|
259 |
+
return wav, fbank
|
260 |
+
|
261 |
+
def get_label(self, index, label_idx):
|
262 |
+
if self.store_labels:
|
263 |
+
label = self.label_list[label_idx][index]
|
264 |
+
else:
|
265 |
+
with open(self.label_paths[label_idx]) as f:
|
266 |
+
offset_s, offset_e = self.label_offsets_list[label_idx][index]
|
267 |
+
f.seek(offset_s)
|
268 |
+
label = f.read(offset_e - offset_s)
|
269 |
+
|
270 |
+
if self.label_processors is not None:
|
271 |
+
label = self.label_processors[label_idx](label)
|
272 |
+
return label
|
273 |
+
|
274 |
+
def get_labels(self, index):
|
275 |
+
return [self.get_label(index, i) for i in range(self.num_labels)]
|
276 |
+
|
277 |
+
def __getitem__(self, index):
|
278 |
+
wav, fbank = self.get_audio(index)
|
279 |
+
labels = self.get_labels(index)
|
280 |
+
spkembs = get_features_or_waveform(
|
281 |
+
os.path.join(self.audio_root, self.spk_embeds[index])
|
282 |
+
)
|
283 |
+
spkembs = torch.from_numpy(spkembs).float()
|
284 |
+
return {"id": index, "source": wav, "target": fbank, "label_list": labels, 'spkembs': spkembs}
|
285 |
+
|
286 |
+
def __len__(self):
|
287 |
+
return len(self.sizes)
|
288 |
+
|
289 |
+
def crop_to_max_size(self, wav, target_size):
|
290 |
+
size = len(wav)
|
291 |
+
diff = size - target_size
|
292 |
+
if diff <= 0:
|
293 |
+
return wav, 0
|
294 |
+
|
295 |
+
start, end = 0, target_size
|
296 |
+
if self.random_crop:
|
297 |
+
start = np.random.randint(0, diff + 1)
|
298 |
+
end = size - diff + start
|
299 |
+
return wav[start:end], start
|
300 |
+
|
301 |
+
def collater(self, samples):
|
302 |
+
# target = max(sizes) -> random_crop not used
|
303 |
+
# target = max_sample_size -> random_crop used for long
|
304 |
+
samples = [s for s in samples if s["source"] is not None]
|
305 |
+
if len(samples) == 0:
|
306 |
+
return {}
|
307 |
+
|
308 |
+
audios = [s["source"] for s in samples]
|
309 |
+
audio_sizes = [len(s) for s in audios]
|
310 |
+
|
311 |
+
fbanks = [s["target"] for s in samples]
|
312 |
+
fbank_sizes = [len(s) for s in fbanks]
|
313 |
+
|
314 |
+
if self.pad_audio:
|
315 |
+
audio_size = min(max(audio_sizes), self.max_sample_size)
|
316 |
+
else:
|
317 |
+
audio_size = min(min(audio_sizes), self.max_sample_size)
|
318 |
+
collated_audios, padding_mask, audio_starts = self.collater_audio(
|
319 |
+
audios, audio_size
|
320 |
+
)
|
321 |
+
|
322 |
+
collated_fbanks = []
|
323 |
+
collated_audios_size = []
|
324 |
+
for i in range(len(fbanks)):
|
325 |
+
fbank_start = int(audio_starts[i] / (audio_sizes[i] / fbank_sizes[i]))
|
326 |
+
fbank_size = int(audio_size / (audio_sizes[i] / fbank_sizes[i]))
|
327 |
+
fbank_end = min(fbank_start + fbank_size, fbank_sizes[i])
|
328 |
+
collated_fbanks.append(fbanks[i][fbank_start : fbank_end])
|
329 |
+
collated_audios_size.append(audio_size)
|
330 |
+
collated_fbanks_size = [len(s) for s in collated_fbanks]
|
331 |
+
collated_fbanks = _collate_frames(collated_fbanks)
|
332 |
+
collated_fbanks_size = torch.tensor(collated_fbanks_size, dtype=torch.long)
|
333 |
+
|
334 |
+
# thin out frames for reduction factor (B, Lmax, odim) -> (B, Lmax//r, odim)
|
335 |
+
if self.reduction_factor > 1:
|
336 |
+
collated_fbanks_in = collated_fbanks[:, self.reduction_factor - 1 :: self.reduction_factor]
|
337 |
+
collated_fbanks_size_in = collated_fbanks_size.new([torch.div(olen, self.reduction_factor, rounding_mode='floor') for olen in collated_fbanks_size])
|
338 |
+
else:
|
339 |
+
collated_fbanks_in, collated_fbanks_size_in = collated_fbanks, collated_fbanks_size
|
340 |
+
|
341 |
+
prev_output_tokens = torch.cat(
|
342 |
+
[collated_fbanks_in.new_zeros((collated_fbanks_in.shape[0], 1, collated_fbanks_in.shape[2])), collated_fbanks_in[:, :-1]], dim=1
|
343 |
+
)
|
344 |
+
|
345 |
+
# make labels for stop prediction
|
346 |
+
labels = collated_fbanks.new_zeros(collated_fbanks.size(0), collated_fbanks.size(1))
|
347 |
+
for i, l in enumerate(fbank_sizes):
|
348 |
+
labels[i, l - 1 :] = 1.0
|
349 |
+
|
350 |
+
spkembs = _collate_frames([s["spkembs"] for s in samples], is_audio_input=True)
|
351 |
+
|
352 |
+
targets_by_label = [
|
353 |
+
[s["label_list"][i] for s in samples] for i in range(self.num_labels)
|
354 |
+
]
|
355 |
+
targets_list, lengths_list, ntokens_list = self.collater_label(
|
356 |
+
targets_by_label, audio_size, audio_starts
|
357 |
+
)
|
358 |
+
|
359 |
+
net_input = {
|
360 |
+
"source": collated_audios,
|
361 |
+
"padding_mask": padding_mask,
|
362 |
+
"prev_output_tokens": prev_output_tokens,
|
363 |
+
"spkembs": spkembs,
|
364 |
+
"tgt_lengths": collated_fbanks_size_in,
|
365 |
+
}
|
366 |
+
|
367 |
+
batch = {
|
368 |
+
"id": torch.LongTensor([s["id"] for s in samples]),
|
369 |
+
"net_input": net_input,
|
370 |
+
"labels": labels,
|
371 |
+
"dec_target": collated_fbanks,
|
372 |
+
"dec_target_lengths": collated_fbanks_size,
|
373 |
+
"src_lengths": collated_audios_size,
|
374 |
+
"task_name": 'speech_pretrain',
|
375 |
+
}
|
376 |
+
|
377 |
+
if self.single_target:
|
378 |
+
batch["target_lengths"] = lengths_list[0]
|
379 |
+
batch["ntokens"] = ntokens_list[0]
|
380 |
+
batch["target"] = targets_list[0]
|
381 |
+
else:
|
382 |
+
batch["target_lengths_list"] = lengths_list
|
383 |
+
batch["ntokens_list"] = ntokens_list
|
384 |
+
batch["target_list"] = targets_list
|
385 |
+
return batch
|
386 |
+
|
387 |
+
def collater_audio(self, audios, audio_size):
|
388 |
+
collated_audios = audios[0].new_zeros(len(audios), audio_size)
|
389 |
+
padding_mask = (
|
390 |
+
torch.BoolTensor(collated_audios.shape).fill_(False)
|
391 |
+
# if self.pad_audio else None
|
392 |
+
)
|
393 |
+
audio_starts = [0 for _ in audios]
|
394 |
+
for i, audio in enumerate(audios):
|
395 |
+
diff = len(audio) - audio_size
|
396 |
+
if diff == 0:
|
397 |
+
collated_audios[i] = audio
|
398 |
+
elif diff < 0:
|
399 |
+
assert self.pad_audio
|
400 |
+
collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
|
401 |
+
padding_mask[i, diff:] = True
|
402 |
+
else:
|
403 |
+
collated_audios[i], audio_starts[i] = self.crop_to_max_size(
|
404 |
+
audio, audio_size
|
405 |
+
)
|
406 |
+
return collated_audios, padding_mask, audio_starts
|
407 |
+
|
408 |
+
def collater_frm_label(self, targets, audio_size, audio_starts, label_rate, pad):
|
409 |
+
assert label_rate > 0
|
410 |
+
s2f = label_rate / self.sample_rate
|
411 |
+
frm_starts = [int(round(s * s2f)) for s in audio_starts]
|
412 |
+
frm_size = int(round(audio_size * s2f))
|
413 |
+
if not self.pad_audio:
|
414 |
+
rem_size = [len(t) - s for t, s in zip(targets, frm_starts)]
|
415 |
+
frm_size = min(frm_size, *rem_size)
|
416 |
+
targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)]
|
417 |
+
logger.debug(f"audio_starts={audio_starts}")
|
418 |
+
logger.debug(f"frame_starts={frm_starts}")
|
419 |
+
logger.debug(f"frame_size={frm_size}")
|
420 |
+
|
421 |
+
lengths = torch.LongTensor([len(t) for t in targets])
|
422 |
+
ntokens = lengths.sum().item()
|
423 |
+
targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
|
424 |
+
return targets, lengths, ntokens
|
425 |
+
|
426 |
+
def collater_seq_label(self, targets, pad):
|
427 |
+
lengths = torch.LongTensor([len(t) for t in targets])
|
428 |
+
ntokens = lengths.sum().item()
|
429 |
+
targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
|
430 |
+
return targets, lengths, ntokens
|
431 |
+
|
432 |
+
def collater_label(self, targets_by_label, audio_size, audio_starts):
|
433 |
+
targets_list, lengths_list, ntokens_list = [], [], []
|
434 |
+
itr = zip(targets_by_label, self.label_rates, self.pad_list)
|
435 |
+
for targets, label_rate, pad in itr:
|
436 |
+
if label_rate == -1.0:
|
437 |
+
targets, lengths, ntokens = self.collater_seq_label(targets, pad)
|
438 |
+
else:
|
439 |
+
targets, lengths, ntokens = self.collater_frm_label(
|
440 |
+
targets, audio_size, audio_starts, label_rate, pad
|
441 |
+
)
|
442 |
+
targets_list.append(targets)
|
443 |
+
lengths_list.append(lengths)
|
444 |
+
ntokens_list.append(ntokens)
|
445 |
+
return targets_list, lengths_list, ntokens_list
|
446 |
+
|
447 |
+
def num_tokens(self, index):
|
448 |
+
return self.size(index)
|
449 |
+
|
450 |
+
def size(self, index):
|
451 |
+
if self.pad_audio:
|
452 |
+
return self.sizes[index]
|
453 |
+
return min(self.sizes[index], self.max_sample_size)
|
454 |
+
|
455 |
+
def ordered_indices(self):
|
456 |
+
if self.shuffle:
|
457 |
+
order = [np.random.permutation(len(self))]
|
458 |
+
else:
|
459 |
+
order = [np.arange(len(self))]
|
460 |
+
|
461 |
+
order.append(self.sizes)
|
462 |
+
return np.lexsort(order)[::-1]
|
463 |
+
|
464 |
+
def postprocess(self, wav, cur_sample_rate):
|
465 |
+
if wav.dim() == 2:
|
466 |
+
wav = wav.mean(-1)
|
467 |
+
assert wav.dim() == 1, wav.dim()
|
468 |
+
|
469 |
+
if cur_sample_rate != self.sample_rate:
|
470 |
+
raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
|
471 |
+
|
472 |
+
if self.normalize:
|
473 |
+
with torch.no_grad():
|
474 |
+
wav = F.layer_norm(wav, wav.shape)
|
475 |
+
return wav
|
artst/data/speech_to_class_dataset.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
# Based on speecht5, fairseq and espnet code bases
|
5 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import logging
|
9 |
+
import os
|
10 |
+
from typing import Any, List, Optional
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from fairseq.data import data_utils, Dictionary
|
17 |
+
from fairseq.data.fairseq_dataset import FairseqDataset
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
def load_audio(manifest_path, max_keep, min_keep):
|
23 |
+
"""manifest tsv: wav_path, wav_nframe, wav_class
|
24 |
+
|
25 |
+
Args
|
26 |
+
manifest_path: str
|
27 |
+
max_keep: int
|
28 |
+
min_keep: int
|
29 |
+
|
30 |
+
Return
|
31 |
+
root, names, inds, tot, sizes, classes
|
32 |
+
"""
|
33 |
+
n_long, n_short = 0, 0
|
34 |
+
names, inds, sizes, classes = [], [], [], []
|
35 |
+
with open(manifest_path) as f:
|
36 |
+
root = f.readline().strip()
|
37 |
+
for ind, line in enumerate(f):
|
38 |
+
items = line.strip().split("\t")
|
39 |
+
assert len(items) >= 2, line
|
40 |
+
sz = int(items[1])
|
41 |
+
if min_keep is not None and sz < min_keep:
|
42 |
+
n_short += 1
|
43 |
+
elif max_keep is not None and sz > max_keep:
|
44 |
+
n_long += 1
|
45 |
+
else:
|
46 |
+
names.append(items[0])
|
47 |
+
if len(items) > 2:
|
48 |
+
classes.append(items[2])
|
49 |
+
inds.append(ind)
|
50 |
+
sizes.append(sz)
|
51 |
+
tot = ind + 1
|
52 |
+
logger.info(
|
53 |
+
(
|
54 |
+
f"max_keep={max_keep}, min_keep={min_keep}, "
|
55 |
+
f"loaded {len(names)}, skipped {n_short} short and {n_long} long, "
|
56 |
+
f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
|
57 |
+
)
|
58 |
+
)
|
59 |
+
if len(classes) == 0:
|
60 |
+
logger.warn("no classes loaded only if inference")
|
61 |
+
return root, names, inds, tot, sizes, classes
|
62 |
+
|
63 |
+
|
64 |
+
def sample_from_feature(x: np.ndarray, max_segment_length: int = 300):
|
65 |
+
"""Load a segment within 300-400/51200-76800 frames or the corresponding samples from a utterance.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
x (np.ndarray): feature or waveform (frames[, features]), e.g., log mel filter bank or waveform
|
69 |
+
max_segment_length (int, optional): maximum segment length. Defaults to 400.
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
np.ndarray: segmented features
|
73 |
+
"""
|
74 |
+
if len(x) <= max_segment_length:
|
75 |
+
return x
|
76 |
+
start = np.random.randint(0, x.shape[0] - max_segment_length)
|
77 |
+
return x[start: start + max_segment_length]
|
78 |
+
|
79 |
+
|
80 |
+
class SpeechToClassDataset(FairseqDataset):
|
81 |
+
def __init__(
|
82 |
+
self,
|
83 |
+
manifest_path: str,
|
84 |
+
sample_rate: float,
|
85 |
+
label_processors: Optional[List[Any]] = None,
|
86 |
+
max_keep_sample_size: Optional[int] = None,
|
87 |
+
min_keep_sample_size: Optional[int] = None,
|
88 |
+
shuffle: bool = True,
|
89 |
+
normalize: bool = False,
|
90 |
+
tgt_dict: Optional[Dictionary] = None,
|
91 |
+
max_length: Optional[int] = None
|
92 |
+
):
|
93 |
+
self.audio_root, self.audio_names, inds, tot, self.wav_sizes, self.wav_classes = load_audio(
|
94 |
+
manifest_path, max_keep_sample_size, min_keep_sample_size
|
95 |
+
)
|
96 |
+
self.sample_rate = sample_rate
|
97 |
+
self.shuffle = shuffle
|
98 |
+
|
99 |
+
self.label_processors = label_processors
|
100 |
+
|
101 |
+
self.normalize = normalize
|
102 |
+
self.tgt_dict = tgt_dict
|
103 |
+
self.max_length = max_length
|
104 |
+
logger.info(
|
105 |
+
f"max_length={max_length}, normalize={normalize}"
|
106 |
+
)
|
107 |
+
|
108 |
+
def get_audio(self, index):
|
109 |
+
import soundfile as sf
|
110 |
+
|
111 |
+
wav_path = os.path.join(self.audio_root, self.audio_names[index])
|
112 |
+
wav, cur_sample_rate = sf.read(wav_path)
|
113 |
+
if self.max_length is not None:
|
114 |
+
wav = sample_from_feature(wav, self.max_length)
|
115 |
+
wav = torch.from_numpy(wav).float()
|
116 |
+
wav = self.postprocess(wav, cur_sample_rate)
|
117 |
+
return wav
|
118 |
+
|
119 |
+
def get_label(self, index):
|
120 |
+
label = self.wav_classes[index]
|
121 |
+
|
122 |
+
if self.label_processors is not None:
|
123 |
+
label = self.label_processors(label)
|
124 |
+
return label
|
125 |
+
|
126 |
+
def __getitem__(self, index):
|
127 |
+
wav = self.get_audio(index)
|
128 |
+
label = None
|
129 |
+
if len(self.wav_classes) == len(self.audio_names):
|
130 |
+
label = self.get_label(index)
|
131 |
+
return {"id": index, "source": wav, "label": label}
|
132 |
+
|
133 |
+
def __len__(self):
|
134 |
+
return len(self.wav_sizes)
|
135 |
+
|
136 |
+
def collater(self, samples):
|
137 |
+
samples = [s for s in samples if s["source"] is not None]
|
138 |
+
if len(samples) == 0:
|
139 |
+
return {}
|
140 |
+
|
141 |
+
audios = [s["source"] for s in samples]
|
142 |
+
audio_sizes = [len(s) for s in audios]
|
143 |
+
|
144 |
+
audio_size = max(audio_sizes)
|
145 |
+
collated_audios, padding_mask = self.collater_audio(
|
146 |
+
audios, audio_size
|
147 |
+
)
|
148 |
+
|
149 |
+
decoder_label = None
|
150 |
+
decoder_target = None
|
151 |
+
decoder_target_lengths = None
|
152 |
+
if samples[0]["label"] is not None:
|
153 |
+
targets_by_label = [
|
154 |
+
[s["label"] for s in samples]
|
155 |
+
]
|
156 |
+
targets_list, lengths_list, ntokens_list = self.collater_label(targets_by_label)
|
157 |
+
|
158 |
+
decoder_label = [
|
159 |
+
(targets_list[0][i, :lengths_list[0][i]]).long()
|
160 |
+
for i in range(targets_list[0].size(0))
|
161 |
+
]
|
162 |
+
|
163 |
+
decoder_target = data_utils.collate_tokens(
|
164 |
+
decoder_label,
|
165 |
+
self.tgt_dict.pad(),
|
166 |
+
self.tgt_dict.eos(),
|
167 |
+
left_pad=False,
|
168 |
+
move_eos_to_beginning=False,
|
169 |
+
)
|
170 |
+
decoder_target_lengths = torch.tensor(
|
171 |
+
[x.size(0) for x in decoder_label], dtype=torch.long
|
172 |
+
)
|
173 |
+
prev_output_tokens = data_utils.collate_tokens(
|
174 |
+
[torch.LongTensor([-1]) for _ in samples],
|
175 |
+
self.tgt_dict.pad(),
|
176 |
+
self.tgt_dict.eos(),
|
177 |
+
left_pad=False,
|
178 |
+
move_eos_to_beginning=True,
|
179 |
+
)
|
180 |
+
|
181 |
+
net_input = {
|
182 |
+
"source": collated_audios,
|
183 |
+
"padding_mask": padding_mask,
|
184 |
+
"prev_output_tokens": prev_output_tokens,
|
185 |
+
"task_name": "s2c",
|
186 |
+
}
|
187 |
+
batch = {
|
188 |
+
"id": torch.LongTensor([s["id"] for s in samples]),
|
189 |
+
"net_input": net_input,
|
190 |
+
"target": decoder_target,
|
191 |
+
"target_lengths": decoder_target_lengths,
|
192 |
+
"task_name": "s2c",
|
193 |
+
"ntokens": len(samples),
|
194 |
+
}
|
195 |
+
|
196 |
+
return batch
|
197 |
+
|
198 |
+
def collater_audio(self, audios, audio_size):
|
199 |
+
collated_audios = audios[0].new_zeros(len(audios), audio_size)
|
200 |
+
padding_mask = (
|
201 |
+
torch.BoolTensor(collated_audios.shape).fill_(False)
|
202 |
+
)
|
203 |
+
for i, audio in enumerate(audios):
|
204 |
+
diff = len(audio) - audio_size
|
205 |
+
if diff == 0:
|
206 |
+
collated_audios[i] = audio
|
207 |
+
elif diff < 0:
|
208 |
+
collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
|
209 |
+
padding_mask[i, diff:] = True
|
210 |
+
else:
|
211 |
+
raise Exception("Diff should not be larger than 0")
|
212 |
+
return collated_audios, padding_mask
|
213 |
+
|
214 |
+
def collater_seq_label(self, targets, pad):
|
215 |
+
lengths = torch.LongTensor([len(t) for t in targets])
|
216 |
+
ntokens = lengths.sum().item()
|
217 |
+
targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
|
218 |
+
return targets, lengths, ntokens
|
219 |
+
|
220 |
+
def collater_label(self, targets_by_label):
|
221 |
+
targets_list, lengths_list, ntokens_list = [], [], []
|
222 |
+
itr = zip(targets_by_label, [self.tgt_dict.pad()])
|
223 |
+
for targets, pad in itr:
|
224 |
+
targets, lengths, ntokens = self.collater_seq_label(targets, pad)
|
225 |
+
targets_list.append(targets)
|
226 |
+
lengths_list.append(lengths)
|
227 |
+
ntokens_list.append(ntokens)
|
228 |
+
return targets_list, lengths_list, ntokens_list
|
229 |
+
|
230 |
+
def num_tokens(self, index):
|
231 |
+
return self.size(index)
|
232 |
+
|
233 |
+
def size(self, index):
|
234 |
+
return self.wav_sizes[index]
|
235 |
+
|
236 |
+
@property
|
237 |
+
def sizes(self):
|
238 |
+
return np.array(self.wav_sizes)
|
239 |
+
|
240 |
+
def ordered_indices(self):
|
241 |
+
if self.shuffle:
|
242 |
+
order = [np.random.permutation(len(self))]
|
243 |
+
else:
|
244 |
+
order = [np.arange(len(self))]
|
245 |
+
|
246 |
+
order.append(self.wav_sizes)
|
247 |
+
return np.lexsort(order)[::-1]
|
248 |
+
|
249 |
+
def postprocess(self, wav, cur_sample_rate):
|
250 |
+
if wav.dim() == 2:
|
251 |
+
wav = wav.mean(-1)
|
252 |
+
assert wav.dim() == 1, wav.dim()
|
253 |
+
|
254 |
+
if cur_sample_rate != self.sample_rate:
|
255 |
+
raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
|
256 |
+
|
257 |
+
if self.normalize:
|
258 |
+
with torch.no_grad():
|
259 |
+
wav = F.layer_norm(wav, wav.shape)
|
260 |
+
return wav
|
artst/data/speech_to_speech_dataset.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
# Based on speecht5, fairseq and espnet code bases
|
5 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import logging
|
9 |
+
import os
|
10 |
+
from typing import Any, List, Optional
|
11 |
+
|
12 |
+
import librosa
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from fairseq.data.fairseq_dataset import FairseqDataset
|
17 |
+
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
def _collate_frames(
|
21 |
+
frames: List[torch.Tensor], is_audio_input: bool = False
|
22 |
+
):
|
23 |
+
"""
|
24 |
+
Convert a list of 2D frames into a padded 3D tensor
|
25 |
+
Args:
|
26 |
+
frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is
|
27 |
+
length of i-th frame and f_dim is static dimension of features
|
28 |
+
Returns:
|
29 |
+
3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
|
30 |
+
"""
|
31 |
+
max_len = max(frame.size(0) for frame in frames)
|
32 |
+
if is_audio_input:
|
33 |
+
out = frames[0].new_zeros((len(frames), max_len))
|
34 |
+
else:
|
35 |
+
out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1)))
|
36 |
+
for i, v in enumerate(frames):
|
37 |
+
out[i, : v.size(0)] = v
|
38 |
+
return out
|
39 |
+
|
40 |
+
def load_audio(manifest_path, max_keep, min_keep):
|
41 |
+
"""manifest tsv: src_wav, src_nframe, tgt_wav, tgt_nframe, tgt_spkemb"""
|
42 |
+
n_long, n_short = 0, 0
|
43 |
+
src_names, tgt_names, inds, sizes, tgt_sizes, spk_embeds = [], [], [], [], [], []
|
44 |
+
with open(manifest_path) as f:
|
45 |
+
root = f.readline().strip()
|
46 |
+
for ind, line in enumerate(f):
|
47 |
+
items = line.strip().split("\t")
|
48 |
+
assert len(items) >= 2, line
|
49 |
+
sz = int(items[1])
|
50 |
+
if min_keep is not None and sz < min_keep:
|
51 |
+
n_short += 1
|
52 |
+
elif max_keep is not None and sz > max_keep:
|
53 |
+
n_long += 1
|
54 |
+
else:
|
55 |
+
src_names.append(items[0])
|
56 |
+
tgt_names.append(items[2])
|
57 |
+
tgt_sizes.append(items[3])
|
58 |
+
spk_embeds.append(items[4])
|
59 |
+
inds.append(ind)
|
60 |
+
sizes.append(sz)
|
61 |
+
tot = ind + 1
|
62 |
+
logger.info(
|
63 |
+
(
|
64 |
+
f"max_keep={max_keep}, min_keep={min_keep}, "
|
65 |
+
f"loaded {len(src_names)}, skipped {n_short} short and {n_long} long, "
|
66 |
+
f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
|
67 |
+
)
|
68 |
+
)
|
69 |
+
return root, src_names, inds, tot, sizes, tgt_names, tgt_sizes, spk_embeds
|
70 |
+
|
71 |
+
|
72 |
+
def logmelfilterbank(
|
73 |
+
audio,
|
74 |
+
sampling_rate,
|
75 |
+
fft_size=1024,
|
76 |
+
hop_size=256,
|
77 |
+
win_length=None,
|
78 |
+
window="hann",
|
79 |
+
num_mels=80,
|
80 |
+
fmin=80,
|
81 |
+
fmax=7600,
|
82 |
+
eps=1e-10,
|
83 |
+
):
|
84 |
+
"""Compute log-Mel filterbank feature.
|
85 |
+
(https://github.com/kan-bayashi/ParallelWaveGAN/blob/master/parallel_wavegan/bin/preprocess.py)
|
86 |
+
|
87 |
+
Args:
|
88 |
+
audio (ndarray): Audio signal (T,).
|
89 |
+
sampling_rate (int): Sampling rate.
|
90 |
+
fft_size (int): FFT size.
|
91 |
+
hop_size (int): Hop size.
|
92 |
+
win_length (int): Window length. If set to None, it will be the same as fft_size.
|
93 |
+
window (str): Window function type.
|
94 |
+
num_mels (int): Number of mel basis.
|
95 |
+
fmin (int): Minimum frequency in mel basis calculation.
|
96 |
+
fmax (int): Maximum frequency in mel basis calculation.
|
97 |
+
eps (float): Epsilon value to avoid inf in log calculation.
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
ndarray: Log Mel filterbank feature (#frames, num_mels).
|
101 |
+
|
102 |
+
"""
|
103 |
+
# get amplitude spectrogram
|
104 |
+
x_stft = librosa.stft(audio, n_fft=fft_size, hop_length=hop_size,
|
105 |
+
win_length=win_length, window=window, pad_mode="reflect")
|
106 |
+
spc = np.abs(x_stft).T # (#frames, #bins)
|
107 |
+
|
108 |
+
# get mel basis
|
109 |
+
fmin = 0 if fmin is None else fmin
|
110 |
+
fmax = sampling_rate / 2 if fmax is None else fmax
|
111 |
+
mel_basis = librosa.filters.mel(sr=sampling_rate, n_fft=fft_size, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
112 |
+
|
113 |
+
return np.log10(np.maximum(eps, np.dot(spc, mel_basis.T)))
|
114 |
+
|
115 |
+
|
116 |
+
class SpeechToSpeechDataset(FairseqDataset):
|
117 |
+
def __init__(
|
118 |
+
self,
|
119 |
+
manifest_path: str,
|
120 |
+
sample_rate: float,
|
121 |
+
max_keep_sample_size: Optional[int] = None,
|
122 |
+
min_keep_sample_size: Optional[int] = None,
|
123 |
+
shuffle: bool = True,
|
124 |
+
normalize: bool = False,
|
125 |
+
reduction_factor: int = 1,
|
126 |
+
):
|
127 |
+
self.audio_root, self.audio_names, inds, tot, self.wav_sizes, self.tgt_audios, self.tgt_sizes, self.tgt_spkembs = load_audio(
|
128 |
+
manifest_path, max_keep_sample_size, min_keep_sample_size
|
129 |
+
)
|
130 |
+
self.sample_rate = sample_rate
|
131 |
+
self.shuffle = shuffle
|
132 |
+
|
133 |
+
self.normalize = normalize
|
134 |
+
self.reduction_factor = reduction_factor
|
135 |
+
logger.info(
|
136 |
+
f"reduction_factor={reduction_factor}, normalize={normalize}"
|
137 |
+
)
|
138 |
+
|
139 |
+
def get_audio(self, index):
|
140 |
+
import soundfile as sf
|
141 |
+
|
142 |
+
wav_fbank = []
|
143 |
+
for name in [self.audio_names[index], self.tgt_audios[index]]:
|
144 |
+
wav_path = os.path.join(self.audio_root, name)
|
145 |
+
wav, cur_sample_rate = sf.read(wav_path)
|
146 |
+
wav = torch.from_numpy(wav).float()
|
147 |
+
fbank = logmelfilterbank(
|
148 |
+
wav.view(-1).cpu().numpy(), 16000
|
149 |
+
)
|
150 |
+
fbank = torch.from_numpy(fbank).float()
|
151 |
+
wav = self.postprocess(wav, cur_sample_rate)
|
152 |
+
wav_fbank.append(wav)
|
153 |
+
wav_fbank.append(fbank)
|
154 |
+
src_wav, src_fbank, tgt_wav, tgt_fbank = wav_fbank
|
155 |
+
return src_wav, src_fbank, tgt_wav, tgt_fbank
|
156 |
+
|
157 |
+
def __getitem__(self, index):
|
158 |
+
src_wav, src_fbank, tgt_wav, tgt_fbank = self.get_audio(index)
|
159 |
+
spkembs = np.load(os.path.join(self.audio_root, self.tgt_spkembs[index]))
|
160 |
+
spkembs = torch.from_numpy(spkembs).float()
|
161 |
+
name = self.audio_names[index].replace("/", ".").replace(".wav", "") + "-" + self.tgt_audios[index].replace("/", ".").replace(".wav", "") + ".wav"
|
162 |
+
return {"id": index, "source": src_wav, "target": tgt_fbank, "spkembs": spkembs, "audio_name": name, "tgt_name": self.tgt_audios[index]}
|
163 |
+
|
164 |
+
def __len__(self):
|
165 |
+
return len(self.wav_sizes)
|
166 |
+
|
167 |
+
def collater(self, samples):
|
168 |
+
samples = [s for s in samples if s["source"] is not None]
|
169 |
+
if len(samples) == 0:
|
170 |
+
return {}
|
171 |
+
|
172 |
+
audios = [s["source"] for s in samples]
|
173 |
+
audio_sizes = [len(s) for s in audios]
|
174 |
+
|
175 |
+
audio_size = max(audio_sizes)
|
176 |
+
collated_audios, padding_mask = self.collater_audio(
|
177 |
+
audios, audio_size
|
178 |
+
)
|
179 |
+
|
180 |
+
fbanks = [s["target"] for s in samples]
|
181 |
+
fbank_sizes = [len(s) for s in fbanks]
|
182 |
+
|
183 |
+
collated_fbanks = _collate_frames(fbanks)
|
184 |
+
collated_fbanks_size = torch.tensor(fbank_sizes, dtype=torch.long)
|
185 |
+
|
186 |
+
# thin out frames for reduction factor (B, Lmax, odim) -> (B, Lmax//r, odim)
|
187 |
+
if self.reduction_factor > 1:
|
188 |
+
collated_fbanks_in = collated_fbanks[:, self.reduction_factor - 1 :: self.reduction_factor]
|
189 |
+
collated_fbanks_size_in = collated_fbanks_size.new([torch.div(olen, self.reduction_factor, rounding_mode='floor') for olen in collated_fbanks_size])
|
190 |
+
else:
|
191 |
+
collated_fbanks_in, collated_fbanks_size_in = collated_fbanks, collated_fbanks_size
|
192 |
+
|
193 |
+
prev_output_tokens = torch.cat(
|
194 |
+
[collated_fbanks_in.new_zeros((collated_fbanks_in.shape[0], 1, collated_fbanks_in.shape[2])), collated_fbanks_in[:, :-1]], dim=1
|
195 |
+
)
|
196 |
+
|
197 |
+
# make labels for stop prediction
|
198 |
+
labels = collated_fbanks.new_zeros(collated_fbanks.size(0), collated_fbanks.size(1))
|
199 |
+
for i, l in enumerate(fbank_sizes):
|
200 |
+
labels[i, l - 1 :] = 1.0
|
201 |
+
|
202 |
+
spkembs = _collate_frames([s["spkembs"] for s in samples], is_audio_input=True)
|
203 |
+
|
204 |
+
net_input = {
|
205 |
+
"source": collated_audios,
|
206 |
+
"padding_mask": padding_mask,
|
207 |
+
"prev_output_tokens": prev_output_tokens,
|
208 |
+
"tgt_lengths": collated_fbanks_size_in,
|
209 |
+
"spkembs": spkembs,
|
210 |
+
"task_name": "s2s",
|
211 |
+
}
|
212 |
+
batch = {
|
213 |
+
"id": torch.LongTensor([s["id"] for s in samples]),
|
214 |
+
"name": [s["audio_name"] for s in samples],
|
215 |
+
"tgt_name": [s["tgt_name"] for s in samples],
|
216 |
+
"net_input": net_input,
|
217 |
+
"labels": labels,
|
218 |
+
"dec_target": collated_fbanks,
|
219 |
+
"dec_target_lengths": collated_fbanks_size,
|
220 |
+
"src_lengths": torch.LongTensor(audio_sizes),
|
221 |
+
"task_name": "s2s",
|
222 |
+
"ntokens": sum(audio_sizes),
|
223 |
+
"target": collated_fbanks,
|
224 |
+
}
|
225 |
+
|
226 |
+
return batch
|
227 |
+
|
228 |
+
def collater_audio(self, audios, audio_size):
|
229 |
+
collated_audios = audios[0].new_zeros(len(audios), audio_size)
|
230 |
+
padding_mask = (
|
231 |
+
torch.BoolTensor(collated_audios.shape).fill_(False)
|
232 |
+
)
|
233 |
+
for i, audio in enumerate(audios):
|
234 |
+
diff = len(audio) - audio_size
|
235 |
+
if diff == 0:
|
236 |
+
collated_audios[i] = audio
|
237 |
+
elif diff < 0:
|
238 |
+
collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
|
239 |
+
padding_mask[i, diff:] = True
|
240 |
+
else:
|
241 |
+
raise Exception("Diff should not be larger than 0")
|
242 |
+
return collated_audios, padding_mask
|
243 |
+
|
244 |
+
|
245 |
+
def num_tokens(self, index):
|
246 |
+
return self.wav_sizes[index]
|
247 |
+
|
248 |
+
def size(self, index):
|
249 |
+
return self.wav_sizes[index], self.tgt_sizes[index]
|
250 |
+
|
251 |
+
@property
|
252 |
+
def sizes(self):
|
253 |
+
return np.array(self.wav_sizes)
|
254 |
+
|
255 |
+
@property
|
256 |
+
def can_reuse_epoch_itr_across_epochs(self):
|
257 |
+
"""No cache dataset if dataset is large-scale. Cache dataset for small dataset."""
|
258 |
+
return True
|
259 |
+
|
260 |
+
def ordered_indices(self):
|
261 |
+
if self.shuffle:
|
262 |
+
order = [np.random.permutation(len(self))]
|
263 |
+
else:
|
264 |
+
order = [np.arange(len(self))]
|
265 |
+
|
266 |
+
order.append(self.wav_sizes)
|
267 |
+
return np.lexsort(order)[::-1]
|
268 |
+
|
269 |
+
def postprocess(self, wav, cur_sample_rate):
|
270 |
+
if wav.dim() == 2:
|
271 |
+
wav = wav.mean(-1)
|
272 |
+
assert wav.dim() == 1, wav.dim()
|
273 |
+
|
274 |
+
if cur_sample_rate != self.sample_rate:
|
275 |
+
raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
|
276 |
+
|
277 |
+
if self.normalize:
|
278 |
+
with torch.no_grad():
|
279 |
+
wav = F.layer_norm(wav, wav.shape)
|
280 |
+
return wav
|
artst/data/speech_to_text_dataset.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
# Based on speecht5, fairseq and espnet code bases
|
5 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import itertools
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
import mmap
|
12 |
+
from typing import Any, List, Optional
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
import torch
|
17 |
+
torch.set_printoptions(profile="full")
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from fairseq.data import data_utils, Dictionary
|
20 |
+
from fairseq.data.fairseq_dataset import FairseqDataset
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
def load_audio(manifest_path, max_keep, min_keep):
|
26 |
+
n_long, n_short = 0, 0
|
27 |
+
names, inds, sizes = [], [], []
|
28 |
+
with open(manifest_path) as f:
|
29 |
+
root = f.readline().strip()
|
30 |
+
for ind, line in enumerate(f):
|
31 |
+
items = line.strip().split("\t")
|
32 |
+
assert len(items) >= 2, line
|
33 |
+
sz = int(items[1])
|
34 |
+
if min_keep is not None and sz < min_keep:
|
35 |
+
n_short += 1
|
36 |
+
elif max_keep is not None and sz > max_keep:
|
37 |
+
n_long += 1
|
38 |
+
else:
|
39 |
+
names.append(items[0])
|
40 |
+
inds.append(ind)
|
41 |
+
sizes.append(sz)
|
42 |
+
tot = ind + 1
|
43 |
+
logger.info(
|
44 |
+
(
|
45 |
+
f"max_keep={max_keep}, min_keep={min_keep}, "
|
46 |
+
f"loaded {len(names)}, skipped {n_short} short and {n_long} long, "
|
47 |
+
f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
|
48 |
+
)
|
49 |
+
)
|
50 |
+
return root, names, inds, tot, sizes
|
51 |
+
|
52 |
+
|
53 |
+
def load_label(label_path, inds, tot):
|
54 |
+
with open(label_path) as f:
|
55 |
+
labels = [line.rstrip() for line in f]
|
56 |
+
assert (
|
57 |
+
len(labels) == tot
|
58 |
+
), f"number of labels does not match ({len(labels)} != {tot})"
|
59 |
+
labels = [labels[i] for i in inds]
|
60 |
+
return labels
|
61 |
+
|
62 |
+
|
63 |
+
def load_label_offset(label_path, inds, tot):
|
64 |
+
with open(label_path) as f:
|
65 |
+
# Hawau:
|
66 |
+
# changed line length reading as it's incorrect
|
67 |
+
code_lengths = [len(line.encode("utf-8")) for line in f] #original
|
68 |
+
# code_lengths = [len(line) for line in f] #fix
|
69 |
+
assert (
|
70 |
+
len(code_lengths) == tot
|
71 |
+
), f"number of labels does not match ({len(code_lengths)} != {tot})"
|
72 |
+
offsets = list(itertools.accumulate([0] + code_lengths))
|
73 |
+
offsets = [(offsets[i], offsets[i + 1]) for i in inds]
|
74 |
+
return offsets
|
75 |
+
|
76 |
+
|
77 |
+
class SpeechToTextDataset(FairseqDataset):
|
78 |
+
def __init__(
|
79 |
+
self,
|
80 |
+
manifest_path: str,
|
81 |
+
sample_rate: float,
|
82 |
+
label_paths: List[str],
|
83 |
+
label_processors: Optional[List[Any]] = None,
|
84 |
+
max_keep_sample_size: Optional[int] = None,
|
85 |
+
min_keep_sample_size: Optional[int] = None,
|
86 |
+
shuffle: bool = True,
|
87 |
+
normalize: bool = False,
|
88 |
+
store_labels: bool = True,
|
89 |
+
tgt_dict: Optional[Dictionary] = None,
|
90 |
+
tokenizer = None,
|
91 |
+
):
|
92 |
+
self.audio_root, self.audio_names, inds, tot, self.wav_sizes = load_audio(
|
93 |
+
manifest_path, max_keep_sample_size, min_keep_sample_size
|
94 |
+
)
|
95 |
+
|
96 |
+
self.sample_rate = sample_rate
|
97 |
+
self.shuffle = shuffle
|
98 |
+
self.tgt_dict = tgt_dict
|
99 |
+
self.tokenizer = tokenizer
|
100 |
+
|
101 |
+
self.num_labels = len(label_paths)
|
102 |
+
self.label_processors = label_processors
|
103 |
+
self.store_labels = store_labels
|
104 |
+
|
105 |
+
if store_labels:
|
106 |
+
self.label_list = [load_label(p, inds, tot) for p in label_paths]
|
107 |
+
logger.info(f"label_list: {self.label_list}")
|
108 |
+
else:
|
109 |
+
self.label_paths = label_paths
|
110 |
+
self.label_offsets_list = [
|
111 |
+
load_label_offset(p, inds, tot) for p in label_paths
|
112 |
+
]
|
113 |
+
# logger.info(f"label_offsets_list: {self.label_offsets_list}")
|
114 |
+
assert label_processors is None or len(label_processors) == self.num_labels
|
115 |
+
|
116 |
+
self.normalize = normalize
|
117 |
+
logger.info(
|
118 |
+
f"normalize={normalize}"
|
119 |
+
)
|
120 |
+
|
121 |
+
def get_audio(self, index):
|
122 |
+
import soundfile as sf
|
123 |
+
# Hawau:
|
124 |
+
# logger.info(f"loaded_audio: {self.audio_names[index]}")
|
125 |
+
wav_path = os.path.join(self.audio_root, self.audio_names[index])
|
126 |
+
wav, cur_sample_rate = sf.read(wav_path)
|
127 |
+
wav = torch.from_numpy(wav).float()
|
128 |
+
wav = self.postprocess(wav, cur_sample_rate)
|
129 |
+
return wav
|
130 |
+
|
131 |
+
def get_label(self, index, label_idx):
|
132 |
+
if self.store_labels:
|
133 |
+
label = self.label_list[label_idx][index]
|
134 |
+
else:
|
135 |
+
# list slicing method
|
136 |
+
# with open(self.label_paths[label_idx]) as f:
|
137 |
+
# offset_s, offset_e = self.label_offsets_list[label_idx][index]
|
138 |
+
# # Hawau:
|
139 |
+
# # f.seek(offset_s)
|
140 |
+
# # label = f.read(offset_e - offset_s)
|
141 |
+
# label = f.read()[offset_s : offset_e]
|
142 |
+
# Hawau:
|
143 |
+
# mmap method
|
144 |
+
with open(self.label_paths[label_idx], encoding='utf-8') as f:
|
145 |
+
with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm:
|
146 |
+
offset_s, offset_e = self.label_offsets_list[label_idx][index]
|
147 |
+
label = mm[offset_s:offset_e].decode("utf-8")
|
148 |
+
|
149 |
+
|
150 |
+
# Hawau:
|
151 |
+
# logger.info(f"loaded_label: {label}")
|
152 |
+
if self.tokenizer is not None:
|
153 |
+
label = self.tokenizer.encode(label)
|
154 |
+
|
155 |
+
if self.label_processors is not None:
|
156 |
+
label = self.label_processors[label_idx](label)
|
157 |
+
# logger.info(f"processed_label: {label}")
|
158 |
+
return label
|
159 |
+
|
160 |
+
def get_labels(self, index):
|
161 |
+
return [self.get_label(index, i) for i in range(self.num_labels)]
|
162 |
+
|
163 |
+
def __getitem__(self, index):
|
164 |
+
wav = self.get_audio(index)
|
165 |
+
labels = self.get_labels(index)
|
166 |
+
return {"id": index, "source": wav, "label_list": labels}
|
167 |
+
|
168 |
+
def __len__(self):
|
169 |
+
return len(self.wav_sizes)
|
170 |
+
|
171 |
+
def collater(self, samples):
|
172 |
+
samples = [s for s in samples if s["source"] is not None]
|
173 |
+
if len(samples) == 0:
|
174 |
+
return {}
|
175 |
+
|
176 |
+
audios = [s["source"] for s in samples]
|
177 |
+
audio_sizes = [len(s) for s in audios]
|
178 |
+
|
179 |
+
audio_size = max(audio_sizes)
|
180 |
+
collated_audios, padding_mask = self.collater_audio(
|
181 |
+
audios, audio_size
|
182 |
+
)
|
183 |
+
|
184 |
+
targets_by_label = [
|
185 |
+
[s["label_list"][i] for s in samples] for i in range(self.num_labels)
|
186 |
+
]
|
187 |
+
targets_list, lengths_list, ntokens_list = self.collater_label(targets_by_label)
|
188 |
+
|
189 |
+
# Hawau:
|
190 |
+
# logger.info(f'targets_list: {targets_list}')
|
191 |
+
|
192 |
+
|
193 |
+
decoder_label = [
|
194 |
+
torch.cat((targets_list[0][i, :lengths_list[0][i]], torch.tensor([self.tgt_dict.eos()])), 0).long()
|
195 |
+
for i in range(targets_list[0].size(0))
|
196 |
+
]
|
197 |
+
|
198 |
+
decoder_target = data_utils.collate_tokens(
|
199 |
+
decoder_label,
|
200 |
+
self.tgt_dict.pad(),
|
201 |
+
self.tgt_dict.eos(),
|
202 |
+
left_pad=False,
|
203 |
+
move_eos_to_beginning=False,
|
204 |
+
)
|
205 |
+
decoder_target_lengths = torch.tensor(
|
206 |
+
[x.size(0) for x in decoder_label], dtype=torch.long
|
207 |
+
)
|
208 |
+
prev_output_tokens = data_utils.collate_tokens(
|
209 |
+
decoder_label,
|
210 |
+
self.tgt_dict.pad(),
|
211 |
+
self.tgt_dict.eos(),
|
212 |
+
left_pad=False,
|
213 |
+
move_eos_to_beginning=True,
|
214 |
+
)
|
215 |
+
|
216 |
+
net_input = {
|
217 |
+
"source": collated_audios,
|
218 |
+
"padding_mask": padding_mask,
|
219 |
+
"prev_output_tokens": prev_output_tokens,
|
220 |
+
"task_name": "s2t",
|
221 |
+
}
|
222 |
+
batch = {
|
223 |
+
"id": torch.LongTensor([s["id"] for s in samples]),
|
224 |
+
"net_input": net_input,
|
225 |
+
"target": decoder_target,
|
226 |
+
"target_lengths": decoder_target_lengths,
|
227 |
+
"task_name": "s2t",
|
228 |
+
"ntokens": ntokens_list[0]
|
229 |
+
}
|
230 |
+
|
231 |
+
return batch
|
232 |
+
|
233 |
+
def collater_audio(self, audios, audio_size):
|
234 |
+
collated_audios = audios[0].new_zeros(len(audios), audio_size)
|
235 |
+
padding_mask = (
|
236 |
+
torch.BoolTensor(collated_audios.shape).fill_(False)
|
237 |
+
)
|
238 |
+
for i, audio in enumerate(audios):
|
239 |
+
diff = len(audio) - audio_size
|
240 |
+
if diff == 0:
|
241 |
+
collated_audios[i] = audio
|
242 |
+
elif diff < 0:
|
243 |
+
collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
|
244 |
+
padding_mask[i, diff:] = True
|
245 |
+
else:
|
246 |
+
raise Exception("Diff should not be larger than 0")
|
247 |
+
return collated_audios, padding_mask
|
248 |
+
|
249 |
+
def collater_seq_label(self, targets, pad):
|
250 |
+
lengths = torch.LongTensor([len(t) for t in targets])
|
251 |
+
ntokens = lengths.sum().item()
|
252 |
+
targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
|
253 |
+
return targets, lengths, ntokens
|
254 |
+
|
255 |
+
def collater_label(self, targets_by_label):
|
256 |
+
targets_list, lengths_list, ntokens_list = [], [], []
|
257 |
+
itr = zip(targets_by_label, [self.tgt_dict.pad()])
|
258 |
+
|
259 |
+
for targets, pad in itr:
|
260 |
+
# Hawau:
|
261 |
+
# logger.info(f'targets: {targets}')
|
262 |
+
targets, lengths, ntokens = self.collater_seq_label(targets, pad)
|
263 |
+
targets_list.append(targets)
|
264 |
+
lengths_list.append(lengths)
|
265 |
+
ntokens_list.append(ntokens)
|
266 |
+
return targets_list, lengths_list, ntokens_list
|
267 |
+
|
268 |
+
def num_tokens(self, index):
|
269 |
+
return self.size(index)
|
270 |
+
|
271 |
+
def size(self, index):
|
272 |
+
return self.wav_sizes[index]
|
273 |
+
|
274 |
+
@property
|
275 |
+
def sizes(self):
|
276 |
+
return np.array(self.wav_sizes)
|
277 |
+
|
278 |
+
def ordered_indices(self):
|
279 |
+
if self.shuffle:
|
280 |
+
order = [np.random.permutation(len(self))]
|
281 |
+
else:
|
282 |
+
order = [np.arange(len(self))]
|
283 |
+
|
284 |
+
order.append(self.wav_sizes)
|
285 |
+
return np.lexsort(order)[::-1]
|
286 |
+
|
287 |
+
def postprocess(self, wav, cur_sample_rate):
|
288 |
+
if wav.dim() == 2:
|
289 |
+
wav = wav.mean(-1)
|
290 |
+
assert wav.dim() == 1, wav.dim()
|
291 |
+
|
292 |
+
if cur_sample_rate != self.sample_rate:
|
293 |
+
raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
|
294 |
+
|
295 |
+
if self.normalize:
|
296 |
+
with torch.no_grad():
|
297 |
+
wav = F.layer_norm(wav, wav.shape)
|
298 |
+
return wav
|
artst/data/text_dataset.py
ADDED
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
# Based on speecht5, fairseq and espnet code bases
|
5 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import math
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from fairseq.data import FairseqDataset, data_utils
|
14 |
+
|
15 |
+
|
16 |
+
def collate(
|
17 |
+
samples,
|
18 |
+
pad_idx,
|
19 |
+
eos_idx,
|
20 |
+
vocab,
|
21 |
+
left_pad_source=False,
|
22 |
+
left_pad_target=False,
|
23 |
+
input_feeding=True,
|
24 |
+
pad_to_length=None,
|
25 |
+
):
|
26 |
+
assert input_feeding
|
27 |
+
if len(samples) == 0:
|
28 |
+
return {}
|
29 |
+
|
30 |
+
def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
|
31 |
+
return data_utils.collate_tokens(
|
32 |
+
[s[key] for s in samples],
|
33 |
+
pad_idx,
|
34 |
+
eos_idx=None, # use eos_idx of each sample instead of vocab.eos()
|
35 |
+
left_pad=left_pad,
|
36 |
+
move_eos_to_beginning=move_eos_to_beginning,
|
37 |
+
pad_to_length=pad_to_length,
|
38 |
+
)
|
39 |
+
|
40 |
+
id = torch.LongTensor([s["id"] for s in samples])
|
41 |
+
src_tokens = merge(
|
42 |
+
"source",
|
43 |
+
left_pad=left_pad_source,
|
44 |
+
pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
|
45 |
+
)
|
46 |
+
# sort by descending source length
|
47 |
+
src_lengths = torch.LongTensor([s["source"].numel() for s in samples])
|
48 |
+
src_lengths, sort_order = src_lengths.sort(descending=True)
|
49 |
+
id = id.index_select(0, sort_order)
|
50 |
+
src_tokens = src_tokens.index_select(0, sort_order)
|
51 |
+
|
52 |
+
prev_output_tokens = None
|
53 |
+
target = None
|
54 |
+
if samples[0].get("target", None) is not None:
|
55 |
+
target = merge(
|
56 |
+
"target",
|
57 |
+
left_pad=left_pad_target,
|
58 |
+
pad_to_length=pad_to_length["target"]
|
59 |
+
if pad_to_length is not None
|
60 |
+
else None,
|
61 |
+
)
|
62 |
+
target = target.index_select(0, sort_order)
|
63 |
+
ntokens = sum(len(s["target"]) for s in samples)
|
64 |
+
|
65 |
+
if input_feeding:
|
66 |
+
# we create a shifted version of targets for feeding the
|
67 |
+
# previous output token(s) into the next decoder step
|
68 |
+
prev_output_tokens = merge(
|
69 |
+
"target",
|
70 |
+
left_pad=left_pad_target,
|
71 |
+
move_eos_to_beginning=True,
|
72 |
+
pad_to_length=pad_to_length["target"]
|
73 |
+
if pad_to_length is not None
|
74 |
+
else None,
|
75 |
+
)
|
76 |
+
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
|
77 |
+
else:
|
78 |
+
ntokens = sum(len(s["source"]) for s in samples)
|
79 |
+
|
80 |
+
batch = {
|
81 |
+
"id": id,
|
82 |
+
"ntokens": ntokens,
|
83 |
+
"net_input": {
|
84 |
+
"src_tokens": src_tokens,
|
85 |
+
"src_lengths": src_lengths,
|
86 |
+
},
|
87 |
+
"target": target,
|
88 |
+
"nsentences": samples[0]["source"].size(0),
|
89 |
+
"sort_order": sort_order,
|
90 |
+
"task_name": 'text_pretrain',
|
91 |
+
}
|
92 |
+
if prev_output_tokens is not None:
|
93 |
+
batch["net_input"]["prev_output_tokens"] = prev_output_tokens
|
94 |
+
|
95 |
+
return batch
|
96 |
+
|
97 |
+
|
98 |
+
class TextPretrainDataset(FairseqDataset):
|
99 |
+
"""
|
100 |
+
A wrapper around TokenBlockDataset for BART dataset.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
dataset (TokenBlockDataset): dataset to wrap
|
104 |
+
sizes (List[int]): sentence lengths
|
105 |
+
vocab (~fairseq.data.Dictionary): vocabulary
|
106 |
+
mask_idx (int): dictionary index used for masked token
|
107 |
+
mask_whole_words: only mask whole words. This should be a byte mask
|
108 |
+
over vocab indices, indicating whether it is the beginning of a
|
109 |
+
word. We will extend any mask to encompass the whole word.
|
110 |
+
shuffle (bool, optional): shuffle the elements before batching.
|
111 |
+
Default: ``True``
|
112 |
+
seed: Seed for random number generator for reproducibility.
|
113 |
+
args: argparse arguments.
|
114 |
+
"""
|
115 |
+
|
116 |
+
def __init__(
|
117 |
+
self,
|
118 |
+
dataset,
|
119 |
+
sizes,
|
120 |
+
vocab,
|
121 |
+
mask_idx,
|
122 |
+
mask_whole_words,
|
123 |
+
shuffle,
|
124 |
+
seed,
|
125 |
+
args,
|
126 |
+
eos=None,
|
127 |
+
item_transform_func=None,
|
128 |
+
iid_noise_target=False,
|
129 |
+
uni_mask_idxs=None,
|
130 |
+
):
|
131 |
+
self.dataset = dataset
|
132 |
+
|
133 |
+
self.sizes = sizes
|
134 |
+
|
135 |
+
self.vocab = vocab
|
136 |
+
self.shuffle = shuffle
|
137 |
+
self.seed = seed
|
138 |
+
if iid_noise_target:
|
139 |
+
assert isinstance(uni_mask_idxs, torch.Tensor), "if use iid_noise_target, the uni_mask_idxs must be a tensor which contain the mask indexs"
|
140 |
+
self.iid_noise_target = iid_noise_target
|
141 |
+
self.uni_mask_idxs = uni_mask_idxs
|
142 |
+
self.mask_idx = mask_idx
|
143 |
+
self.mask_whole_word = mask_whole_words
|
144 |
+
self.mask_ratio = args.mask
|
145 |
+
self.random_ratio = args.mask_random
|
146 |
+
self.insert_ratio = args.insert
|
147 |
+
self.rotate_ratio = args.rotate
|
148 |
+
self.permute_sentence_ratio = args.permute_sentences
|
149 |
+
self.eos = eos if eos is not None else vocab.eos()
|
150 |
+
self.item_transform_func = item_transform_func
|
151 |
+
|
152 |
+
if args.bpe != "gpt2":
|
153 |
+
self.full_stop_index = self.vocab.eos()
|
154 |
+
else:
|
155 |
+
assert args.bpe == "gpt2"
|
156 |
+
self.full_stop_index = self.vocab.index("13")
|
157 |
+
|
158 |
+
self.replace_length = args.replace_length
|
159 |
+
if self.replace_length not in [-1, 0, 1]:
|
160 |
+
raise ValueError(f"invalid arg: replace_length={self.replace_length}")
|
161 |
+
if args.mask_length not in ["subword", "word", "span-poisson"]:
|
162 |
+
raise ValueError(f"invalid arg: mask-length={args.mask_length}")
|
163 |
+
if args.mask_length == "subword" and args.replace_length not in [0, 1]:
|
164 |
+
raise ValueError(f"if using subwords, use replace-length=1 or 0")
|
165 |
+
|
166 |
+
self.mask_span_distribution = None
|
167 |
+
if args.mask_length == "span-poisson":
|
168 |
+
_lambda = args.poisson_lambda
|
169 |
+
|
170 |
+
lambda_to_the_k = 1
|
171 |
+
e_to_the_minus_lambda = math.exp(-_lambda)
|
172 |
+
k_factorial = 1
|
173 |
+
ps = []
|
174 |
+
for k in range(0, 128):
|
175 |
+
ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial)
|
176 |
+
lambda_to_the_k *= _lambda
|
177 |
+
k_factorial *= k + 1
|
178 |
+
if ps[-1] < 0.0000001:
|
179 |
+
break
|
180 |
+
ps = torch.FloatTensor(ps)
|
181 |
+
self.mask_span_distribution = torch.distributions.Categorical(ps)
|
182 |
+
|
183 |
+
self.epoch = 0
|
184 |
+
|
185 |
+
@property
|
186 |
+
def can_reuse_epoch_itr_across_epochs(self):
|
187 |
+
return True # only the noise changes, not item sizes
|
188 |
+
|
189 |
+
def set_epoch(self, epoch, **unused):
|
190 |
+
self.epoch = epoch
|
191 |
+
|
192 |
+
def __getitem__(self, index):
|
193 |
+
with data_utils.numpy_seed(self.seed, self.epoch, index):
|
194 |
+
tokens = self.dataset[index]
|
195 |
+
assert tokens[-1] == self.eos
|
196 |
+
source, target = tokens, tokens.clone()
|
197 |
+
|
198 |
+
if self.permute_sentence_ratio > 0.0:
|
199 |
+
source = self.permute_sentences(source, self.permute_sentence_ratio)
|
200 |
+
|
201 |
+
if self.mask_ratio > 0:
|
202 |
+
source, new_target = self.add_whole_word_mask(source, self.mask_ratio)
|
203 |
+
if new_target is not None:
|
204 |
+
target = new_target
|
205 |
+
|
206 |
+
if self.insert_ratio > 0:
|
207 |
+
source = self.add_insertion_noise(source, self.insert_ratio)
|
208 |
+
|
209 |
+
if self.rotate_ratio > 0.0 and np.random.random() < self.rotate_ratio:
|
210 |
+
source = self.add_rolling_noise(source)
|
211 |
+
# there can additional changes to make:
|
212 |
+
if self.item_transform_func is not None:
|
213 |
+
source, target = self.item_transform_func(source, target)
|
214 |
+
|
215 |
+
assert (source >= 0).all()
|
216 |
+
assert (source[1:-1] >= 1).all()
|
217 |
+
assert (source <= len(self.vocab)).all()
|
218 |
+
assert source[0] == self.vocab.bos()
|
219 |
+
assert source[-1] == self.eos
|
220 |
+
return {
|
221 |
+
"id": index,
|
222 |
+
"source": source,
|
223 |
+
"target": target,
|
224 |
+
}
|
225 |
+
|
226 |
+
def __len__(self):
|
227 |
+
return len(self.dataset)
|
228 |
+
|
229 |
+
def permute_sentences(self, source, p=1.0):
|
230 |
+
full_stops = source == self.full_stop_index
|
231 |
+
# Pretend it ends with a full stop so last span is a sentence
|
232 |
+
full_stops[-2] = 1
|
233 |
+
|
234 |
+
# Tokens that are full stops, where the previous token is not
|
235 |
+
sentence_ends = (full_stops[1:] * ~full_stops[:-1]).nonzero(as_tuple=False) + 2
|
236 |
+
result = source.clone()
|
237 |
+
|
238 |
+
num_sentences = sentence_ends.size(0)
|
239 |
+
num_to_permute = math.ceil((num_sentences * 2 * p) / 2.0)
|
240 |
+
substitutions = torch.randperm(num_sentences)[:num_to_permute]
|
241 |
+
ordering = torch.arange(0, num_sentences)
|
242 |
+
ordering[substitutions] = substitutions[torch.randperm(num_to_permute)]
|
243 |
+
|
244 |
+
# Ignore <bos> at start
|
245 |
+
index = 1
|
246 |
+
for i in ordering:
|
247 |
+
sentence = source[(sentence_ends[i - 1] if i > 0 else 1) : sentence_ends[i]]
|
248 |
+
result[index : index + sentence.size(0)] = sentence
|
249 |
+
index += sentence.size(0)
|
250 |
+
return result
|
251 |
+
|
252 |
+
def word_starts(self, source):
|
253 |
+
if self.mask_whole_word is not None:
|
254 |
+
is_word_start = self.mask_whole_word.gather(0, source)
|
255 |
+
else:
|
256 |
+
is_word_start = torch.ones(source.size())
|
257 |
+
is_word_start[0] = 0
|
258 |
+
is_word_start[-1] = 0
|
259 |
+
return is_word_start
|
260 |
+
|
261 |
+
def add_whole_word_mask(self, source, p):
|
262 |
+
source_ori = source.clone()
|
263 |
+
is_word_start = self.word_starts(source)
|
264 |
+
num_to_mask = int(math.ceil(is_word_start.float().sum() * p))
|
265 |
+
num_inserts = 0
|
266 |
+
if num_to_mask == 0:
|
267 |
+
return source
|
268 |
+
|
269 |
+
if self.mask_span_distribution is not None:
|
270 |
+
lengths = self.mask_span_distribution.sample(sample_shape=(num_to_mask,))
|
271 |
+
|
272 |
+
# Make sure we have enough to mask
|
273 |
+
cum_length = torch.cumsum(lengths, 0)
|
274 |
+
while cum_length[-1] < num_to_mask:
|
275 |
+
lengths = torch.cat(
|
276 |
+
[
|
277 |
+
lengths,
|
278 |
+
self.mask_span_distribution.sample(sample_shape=(num_to_mask,)),
|
279 |
+
],
|
280 |
+
dim=0,
|
281 |
+
)
|
282 |
+
cum_length = torch.cumsum(lengths, 0)
|
283 |
+
|
284 |
+
# Trim to masking budget
|
285 |
+
i = 0
|
286 |
+
while cum_length[i] < num_to_mask:
|
287 |
+
i += 1
|
288 |
+
lengths[i] = num_to_mask - (0 if i == 0 else cum_length[i - 1])
|
289 |
+
num_to_mask = i + 1
|
290 |
+
lengths = lengths[:num_to_mask]
|
291 |
+
|
292 |
+
# Handle 0-length mask (inserts) separately
|
293 |
+
lengths = lengths[lengths > 0]
|
294 |
+
num_inserts = num_to_mask - lengths.size(0)
|
295 |
+
num_to_mask -= num_inserts
|
296 |
+
if num_to_mask == 0:
|
297 |
+
return self.add_insertion_noise(source, num_inserts / source.size(0))
|
298 |
+
|
299 |
+
assert (lengths > 0).all()
|
300 |
+
else:
|
301 |
+
lengths = torch.ones((num_to_mask,)).long()
|
302 |
+
assert is_word_start[-1] == 0
|
303 |
+
word_starts = is_word_start.nonzero(as_tuple=False)
|
304 |
+
indices = word_starts[
|
305 |
+
torch.randperm(word_starts.size(0))[:num_to_mask]
|
306 |
+
].squeeze(1)
|
307 |
+
mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio
|
308 |
+
|
309 |
+
source_length = source.size(0)
|
310 |
+
assert source_length - 1 not in indices
|
311 |
+
to_keep = torch.ones(source_length, dtype=torch.bool)
|
312 |
+
is_word_start[
|
313 |
+
-1
|
314 |
+
] = 255 # acts as a long length, so spans don't go over the end of doc
|
315 |
+
if self.replace_length == 0:
|
316 |
+
to_keep[indices] = 0
|
317 |
+
else:
|
318 |
+
# keep index, but replace it with [MASK]
|
319 |
+
source[indices] = self.mask_idx
|
320 |
+
source[indices[mask_random]] = torch.randint(
|
321 |
+
1, len(self.vocab), size=(mask_random.sum(),)
|
322 |
+
)
|
323 |
+
|
324 |
+
if self.mask_span_distribution is not None:
|
325 |
+
assert len(lengths.size()) == 1
|
326 |
+
assert lengths.size() == indices.size()
|
327 |
+
lengths -= 1
|
328 |
+
while indices.size(0) > 0:
|
329 |
+
assert lengths.size() == indices.size()
|
330 |
+
lengths -= is_word_start[indices + 1].long()
|
331 |
+
uncompleted = lengths >= 0
|
332 |
+
indices = indices[uncompleted] + 1
|
333 |
+
mask_random = mask_random[uncompleted]
|
334 |
+
lengths = lengths[uncompleted]
|
335 |
+
if self.replace_length != -1:
|
336 |
+
# delete token
|
337 |
+
to_keep[indices] = 0
|
338 |
+
else:
|
339 |
+
# keep index, but replace it with [MASK]
|
340 |
+
source[indices] = self.mask_idx
|
341 |
+
source[indices[mask_random]] = torch.randint(
|
342 |
+
1, len(self.vocab), size=(mask_random.sum(),)
|
343 |
+
)
|
344 |
+
else:
|
345 |
+
# A bit faster when all lengths are 1
|
346 |
+
while indices.size(0) > 0:
|
347 |
+
uncompleted = is_word_start[indices + 1] == 0
|
348 |
+
indices = indices[uncompleted] + 1
|
349 |
+
mask_random = mask_random[uncompleted]
|
350 |
+
if self.replace_length != -1:
|
351 |
+
# delete token
|
352 |
+
to_keep[indices] = 0
|
353 |
+
else:
|
354 |
+
# keep index, but replace it with [MASK]
|
355 |
+
source[indices] = self.mask_idx
|
356 |
+
source[indices[mask_random]] = torch.randint(
|
357 |
+
1, len(self.vocab), size=(mask_random.sum(),)
|
358 |
+
)
|
359 |
+
|
360 |
+
assert source_length - 1 not in indices
|
361 |
+
|
362 |
+
if not self.iid_noise_target:
|
363 |
+
source = source[to_keep]
|
364 |
+
target = None
|
365 |
+
else:
|
366 |
+
## Prepare source
|
367 |
+
source_mask_idx = (source == self.mask_idx).nonzero().view(-1)
|
368 |
+
source[source_mask_idx] = self.uni_mask_idxs[:source_mask_idx.size(0)]
|
369 |
+
source = source[to_keep]
|
370 |
+
|
371 |
+
## Prepare target
|
372 |
+
to_keep[source_mask_idx] = 0
|
373 |
+
|
374 |
+
# source_mask_idx: from [a, b, c, ...] to [a, b + 1, c + 2, ...]
|
375 |
+
source_mask_idx = source_mask_idx + torch.arange(source_mask_idx.size(0))
|
376 |
+
# target: source_length + mask_length
|
377 |
+
target = source_ori.new_zeros(source_mask_idx.size(0) + source_ori.size(0))
|
378 |
+
# target: [0, 0, 0, X, 0, 0, Y, ....]
|
379 |
+
target[source_mask_idx] = self.uni_mask_idxs[:source_mask_idx.size(0)]
|
380 |
+
|
381 |
+
target_to_keep = to_keep.new_zeros(source_mask_idx.size(0) + source_ori.size(0))
|
382 |
+
|
383 |
+
# Copy original value to target and target_to_keep
|
384 |
+
target_to_keep[target == 0] = to_keep
|
385 |
+
target_to_keep[-1] = 0
|
386 |
+
target[target == 0] = source_ori
|
387 |
+
|
388 |
+
target = target[~target_to_keep]
|
389 |
+
|
390 |
+
if num_inserts > 0:
|
391 |
+
source = self.add_insertion_noise(source, num_inserts / source.size(0))
|
392 |
+
|
393 |
+
return source, target
|
394 |
+
|
395 |
+
def add_permuted_noise(self, tokens, p):
|
396 |
+
num_words = len(tokens)
|
397 |
+
num_to_permute = math.ceil(((num_words * 2) * p) / 2.0)
|
398 |
+
substitutions = torch.randperm(num_words - 2)[:num_to_permute] + 1
|
399 |
+
tokens[substitutions] = tokens[substitutions[torch.randperm(num_to_permute)]]
|
400 |
+
return tokens
|
401 |
+
|
402 |
+
def add_rolling_noise(self, tokens):
|
403 |
+
offset = np.random.randint(1, max(1, tokens.size(-1) - 1) + 1)
|
404 |
+
tokens = torch.cat(
|
405 |
+
(tokens[0:1], tokens[offset:-1], tokens[1:offset], tokens[-1:]),
|
406 |
+
dim=0,
|
407 |
+
)
|
408 |
+
return tokens
|
409 |
+
|
410 |
+
def add_insertion_noise(self, tokens, p):
|
411 |
+
if p == 0.0:
|
412 |
+
return tokens
|
413 |
+
|
414 |
+
num_tokens = len(tokens)
|
415 |
+
n = int(math.ceil(num_tokens * p))
|
416 |
+
|
417 |
+
noise_indices = torch.randperm(num_tokens + n - 2)[:n] + 1
|
418 |
+
noise_mask = torch.zeros(size=(num_tokens + n,), dtype=torch.bool)
|
419 |
+
noise_mask[noise_indices] = 1
|
420 |
+
result = torch.LongTensor(n + len(tokens)).fill_(-1)
|
421 |
+
|
422 |
+
num_random = int(math.ceil(n * self.random_ratio))
|
423 |
+
result[noise_indices[num_random:]] = self.mask_idx
|
424 |
+
result[noise_indices[:num_random]] = torch.randint(
|
425 |
+
low=1, high=len(self.vocab), size=(num_random,)
|
426 |
+
)
|
427 |
+
|
428 |
+
result[~noise_mask] = tokens
|
429 |
+
|
430 |
+
assert (result >= 0).all()
|
431 |
+
return result
|
432 |
+
|
433 |
+
def collater(self, samples, pad_to_length=None):
|
434 |
+
"""Merge a list of samples to form a mini-batch.
|
435 |
+
Args:
|
436 |
+
samples (List[dict]): samples to collate
|
437 |
+
Returns:
|
438 |
+
dict: a mini-batch of data
|
439 |
+
"""
|
440 |
+
return collate(
|
441 |
+
samples, self.vocab.pad(), self.eos, self.vocab, pad_to_length=pad_to_length
|
442 |
+
)
|
443 |
+
|
444 |
+
def num_tokens(self, index):
|
445 |
+
"""Return the number of tokens in a sample. This value is used to
|
446 |
+
enforce ``--max-tokens`` during batching."""
|
447 |
+
return self.sizes[index]
|
448 |
+
|
449 |
+
def size(self, index):
|
450 |
+
"""Return an example's size as a float or tuple. This value is used when
|
451 |
+
filtering a dataset with ``--max-positions``."""
|
452 |
+
return self.sizes[index]
|
453 |
+
|
454 |
+
def ordered_indices(self):
|
455 |
+
"""Return an ordered list of indices. Batches will be constructed based
|
456 |
+
on this order."""
|
457 |
+
if self.shuffle:
|
458 |
+
indices = np.random.permutation(len(self))
|
459 |
+
else:
|
460 |
+
indices = np.arange(len(self))
|
461 |
+
return indices[np.argsort(self.sizes[indices], kind="mergesort")]
|
462 |
+
|
463 |
+
def prefetch(self, indices):
|
464 |
+
self.src.prefetch(indices)
|
465 |
+
self.tgt.prefetch(indices)
|
466 |
+
|
467 |
+
@property
|
468 |
+
def supports_prefetch(self):
|
469 |
+
return (
|
470 |
+
hasattr(self.src, "supports_prefetch")
|
471 |
+
and self.src.supports_prefetch
|
472 |
+
and hasattr(self.tgt, "supports_prefetch")
|
473 |
+
and self.tgt.supports_prefetch
|
474 |
+
)
|
artst/data/text_to_speech_dataset.py
ADDED
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
# Based on speecht5, fairseq and espnet code bases
|
5 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import itertools
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
from typing import Any, List, Optional
|
12 |
+
import mmap
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
import librosa
|
19 |
+
from fairseq.data.audio.speech_to_text_dataset import get_features_or_waveform
|
20 |
+
from fairseq.data import data_utils, Dictionary
|
21 |
+
from fairseq.data.fairseq_dataset import FairseqDataset
|
22 |
+
|
23 |
+
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
|
26 |
+
def _collate_frames(
|
27 |
+
frames: List[torch.Tensor], is_audio_input: bool = False
|
28 |
+
):
|
29 |
+
"""
|
30 |
+
Convert a list of 2D frames into a padded 3D tensor
|
31 |
+
Args:
|
32 |
+
frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is
|
33 |
+
length of i-th frame and f_dim is static dimension of features
|
34 |
+
Returns:
|
35 |
+
3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
|
36 |
+
"""
|
37 |
+
max_len = max(frame.size(0) for frame in frames)
|
38 |
+
if is_audio_input:
|
39 |
+
out = frames[0].new_zeros((len(frames), max_len))
|
40 |
+
else:
|
41 |
+
out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1)))
|
42 |
+
for i, v in enumerate(frames):
|
43 |
+
out[i, : v.size(0)] = v
|
44 |
+
return out
|
45 |
+
|
46 |
+
def load_audio(manifest_path, max_keep, min_keep):
|
47 |
+
n_long, n_short = 0, 0
|
48 |
+
names, inds, sizes, spk_embeds = [], [], [], []
|
49 |
+
with open(manifest_path) as f:
|
50 |
+
root = f.readline().strip()
|
51 |
+
for ind, line in enumerate(f):
|
52 |
+
items = line.strip().split("\t")
|
53 |
+
assert len(items) == 3, line
|
54 |
+
sz = int(items[1])
|
55 |
+
if min_keep is not None and sz < min_keep:
|
56 |
+
n_short += 1
|
57 |
+
elif max_keep is not None and sz > max_keep:
|
58 |
+
n_long += 1
|
59 |
+
else:
|
60 |
+
names.append(items[0])
|
61 |
+
spk_embeds.append(items[2])
|
62 |
+
inds.append(ind)
|
63 |
+
sizes.append(sz)
|
64 |
+
tot = ind + 1
|
65 |
+
logger.info(
|
66 |
+
(
|
67 |
+
f"max_keep={max_keep}, min_keep={min_keep}, "
|
68 |
+
f"loaded {len(names)}, skipped {n_short} short and {n_long} long, "
|
69 |
+
f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
|
70 |
+
)
|
71 |
+
)
|
72 |
+
return root, names, inds, tot, sizes, spk_embeds
|
73 |
+
|
74 |
+
|
75 |
+
def load_label(label_path, inds, tot):
|
76 |
+
with open(label_path) as f:
|
77 |
+
labels = [line.rstrip() for line in f]
|
78 |
+
assert (
|
79 |
+
len(labels) == tot
|
80 |
+
), f"number of labels does not match ({len(labels)} != {tot})"
|
81 |
+
labels = [labels[i] for i in inds]
|
82 |
+
return labels
|
83 |
+
|
84 |
+
|
85 |
+
def load_label_offset(label_path, inds, tot):
|
86 |
+
with open(label_path, encoding='utf-8') as f:
|
87 |
+
code_lengths = [len(line.encode("utf-8")) for line in f] #changed as in speech_to_text_dataset.py
|
88 |
+
assert (
|
89 |
+
len(code_lengths) == tot
|
90 |
+
), f"number of labels does not match ({len(code_lengths)} != {tot})"
|
91 |
+
offsets = list(itertools.accumulate([0] + code_lengths))
|
92 |
+
offsets = [(offsets[i], offsets[i + 1]) for i in inds]
|
93 |
+
return offsets
|
94 |
+
|
95 |
+
|
96 |
+
def logmelfilterbank(
|
97 |
+
audio,
|
98 |
+
sampling_rate,
|
99 |
+
fft_size=1024,
|
100 |
+
hop_size=256,
|
101 |
+
win_length=None,
|
102 |
+
window="hann",
|
103 |
+
num_mels=80,
|
104 |
+
fmin=80,
|
105 |
+
fmax=7600,
|
106 |
+
eps=1e-10,
|
107 |
+
):
|
108 |
+
"""Compute log-Mel filterbank feature.
|
109 |
+
(https://github.com/kan-bayashi/ParallelWaveGAN/blob/master/parallel_wavegan/bin/preprocess.py)
|
110 |
+
|
111 |
+
Args:
|
112 |
+
audio (ndarray): Audio signal (T,).
|
113 |
+
sampling_rate (int): Sampling rate.
|
114 |
+
fft_size (int): FFT size.
|
115 |
+
hop_size (int): Hop size.
|
116 |
+
win_length (int): Window length. If set to None, it will be the same as fft_size.
|
117 |
+
window (str): Window function type.
|
118 |
+
num_mels (int): Number of mel basis.
|
119 |
+
fmin (int): Minimum frequency in mel basis calculation.
|
120 |
+
fmax (int): Maximum frequency in mel basis calculation.
|
121 |
+
eps (float): Epsilon value to avoid inf in log calculation.
|
122 |
+
|
123 |
+
Returns:
|
124 |
+
ndarray: Log Mel filterbank feature (#frames, num_mels).
|
125 |
+
|
126 |
+
"""
|
127 |
+
# get amplitude spectrogram
|
128 |
+
x_stft = librosa.stft(audio, n_fft=fft_size, hop_length=hop_size,
|
129 |
+
win_length=win_length, window=window, pad_mode="reflect")
|
130 |
+
spc = np.abs(x_stft).T # (#frames, #bins)
|
131 |
+
|
132 |
+
# get mel basis
|
133 |
+
fmin = 0 if fmin is None else fmin
|
134 |
+
fmax = sampling_rate / 2 if fmax is None else fmax
|
135 |
+
mel_basis = librosa.filters.mel(sr=sampling_rate, n_fft=fft_size, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
136 |
+
|
137 |
+
return np.log10(np.maximum(eps, np.dot(spc, mel_basis.T)))
|
138 |
+
|
139 |
+
|
140 |
+
|
141 |
+
class TextToSpeechDataset(FairseqDataset):
|
142 |
+
def __init__(
|
143 |
+
self,
|
144 |
+
manifest_path: str,
|
145 |
+
sample_rate: float,
|
146 |
+
label_paths: List[str],
|
147 |
+
label_processors: Optional[List[Any]] = None,
|
148 |
+
max_keep_sample_size: Optional[int] = None,
|
149 |
+
min_keep_sample_size: Optional[int] = None,
|
150 |
+
shuffle: bool = True,
|
151 |
+
normalize: bool = False,
|
152 |
+
store_labels: bool = True,
|
153 |
+
src_dict: Optional[Dictionary] = None,
|
154 |
+
tokenizer = None,
|
155 |
+
reduction_factor: int = 1,
|
156 |
+
inference: bool = False,
|
157 |
+
):
|
158 |
+
|
159 |
+
self.audio_root, self.audio_names, inds, tot, self.wav_sizes, self.spk_embeds = load_audio(
|
160 |
+
manifest_path, max_keep_sample_size, min_keep_sample_size
|
161 |
+
)
|
162 |
+
self.inference = inference
|
163 |
+
|
164 |
+
self.sample_rate = sample_rate
|
165 |
+
self.shuffle = shuffle
|
166 |
+
self.src_dict = src_dict
|
167 |
+
self.tokenizer = tokenizer
|
168 |
+
|
169 |
+
self.num_labels = len(label_paths)
|
170 |
+
self.label_processors = label_processors
|
171 |
+
self.store_labels = store_labels
|
172 |
+
if store_labels:
|
173 |
+
self.label_list = [load_label(p, inds, tot) for p in label_paths]
|
174 |
+
else:
|
175 |
+
self.label_paths = label_paths
|
176 |
+
self.label_offsets_list = [
|
177 |
+
load_label_offset(p, inds, tot) for p in label_paths
|
178 |
+
]
|
179 |
+
assert label_processors is None or len(label_processors) == self.num_labels
|
180 |
+
|
181 |
+
self.normalize = normalize
|
182 |
+
self.reduction_factor = reduction_factor
|
183 |
+
logger.info(
|
184 |
+
f"reduction_factor={reduction_factor}, normalize={normalize}"
|
185 |
+
)
|
186 |
+
|
187 |
+
def get_audio(self, index):
|
188 |
+
import soundfile as sf
|
189 |
+
|
190 |
+
wav_path = os.path.join(self.audio_root, self.audio_names[index])
|
191 |
+
wav, cur_sample_rate = sf.read(wav_path)
|
192 |
+
wav = torch.from_numpy(wav).float()
|
193 |
+
fbank = logmelfilterbank(
|
194 |
+
wav.view(-1).cpu().numpy(), 16000
|
195 |
+
)
|
196 |
+
fbank = torch.from_numpy(fbank).float()
|
197 |
+
wav = self.postprocess(wav, cur_sample_rate)
|
198 |
+
return wav, fbank
|
199 |
+
|
200 |
+
def get_label(self, index, label_idx):
|
201 |
+
if self.store_labels:
|
202 |
+
label = self.label_list[label_idx][index]
|
203 |
+
else:
|
204 |
+
# with open(self.label_paths[label_idx]) as f:
|
205 |
+
# offset_s, offset_e = self.label_offsets_list[label_idx][index]
|
206 |
+
# f.seek(offset_s)
|
207 |
+
# label = f.read(offset_e - offset_s)
|
208 |
+
|
209 |
+
# Hawau:
|
210 |
+
# mmap method
|
211 |
+
with open(self.label_paths[label_idx], encoding='utf-8') as f:
|
212 |
+
with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm:
|
213 |
+
offset_s, offset_e = self.label_offsets_list[label_idx][index]
|
214 |
+
label = mm[offset_s:offset_e].decode("utf-8")
|
215 |
+
|
216 |
+
|
217 |
+
if self.tokenizer is not None:
|
218 |
+
label = self.tokenizer.encode(label)
|
219 |
+
|
220 |
+
if self.label_processors is not None:
|
221 |
+
label = self.label_processors[label_idx](label)
|
222 |
+
return label
|
223 |
+
|
224 |
+
def get_labels(self, index):
|
225 |
+
return [self.get_label(index, i) for i in range(self.num_labels)]
|
226 |
+
|
227 |
+
def __getitem__(self, index):
|
228 |
+
wav, fbank = self.get_audio(index)
|
229 |
+
labels = self.get_labels(index)
|
230 |
+
spkembs = get_features_or_waveform(
|
231 |
+
os.path.join(self.audio_root, self.spk_embeds[index])
|
232 |
+
)
|
233 |
+
spkembs = torch.from_numpy(spkembs).float()
|
234 |
+
|
235 |
+
return {"id": index, "source": labels, "target": fbank, "spkembs": spkembs, "audio_name": self.audio_names[index]}
|
236 |
+
|
237 |
+
|
238 |
+
def __len__(self):
|
239 |
+
return len(self.wav_sizes)
|
240 |
+
|
241 |
+
def collater(self, samples):
|
242 |
+
samples = [s for s in samples if s["source"] is not None]
|
243 |
+
if len(samples) == 0:
|
244 |
+
return {}
|
245 |
+
|
246 |
+
fbanks = [s["target"] for s in samples]
|
247 |
+
fbank_sizes = [len(s) for s in fbanks]
|
248 |
+
|
249 |
+
collated_fbanks = _collate_frames(fbanks)
|
250 |
+
collated_fbanks_size = torch.tensor(fbank_sizes, dtype=torch.long)
|
251 |
+
|
252 |
+
# thin out frames for reduction factor (B, Lmax, odim) -> (B, Lmax//r, odim)
|
253 |
+
if self.reduction_factor > 1:
|
254 |
+
collated_fbanks_in = collated_fbanks[:, self.reduction_factor - 1 :: self.reduction_factor]
|
255 |
+
collated_fbanks_size_in = collated_fbanks_size.new([torch.div(olen, self.reduction_factor, rounding_mode='floor') for olen in collated_fbanks_size])
|
256 |
+
else:
|
257 |
+
collated_fbanks_in, collated_fbanks_size_in = collated_fbanks, collated_fbanks_size
|
258 |
+
|
259 |
+
prev_output_tokens = torch.cat(
|
260 |
+
[collated_fbanks_in.new_zeros((collated_fbanks_in.shape[0], 1, collated_fbanks_in.shape[2])), collated_fbanks_in[:, :-1]], dim=1
|
261 |
+
)
|
262 |
+
|
263 |
+
# make labels for stop prediction
|
264 |
+
labels = collated_fbanks.new_zeros(collated_fbanks.size(0), collated_fbanks.size(1))
|
265 |
+
for i, l in enumerate(fbank_sizes):
|
266 |
+
labels[i, l - 1 :] = 1.0
|
267 |
+
|
268 |
+
spkembs = _collate_frames([s["spkembs"] for s in samples], is_audio_input=True)
|
269 |
+
|
270 |
+
sources_by_label = [
|
271 |
+
[s["source"][i] for s in samples] for i in range(self.num_labels)
|
272 |
+
]
|
273 |
+
sources_list, lengths_list, ntokens_list = self.collater_label(sources_by_label)
|
274 |
+
|
275 |
+
net_input = {
|
276 |
+
"src_tokens": sources_list[0],
|
277 |
+
"src_lengths": lengths_list[0],
|
278 |
+
"prev_output_tokens": prev_output_tokens,
|
279 |
+
"tgt_lengths": collated_fbanks_size_in,
|
280 |
+
"spkembs": spkembs,
|
281 |
+
"task_name": "t2s",
|
282 |
+
}
|
283 |
+
batch = {
|
284 |
+
"id": torch.LongTensor([s["id"] for s in samples]),
|
285 |
+
"name": [s["audio_name"] for s in samples],
|
286 |
+
"net_input": net_input,
|
287 |
+
"labels": labels,
|
288 |
+
"dec_target": collated_fbanks,
|
289 |
+
"dec_target_lengths": collated_fbanks_size,
|
290 |
+
"src_lengths": lengths_list[0],
|
291 |
+
"task_name": "t2s",
|
292 |
+
"ntokens": ntokens_list[0],
|
293 |
+
"target": collated_fbanks,
|
294 |
+
}
|
295 |
+
|
296 |
+
return batch
|
297 |
+
|
298 |
+
def collater_seq_label(self, targets, pad):
|
299 |
+
lengths = torch.LongTensor([len(t) for t in targets])
|
300 |
+
ntokens = lengths.sum().item()
|
301 |
+
targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
|
302 |
+
return targets, lengths, ntokens
|
303 |
+
|
304 |
+
def collater_label(self, targets_by_label):
|
305 |
+
targets_list, lengths_list, ntokens_list = [], [], []
|
306 |
+
itr = zip(targets_by_label, [self.src_dict.pad()])
|
307 |
+
for targets, pad in itr:
|
308 |
+
targets, lengths, ntokens = self.collater_seq_label(targets, pad)
|
309 |
+
targets_list.append(targets)
|
310 |
+
lengths_list.append(lengths)
|
311 |
+
ntokens_list.append(ntokens)
|
312 |
+
return targets_list, lengths_list, ntokens_list
|
313 |
+
|
314 |
+
def num_tokens(self, index):
|
315 |
+
return self.size(index)
|
316 |
+
|
317 |
+
def size(self, index):
|
318 |
+
return self.wav_sizes[index]
|
319 |
+
|
320 |
+
@property
|
321 |
+
def sizes(self):
|
322 |
+
return np.array(self.wav_sizes)
|
323 |
+
|
324 |
+
def ordered_indices(self):
|
325 |
+
if self.shuffle:
|
326 |
+
order = [np.random.permutation(len(self))]
|
327 |
+
else:
|
328 |
+
order = [np.arange(len(self))]
|
329 |
+
|
330 |
+
order.append(self.wav_sizes)
|
331 |
+
return np.lexsort(order)[::-1]
|
332 |
+
|
333 |
+
def postprocess(self, wav, cur_sample_rate):
|
334 |
+
if wav.dim() == 2:
|
335 |
+
wav = wav.mean(-1)
|
336 |
+
assert wav.dim() == 1, wav.dim()
|
337 |
+
|
338 |
+
if cur_sample_rate != self.sample_rate:
|
339 |
+
raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
|
340 |
+
|
341 |
+
if self.normalize:
|
342 |
+
with torch.no_grad():
|
343 |
+
wav = F.layer_norm(wav, wav.shape)
|
344 |
+
return wav
|
artst/models/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .artst import * # noqa
|
2 |
+
from .t5_transformer_lm import * # noqa
|
artst/models/artst.py
ADDED
@@ -0,0 +1,1448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
# Based on speecht5, fairseq and espnet code bases
|
5 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import logging
|
9 |
+
from ast import literal_eval
|
10 |
+
from typing import Dict, List, Optional, Tuple
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from fairseq import utils
|
15 |
+
from fairseq.models import (
|
16 |
+
FairseqEncoderDecoderModel,
|
17 |
+
FairseqIncrementalDecoder,
|
18 |
+
register_model,
|
19 |
+
register_model_architecture,
|
20 |
+
)
|
21 |
+
from .modules.text_encoder_prenet import TextEncoderPrenet
|
22 |
+
from .modules.text_decoder_prenet import TextDecoderPrenet
|
23 |
+
from .modules.text_decoder_postnet import TextDecoderPostnet
|
24 |
+
from .modules.speech_encoder_prenet import SpeechEncoderPrenet
|
25 |
+
from .modules.speech_encoder_postnet import SpeechEncoderPostnet
|
26 |
+
from .modules.speech_decoder_prenet import SpeechDecoderPrenet
|
27 |
+
from .modules.speech_decoder_postnet import SpeechDecoderPostnet
|
28 |
+
from .modules.speaker_decoder_postnet import SpeakerDecoderPostnet
|
29 |
+
from .modules.encoder import TransformerEncoder
|
30 |
+
from .modules.decoder import TransformerDecoder
|
31 |
+
from fairseq.modules.transformer_sentence_encoder import init_bert_params
|
32 |
+
from fairseq.models.transformer import Embedding
|
33 |
+
from fairseq.modules import (
|
34 |
+
GumbelVectorQuantizer,
|
35 |
+
)
|
36 |
+
from torch import Tensor
|
37 |
+
|
38 |
+
|
39 |
+
logger = logging.getLogger(__name__)
|
40 |
+
|
41 |
+
DEFAULT_MAX_TEXT_POSITIONS = 450
|
42 |
+
DEFAULT_MAX_SPEECH_POSITIONS = 4000
|
43 |
+
|
44 |
+
|
45 |
+
@register_model("artst_transformer")
|
46 |
+
class ArTSTTransformerModel(FairseqEncoderDecoderModel):
|
47 |
+
"""Adapted Transformer model (https://arxiv.org/abs/1706.03762) for
|
48 |
+
speech-to-text tasks. The Transformer encoder/decoder remains the same.
|
49 |
+
A trainable input subsampler is prepended to the Transformer encoder to
|
50 |
+
project inputs into the encoder dimension as well as downsample input
|
51 |
+
sequence for computational efficiency."""
|
52 |
+
|
53 |
+
def __init__(
|
54 |
+
self,
|
55 |
+
args,
|
56 |
+
encoder, decoder,
|
57 |
+
text_encoder_prenet, speech_encoder_prenet,
|
58 |
+
text_decoder_prenet, speech_decoder_prenet,
|
59 |
+
text_decoder_postnet, speech_decoder_postnet,
|
60 |
+
speaker_decoder_postnet, speech_encoder_postnet,
|
61 |
+
):
|
62 |
+
super().__init__(encoder, decoder)
|
63 |
+
|
64 |
+
self.encoder = encoder
|
65 |
+
self.decoder = decoder
|
66 |
+
|
67 |
+
self.text_encoder_prenet = text_encoder_prenet
|
68 |
+
self.speech_encoder_prenet = speech_encoder_prenet
|
69 |
+
|
70 |
+
self.text_decoder_prenet = text_decoder_prenet
|
71 |
+
self.speech_decoder_prenet = speech_decoder_prenet
|
72 |
+
|
73 |
+
self.text_decoder_postnet = text_decoder_postnet
|
74 |
+
self.speech_decoder_postnet = speech_decoder_postnet
|
75 |
+
self.speaker_decoder_postnet = speaker_decoder_postnet
|
76 |
+
|
77 |
+
self.hubert_layer = speech_encoder_postnet
|
78 |
+
|
79 |
+
self.reduction_factor = args.reduction_factor
|
80 |
+
self.spk_embed_dim = args.spk_embed_dim
|
81 |
+
|
82 |
+
# define projection layer
|
83 |
+
self.spk_embed_integration_type = args.spk_embed_integration_type
|
84 |
+
if self.spk_embed_dim is not None and self.spk_embed_integration_type != 'pre':
|
85 |
+
if self.spk_embed_integration_type == "add":
|
86 |
+
self.projection = torch.nn.Linear(self.spk_embed_dim, args.decoder_embed_dim)
|
87 |
+
else:
|
88 |
+
self.projection = torch.nn.Linear(
|
89 |
+
args.decoder_embed_dim + self.spk_embed_dim, args.decoder_embed_dim
|
90 |
+
)
|
91 |
+
|
92 |
+
# Hawau: here we can add language embedding integration
|
93 |
+
|
94 |
+
self.use_codebook = args.use_codebook
|
95 |
+
self.codebook_prob = getattr(args, "codebook_prob", 0.5) # args.codebook_prob
|
96 |
+
if self.use_codebook:
|
97 |
+
vq_dim = args.latent_dim if args.latent_dim > 0 else args.encoder_embed_dim
|
98 |
+
self.quantizer = GumbelVectorQuantizer(
|
99 |
+
dim=args.encoder_embed_dim,
|
100 |
+
num_vars=args.latent_vars,
|
101 |
+
temp=args.latent_temp,
|
102 |
+
groups=args.latent_groups,
|
103 |
+
combine_groups=False,
|
104 |
+
vq_dim=vq_dim,
|
105 |
+
time_first=True,
|
106 |
+
weight_proj_depth=args.quantizer_depth,
|
107 |
+
weight_proj_factor=args.quantizer_factor,
|
108 |
+
)
|
109 |
+
|
110 |
+
self.num_updates = 0
|
111 |
+
|
112 |
+
# # Follow BERT's random weight initialization (for BART)
|
113 |
+
if args.bert_init:
|
114 |
+
self.apply(init_bert_params)
|
115 |
+
self.args = args
|
116 |
+
self.prune_modules(args.modules_filter)
|
117 |
+
|
118 |
+
@staticmethod
|
119 |
+
def add_args(parser):
|
120 |
+
"""Add model-specific arguments to the parser."""
|
121 |
+
# Transformer
|
122 |
+
parser.add_argument(
|
123 |
+
"--activation-fn",
|
124 |
+
type=str,
|
125 |
+
choices=utils.get_available_activation_fns(),
|
126 |
+
help="activation function to use",
|
127 |
+
)
|
128 |
+
parser.add_argument(
|
129 |
+
"--dropout", type=float, metavar="D", help="dropout probability"
|
130 |
+
)
|
131 |
+
parser.add_argument(
|
132 |
+
"--attention-dropout",
|
133 |
+
type=float,
|
134 |
+
metavar="D",
|
135 |
+
help="dropout probability for attention weights",
|
136 |
+
)
|
137 |
+
parser.add_argument(
|
138 |
+
"--activation-dropout",
|
139 |
+
"--relu-dropout",
|
140 |
+
type=float,
|
141 |
+
metavar="D",
|
142 |
+
help="dropout probability after activation in FFN.",
|
143 |
+
)
|
144 |
+
parser.add_argument(
|
145 |
+
"--encoder-embed-dim",
|
146 |
+
type=int,
|
147 |
+
metavar="N",
|
148 |
+
help="encoder embedding dimension",
|
149 |
+
)
|
150 |
+
parser.add_argument(
|
151 |
+
"--encoder-ffn-embed-dim",
|
152 |
+
type=int,
|
153 |
+
metavar="N",
|
154 |
+
help="encoder embedding dimension for FFN",
|
155 |
+
)
|
156 |
+
parser.add_argument(
|
157 |
+
"--encoder-layers", type=int, metavar="N", help="num encoder layers"
|
158 |
+
)
|
159 |
+
parser.add_argument(
|
160 |
+
"--encoder-attention-heads",
|
161 |
+
type=int,
|
162 |
+
metavar="N",
|
163 |
+
help="num encoder attention heads",
|
164 |
+
)
|
165 |
+
parser.add_argument(
|
166 |
+
"--encoder-normalize-before",
|
167 |
+
action="store_true",
|
168 |
+
help="apply layernorm before each encoder block",
|
169 |
+
)
|
170 |
+
parser.add_argument(
|
171 |
+
"--decoder-normalize-before",
|
172 |
+
action="store_true",
|
173 |
+
help="apply layernorm before each decoder block",
|
174 |
+
)
|
175 |
+
parser.add_argument(
|
176 |
+
"--decoder-embed-dim",
|
177 |
+
type=int,
|
178 |
+
metavar="N",
|
179 |
+
help="decoder embedding dimension",
|
180 |
+
)
|
181 |
+
parser.add_argument(
|
182 |
+
"--decoder-ffn-embed-dim",
|
183 |
+
type=int,
|
184 |
+
metavar="N",
|
185 |
+
help="decoder embedding dimension for FFN",
|
186 |
+
)
|
187 |
+
parser.add_argument(
|
188 |
+
"--decoder-layers", type=int, metavar="N", help="num decoder layers"
|
189 |
+
)
|
190 |
+
parser.add_argument(
|
191 |
+
"--decoder-attention-heads",
|
192 |
+
type=int,
|
193 |
+
metavar="N",
|
194 |
+
help="num decoder attention heads",
|
195 |
+
)
|
196 |
+
parser.add_argument(
|
197 |
+
"--reduction-factor",
|
198 |
+
type=int,
|
199 |
+
help="reduction factor for decoder",
|
200 |
+
)
|
201 |
+
parser.add_argument(
|
202 |
+
"--spk-embed-dim",
|
203 |
+
type=int,
|
204 |
+
help="speaker embedding dimension",
|
205 |
+
)
|
206 |
+
parser.add_argument(
|
207 |
+
"--layernorm-embedding",
|
208 |
+
action="store_true",
|
209 |
+
help="add layernorm to embedding",
|
210 |
+
)
|
211 |
+
parser.add_argument(
|
212 |
+
"--load-pretrained-encoder-from",
|
213 |
+
type=str,
|
214 |
+
metavar="STR",
|
215 |
+
help="model to take encoder weights from (for initialization)",
|
216 |
+
)
|
217 |
+
parser.add_argument(
|
218 |
+
'--freeze-encoder-updates',
|
219 |
+
type=int,
|
220 |
+
help='number of steps to freeze encoder before finetune'
|
221 |
+
)
|
222 |
+
parser.add_argument(
|
223 |
+
'--freeze-decoder-updates',
|
224 |
+
type=int,
|
225 |
+
help='number of steps to freeze decoder before finetune'
|
226 |
+
)
|
227 |
+
parser.add_argument(
|
228 |
+
'--no-freeze-encoder-layer',
|
229 |
+
type=str,
|
230 |
+
help='which encoder layer not freeze during finetune'
|
231 |
+
)
|
232 |
+
parser.add_argument(
|
233 |
+
"--share-input-output-embed",
|
234 |
+
action="store_true",
|
235 |
+
help="share decoder input and output embeddings",
|
236 |
+
)
|
237 |
+
parser.add_argument(
|
238 |
+
"--share-ctc-embed",
|
239 |
+
action="store_true",
|
240 |
+
help="share ctc embed and decoder embed",
|
241 |
+
)
|
242 |
+
parser.add_argument(
|
243 |
+
"--encoder-sliding-window-attn",
|
244 |
+
default=None,
|
245 |
+
type=int,
|
246 |
+
help="If not None but a even number, set sliding window attention to encoder's attn_mask, e.g., 4, 10, and 20",
|
247 |
+
)
|
248 |
+
|
249 |
+
# Convolutional subsampler
|
250 |
+
parser.add_argument(
|
251 |
+
"--encoder-speech-prenet",
|
252 |
+
default="conv",
|
253 |
+
type=str,
|
254 |
+
choices=["conv", "linear"],
|
255 |
+
help="The type of encoder speech prenet, e.g., conv or linear."
|
256 |
+
)
|
257 |
+
parser.add_argument(
|
258 |
+
"--conv-kernel-sizes",
|
259 |
+
default="5,5",
|
260 |
+
type=str,
|
261 |
+
help="The layer of convolution of encoder speech prenet."
|
262 |
+
)
|
263 |
+
parser.add_argument(
|
264 |
+
"--conv-channels",
|
265 |
+
default=1024,
|
266 |
+
type=int,
|
267 |
+
help="The channels of encoder speech prenet."
|
268 |
+
)
|
269 |
+
parser.add_argument(
|
270 |
+
"--subsample-stride",
|
271 |
+
default="2,2",
|
272 |
+
type=str,
|
273 |
+
help="The subsample stride for conv1dsubsample."
|
274 |
+
)
|
275 |
+
parser.add_argument(
|
276 |
+
"--spk-embed-integration-type",
|
277 |
+
type=str,
|
278 |
+
choices=["pre", "add"],
|
279 |
+
help="speaker embedding integration type"
|
280 |
+
)
|
281 |
+
parser.add_argument(
|
282 |
+
"--dprenet-dropout-rate",
|
283 |
+
default=0.5,
|
284 |
+
type=float,
|
285 |
+
help="The dropout rate of decoder speech prenet."
|
286 |
+
)
|
287 |
+
|
288 |
+
## SE
|
289 |
+
parser.add_argument(
|
290 |
+
"--se-predict",
|
291 |
+
default=None,
|
292 |
+
choices=["masking", "target", "delta"],
|
293 |
+
help="If set, source speech inputs decoder to predict the masking/target/delta of corresponding inputs."
|
294 |
+
+ "masking is [0, 1], target is predicted output, delta is difference between inputs and outputs",
|
295 |
+
)
|
296 |
+
parser.add_argument(
|
297 |
+
"--se-decoder-input",
|
298 |
+
type=str,
|
299 |
+
default="previous_target",
|
300 |
+
choices=["previous_target", "source"],
|
301 |
+
)
|
302 |
+
|
303 |
+
## SID
|
304 |
+
parser.add_argument(
|
305 |
+
"--modules-filter",
|
306 |
+
default=None,
|
307 |
+
type=str,
|
308 |
+
help="Remove unused modules for, e.g., SID.",
|
309 |
+
)
|
310 |
+
parser.add_argument(
|
311 |
+
"--sid-pad-prenet",
|
312 |
+
action="store_true",
|
313 |
+
help="If set, the size of text dictionary is as small as for <pad> token.",
|
314 |
+
)
|
315 |
+
parser.add_argument(
|
316 |
+
"--encoder-attn-branch",
|
317 |
+
type=str,
|
318 |
+
default="identity,full",
|
319 |
+
help="encoder attention branch sliding window, e.g., 'identity,0,2,4,full'",
|
320 |
+
)
|
321 |
+
parser.add_argument(
|
322 |
+
"--encoder-block-branch",
|
323 |
+
type=str,
|
324 |
+
help="average the output of encoder, e.g., '4,5,6'",
|
325 |
+
)
|
326 |
+
parser.add_argument(
|
327 |
+
"--sid-encoder-cls",
|
328 |
+
default=None,
|
329 |
+
choices=["encoder"],
|
330 |
+
help="If set, add cls vector to the encoder input, e.g., constant vector.",
|
331 |
+
)
|
332 |
+
parser.add_argument(
|
333 |
+
"--sid-shuffle-encoder-input",
|
334 |
+
action="store_true",
|
335 |
+
help="If set, shuffle encoder input in time.",
|
336 |
+
)
|
337 |
+
parser.add_argument(
|
338 |
+
"--sid-decoder-speaker",
|
339 |
+
action="store_true",
|
340 |
+
help="If set, apply speaker decoder as transformer decoder.",
|
341 |
+
)
|
342 |
+
parser.add_argument(
|
343 |
+
"--sid-decoder-attn-dim",
|
344 |
+
default=128,
|
345 |
+
type=int,
|
346 |
+
help="Attention dimension in attensive statistics pooling of speaker decoder.",
|
347 |
+
)
|
348 |
+
parser.add_argument(
|
349 |
+
"--sid-t5-postnet",
|
350 |
+
action="store_true",
|
351 |
+
help="If set, apply TextDecoderPostnet as speaker classification.",
|
352 |
+
)
|
353 |
+
parser.add_argument(
|
354 |
+
"--sid-embed-dim",
|
355 |
+
default=128,
|
356 |
+
type=int,
|
357 |
+
help="Embedding dimension in speaker postnet for speaker identification if embed postnet.",
|
358 |
+
)
|
359 |
+
parser.add_argument(
|
360 |
+
"--sid-pooling-layer",
|
361 |
+
default="decoder",
|
362 |
+
type=str,
|
363 |
+
choices=["decoder-las", "decoder", "encoder", "encoder-cls", "encoder-speaker"],
|
364 |
+
help="The output of decoder or encoder uses as SID pooling layer over temporal dimension.",
|
365 |
+
)
|
366 |
+
parser.add_argument(
|
367 |
+
"--sid-no-pooling-bn",
|
368 |
+
action="store_true",
|
369 |
+
help="If set, not attention batchnorm.",
|
370 |
+
)
|
371 |
+
parser.add_argument(
|
372 |
+
"--sid-no-embed-postnet",
|
373 |
+
action="store_true",
|
374 |
+
help="If set, no layer between decoder output and classification layer.",
|
375 |
+
)
|
376 |
+
parser.add_argument(
|
377 |
+
"--sid-normalize-postnet",
|
378 |
+
action="store_true",
|
379 |
+
help="If set, normalize input and weight in postnet/classifier.",
|
380 |
+
)
|
381 |
+
parser.add_argument(
|
382 |
+
"--sid-softmax-type",
|
383 |
+
default="softmax",
|
384 |
+
choices=["softmax", "amsoftmax", "aamsoftmax"],
|
385 |
+
help="If using amsoftmax or aamsoftmax, the target should be given.",
|
386 |
+
)
|
387 |
+
parser.add_argument(
|
388 |
+
"--softmax-scale",
|
389 |
+
default=1.0,
|
390 |
+
type=float,
|
391 |
+
help="Scale for AMSoftmax or AAMSoftmax.",
|
392 |
+
)
|
393 |
+
parser.add_argument(
|
394 |
+
"--softmax-margin",
|
395 |
+
default=0.0,
|
396 |
+
type=float,
|
397 |
+
help="Margin for AMSoftmax or AAMSoftmax.",
|
398 |
+
)
|
399 |
+
parser.add_argument(
|
400 |
+
"--softmax-easy-margin",
|
401 |
+
action="store_true",
|
402 |
+
help="Enable easy margin for AAMSoftmax.",
|
403 |
+
)
|
404 |
+
parser.add_argument(
|
405 |
+
"--encoder-layerdrop",
|
406 |
+
type=float,
|
407 |
+
metavar="D",
|
408 |
+
help="LayerDrop probability for encoder",
|
409 |
+
)
|
410 |
+
parser.add_argument(
|
411 |
+
"--decoder-layerdrop",
|
412 |
+
type=float,
|
413 |
+
metavar="D",
|
414 |
+
help="LayerDrop probability for decoder",
|
415 |
+
)
|
416 |
+
|
417 |
+
## Hubert
|
418 |
+
parser.add_argument(
|
419 |
+
'--feature-grad-mult',
|
420 |
+
type=float,
|
421 |
+
help='multiply feature extractor var grads by this'
|
422 |
+
)
|
423 |
+
parser.add_argument(
|
424 |
+
'--logit-temp',
|
425 |
+
type=float,
|
426 |
+
help='temperature to divide logits by'
|
427 |
+
)
|
428 |
+
parser.add_argument(
|
429 |
+
'--final-dim',
|
430 |
+
type=int,
|
431 |
+
help="project final representations and targets to this many "
|
432 |
+
"dimensions. set to encoder_embed_dim is <= 0"
|
433 |
+
)
|
434 |
+
|
435 |
+
# mask
|
436 |
+
parser.add_argument(
|
437 |
+
'--hubert-mask-length',
|
438 |
+
type=int,
|
439 |
+
help='mask length'
|
440 |
+
)
|
441 |
+
parser.add_argument(
|
442 |
+
'--mask-prob',
|
443 |
+
type=float,
|
444 |
+
help='probability of replacing a token with mask'
|
445 |
+
)
|
446 |
+
parser.add_argument(
|
447 |
+
"--mask-selection",
|
448 |
+
choices=["static", "uniform", "normal", "poisson"],
|
449 |
+
help="how to choose mask length",
|
450 |
+
)
|
451 |
+
parser.add_argument(
|
452 |
+
'--mask-other',
|
453 |
+
type=float,
|
454 |
+
help="secondary mask argument "
|
455 |
+
"(used for more complex distributions), "
|
456 |
+
"see help in compute_mask_indices"
|
457 |
+
)
|
458 |
+
parser.add_argument(
|
459 |
+
'--mask-min-space',
|
460 |
+
type=int,
|
461 |
+
help='min space between spans (if no overlap is enabled)'
|
462 |
+
)
|
463 |
+
|
464 |
+
# channel masking
|
465 |
+
parser.add_argument(
|
466 |
+
'--mask-channel-length',
|
467 |
+
type=int,
|
468 |
+
help='length of the mask for features (channels)'
|
469 |
+
)
|
470 |
+
parser.add_argument(
|
471 |
+
'--mask-channel-prob',
|
472 |
+
type=float,
|
473 |
+
help="probability of replacing a feature with 0"
|
474 |
+
)
|
475 |
+
parser.add_argument(
|
476 |
+
"--mask-channel-selection",
|
477 |
+
choices=["static", "uniform", "normal", "poisson"],
|
478 |
+
help="how to choose mask length for channel masking",
|
479 |
+
)
|
480 |
+
parser.add_argument(
|
481 |
+
'--mask-channel-other',
|
482 |
+
type=float,
|
483 |
+
help="secondary mask argument "
|
484 |
+
"(used for more complex distributions), "
|
485 |
+
"see help in compute_mask_indices"
|
486 |
+
)
|
487 |
+
parser.add_argument(
|
488 |
+
'--mask-channel-min-space',
|
489 |
+
type=int,
|
490 |
+
help='min space between spans (if no overlap is enabled)'
|
491 |
+
)
|
492 |
+
|
493 |
+
# abs positional embeddings
|
494 |
+
parser.add_argument(
|
495 |
+
'--conv-pos',
|
496 |
+
type=int,
|
497 |
+
help='number of filters for convolutional positional embeddings'
|
498 |
+
)
|
499 |
+
parser.add_argument(
|
500 |
+
'--conv-pos-groups',
|
501 |
+
type=int,
|
502 |
+
help='number of groups for convolutional positional embedding'
|
503 |
+
)
|
504 |
+
|
505 |
+
# codebook related
|
506 |
+
parser.add_argument(
|
507 |
+
"--use-codebook",
|
508 |
+
action="store_true",
|
509 |
+
help="whether to use codebook",
|
510 |
+
)
|
511 |
+
parser.add_argument(
|
512 |
+
"--codebook-prob",
|
513 |
+
type=float,
|
514 |
+
help="probability to use codebook",
|
515 |
+
)
|
516 |
+
parser.add_argument(
|
517 |
+
"--latent-vars",
|
518 |
+
type=int,
|
519 |
+
help="number of latent variables V in each group of the codebook",
|
520 |
+
)
|
521 |
+
parser.add_argument(
|
522 |
+
"--latent-groups",
|
523 |
+
type=int,
|
524 |
+
help="number of groups G of latent variables in the codebook",
|
525 |
+
)
|
526 |
+
parser.add_argument(
|
527 |
+
"--latent-dim",
|
528 |
+
type=int,
|
529 |
+
help="if > 0, uses this dimensionality for latent variables. "
|
530 |
+
"otherwise uses final_dim / latent_groups",
|
531 |
+
)
|
532 |
+
parser.add_argument(
|
533 |
+
"--latent-temp",
|
534 |
+
type=literal_eval,
|
535 |
+
help="temperature for latent variable sampling. "
|
536 |
+
"can be tuple of 3 values (start, end, decay)",
|
537 |
+
)
|
538 |
+
parser.add_argument(
|
539 |
+
"--quantizer-depth",
|
540 |
+
type=int,
|
541 |
+
help="number of quantizer layers",
|
542 |
+
)
|
543 |
+
parser.add_argument(
|
544 |
+
"--quantizer-factor",
|
545 |
+
type=int,
|
546 |
+
help="number of quantizer layers",
|
547 |
+
)
|
548 |
+
parser.add_argument(
|
549 |
+
"--get-code-distribution",
|
550 |
+
action='store_true',
|
551 |
+
help="whether to get the code distribution (for test)",
|
552 |
+
)
|
553 |
+
|
554 |
+
# relative pos enc
|
555 |
+
parser.add_argument(
|
556 |
+
"--relative-position-embedding",
|
557 |
+
action='store_true',
|
558 |
+
help="whether to use relative position embedding",
|
559 |
+
)
|
560 |
+
parser.add_argument(
|
561 |
+
"--num-buckets",
|
562 |
+
type=int,
|
563 |
+
default=320,
|
564 |
+
help="num of buckets for relative position embedding",
|
565 |
+
)
|
566 |
+
parser.add_argument(
|
567 |
+
"--max-distance",
|
568 |
+
type=int,
|
569 |
+
default=1280,
|
570 |
+
help="max distance for relative position embedding",
|
571 |
+
)
|
572 |
+
parser.add_argument(
|
573 |
+
"--encoder-max-relative-position",
|
574 |
+
type=int,
|
575 |
+
help="max distance for relative position embedding in encoder",
|
576 |
+
)
|
577 |
+
parser.add_argument(
|
578 |
+
"--decoder-max-relative-position",
|
579 |
+
type=int,
|
580 |
+
help="max distance for relative position embedding in decoder",
|
581 |
+
)
|
582 |
+
|
583 |
+
# hubert feature extractor
|
584 |
+
parser.add_argument(
|
585 |
+
"--conv-feature-layers",
|
586 |
+
type=str,
|
587 |
+
help= "string describing convolutional feature extraction "
|
588 |
+
"layers in form of a python list that contains "
|
589 |
+
"[(dim, kernel_size, stride), ...]",
|
590 |
+
)
|
591 |
+
parser.add_argument(
|
592 |
+
"--conv-bias",
|
593 |
+
action='store_true',
|
594 |
+
help="include bias in conv encoder",
|
595 |
+
)
|
596 |
+
parser.add_argument(
|
597 |
+
"--extractor-mode",
|
598 |
+
choices=["default", "layer_norm"],
|
599 |
+
help="mode for feature extractor. default has a single group "
|
600 |
+
"norm with d groups in the first conv block, whereas layer_norm "
|
601 |
+
"has layer norms in every block (meant to use with normalize=True)"
|
602 |
+
)
|
603 |
+
|
604 |
+
# others
|
605 |
+
parser.add_argument(
|
606 |
+
"--bert-init",
|
607 |
+
action='store_true',
|
608 |
+
help="initilize as bert",
|
609 |
+
)
|
610 |
+
parser.add_argument(
|
611 |
+
"--unb-enc-layer",
|
612 |
+
type=int,
|
613 |
+
default=-1,
|
614 |
+
help="which layer's output is used as the input of decoder",
|
615 |
+
)
|
616 |
+
|
617 |
+
# Encoder, Decoder
|
618 |
+
@classmethod
|
619 |
+
def build_encoder(cls, args, dictionary=None, embed_tokens=None):
|
620 |
+
return TransformerEncoder(args, dictionary, embed_tokens)
|
621 |
+
|
622 |
+
@classmethod
|
623 |
+
def build_decoder(cls, args):
|
624 |
+
return TransformerDecoder(args)
|
625 |
+
|
626 |
+
# Encoder Prenet
|
627 |
+
@classmethod
|
628 |
+
def build_text_encoder_prenet(cls, embed_tokens, args):
|
629 |
+
return TextEncoderPrenet(embed_tokens, args)
|
630 |
+
|
631 |
+
@classmethod
|
632 |
+
def build_speech_encoder_prenet(cls, args):
|
633 |
+
return SpeechEncoderPrenet(args)
|
634 |
+
|
635 |
+
# Decoder Prenet
|
636 |
+
@classmethod
|
637 |
+
def build_text_decoder_prenet(cls, embed_tokens, args):
|
638 |
+
return TextDecoderPrenet(embed_tokens, args)
|
639 |
+
|
640 |
+
@classmethod
|
641 |
+
def build_speech_decoder_prenet(cls, odim, args):
|
642 |
+
return SpeechDecoderPrenet(odim, args)
|
643 |
+
|
644 |
+
# Decoder Postnet
|
645 |
+
@classmethod
|
646 |
+
def build_text_decoder_postnet(cls, embed_tokens, dictionary, args):
|
647 |
+
return TextDecoderPostnet(embed_tokens, dictionary, args)
|
648 |
+
|
649 |
+
@classmethod
|
650 |
+
def build_speaker_decoder_postnet(cls, embed_dim, class_num, args):
|
651 |
+
return SpeakerDecoderPostnet(embed_dim, class_num, args)
|
652 |
+
|
653 |
+
@classmethod
|
654 |
+
def build_speech_decoder_postnet(cls, odim, args):
|
655 |
+
return SpeechDecoderPostnet(odim, args)
|
656 |
+
|
657 |
+
@classmethod
|
658 |
+
def build_speech_encoder_postnet(cls, dictionaries, args):
|
659 |
+
return SpeechEncoderPostnet(dictionaries, args)
|
660 |
+
|
661 |
+
@classmethod
|
662 |
+
def build_model(cls, args, task):
|
663 |
+
"""Build a new model instance."""
|
664 |
+
|
665 |
+
# make sure all arguments are present in older models
|
666 |
+
base_architecture(args)
|
667 |
+
|
668 |
+
def build_embedding(dictionary, embed_dim, max_num_embeddings=None):
|
669 |
+
num_embeddings = len(dictionary)
|
670 |
+
if max_num_embeddings is not None and isinstance(max_num_embeddings, int):
|
671 |
+
num_embeddings = min(num_embeddings, max_num_embeddings)
|
672 |
+
padding_idx = dictionary.pad()
|
673 |
+
return Embedding(num_embeddings, embed_dim, padding_idx)
|
674 |
+
|
675 |
+
if hasattr(args, "sid_pad_prenet") and args.sid_pad_prenet:
|
676 |
+
max_num_embeddings = 3 # <pad> at index 2
|
677 |
+
else:
|
678 |
+
max_num_embeddings = None
|
679 |
+
|
680 |
+
text_decoder_embed_tokens = build_embedding(
|
681 |
+
task.dicts["text"], args.decoder_embed_dim, max_num_embeddings
|
682 |
+
)
|
683 |
+
|
684 |
+
if args.share_input_output_embed:
|
685 |
+
text_encoder_embed_tokens = text_decoder_embed_tokens
|
686 |
+
else:
|
687 |
+
text_encoder_embed_tokens = build_embedding(
|
688 |
+
task.dicts["text"], args.encoder_embed_dim
|
689 |
+
)
|
690 |
+
|
691 |
+
speech_odim = args.speech_odim
|
692 |
+
if "text" in task.dicts:
|
693 |
+
encoder = cls.build_encoder(args, task.dicts["text"], text_encoder_embed_tokens)
|
694 |
+
else:
|
695 |
+
encoder = cls.build_encoder(args)
|
696 |
+
decoder = cls.build_decoder(args)
|
697 |
+
|
698 |
+
text_encoder_prenet = cls.build_text_encoder_prenet(text_encoder_embed_tokens, args)
|
699 |
+
speech_encoder_prenet = cls.build_speech_encoder_prenet(args)
|
700 |
+
|
701 |
+
text_decoder_prenet = cls.build_text_decoder_prenet(text_decoder_embed_tokens, args)
|
702 |
+
if getattr(args, "sid_pooling_layer", None) == "decoder-las":
|
703 |
+
speech_decoder_prenet = cls.build_speech_encoder_prenet(args)
|
704 |
+
else:
|
705 |
+
speech_decoder_prenet = cls.build_speech_decoder_prenet(speech_odim, args)
|
706 |
+
|
707 |
+
text_decoder_postnet = cls.build_text_decoder_postnet(text_decoder_embed_tokens, task.dicts['text'], args)
|
708 |
+
speech_decoder_postnet = cls.build_speech_decoder_postnet(speech_odim, args)
|
709 |
+
|
710 |
+
if getattr(args, "sid_t5_postnet", False):
|
711 |
+
speaker_decoder_postnet = None
|
712 |
+
else:
|
713 |
+
if task.t5_task == "s2c":
|
714 |
+
speaker_decoder_postnet = cls.build_speaker_decoder_postnet(args.sid_embed_dim, len(task.dicts['text']), args)
|
715 |
+
else:
|
716 |
+
speaker_decoder_postnet = None
|
717 |
+
|
718 |
+
if "hubert" in task.dicts:
|
719 |
+
speech_encoder_postnet = cls.build_speech_encoder_postnet(task.dicts['hubert'], args)
|
720 |
+
else:
|
721 |
+
speech_encoder_postnet = None
|
722 |
+
|
723 |
+
return cls(
|
724 |
+
args,
|
725 |
+
encoder, decoder,
|
726 |
+
text_encoder_prenet, speech_encoder_prenet,
|
727 |
+
text_decoder_prenet, speech_decoder_prenet,
|
728 |
+
text_decoder_postnet, speech_decoder_postnet,
|
729 |
+
speaker_decoder_postnet, speech_encoder_postnet,
|
730 |
+
)
|
731 |
+
|
732 |
+
def get_normalized_probs(
|
733 |
+
self,
|
734 |
+
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
|
735 |
+
log_probs: bool,
|
736 |
+
sample: Optional[Dict[str, Tensor]] = None,
|
737 |
+
):
|
738 |
+
# net_output['encoder_out'] is a (B, T, D) tensor
|
739 |
+
lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, sample)
|
740 |
+
lprobs.batch_first = True
|
741 |
+
return lprobs
|
742 |
+
|
743 |
+
def get_normalized_probs_for_ctc(self, net_output, log_probs):
|
744 |
+
"""Get normalized probabilities (or log probs) from a net's output."""
|
745 |
+
|
746 |
+
logits = net_output["encoder_out_for_ctc"][0]
|
747 |
+
if log_probs:
|
748 |
+
return utils.log_softmax(logits.float(), dim=-1)
|
749 |
+
else:
|
750 |
+
return utils.softmax(logits.float(), dim=-1)
|
751 |
+
|
752 |
+
def get_logits(self, net_output, is_masked=True):
|
753 |
+
if is_masked:
|
754 |
+
logits_list = net_output["logit_m_list"]
|
755 |
+
else:
|
756 |
+
logits_list = net_output["logit_u_list"]
|
757 |
+
logits_list = [x.float() for x in logits_list if x is not None]
|
758 |
+
return logits_list
|
759 |
+
|
760 |
+
def get_targets(self, sample, net_output, is_masked=True):
|
761 |
+
if "logit_m_list" in net_output:
|
762 |
+
logits_list = self.get_logits(net_output, is_masked)
|
763 |
+
targets_list = [
|
764 |
+
x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list
|
765 |
+
]
|
766 |
+
return targets_list
|
767 |
+
else:
|
768 |
+
return sample["target"]
|
769 |
+
|
770 |
+
def get_extra_losses(self, net_output):
|
771 |
+
extra_losses = []
|
772 |
+
names = []
|
773 |
+
|
774 |
+
if "features_pen" in net_output:
|
775 |
+
extra_losses.append(net_output["features_pen"])
|
776 |
+
names.append("features_pen")
|
777 |
+
|
778 |
+
if "prob_perplexity" in net_output:
|
779 |
+
extra_losses.append(
|
780 |
+
(net_output["num_vars"] - net_output["prob_perplexity"])
|
781 |
+
/ net_output["num_vars"]
|
782 |
+
)
|
783 |
+
names.append("prob_perplexity")
|
784 |
+
|
785 |
+
return extra_losses, names
|
786 |
+
|
787 |
+
def forward(self, source=None, src_tokens=None, src_lengths=None, prev_output_tokens=None, tgt_lengths=None, spkembs=None, target_list=None, task_name=None, padding_mask=None, only_hubert=False, only_ctc=False, feature_only=False, tgt_enc_layer=None, mask=True):
|
788 |
+
"""
|
789 |
+
The forward method inherited from the base class has a **kwargs
|
790 |
+
argument in its input, which is not supported in torchscript. This
|
791 |
+
method overwrites the forward method definition without **kwargs.
|
792 |
+
"""
|
793 |
+
assert source is not None or src_tokens is not None
|
794 |
+
# padding_mask is not none only when input is waveform
|
795 |
+
if source is None and padding_mask is None and not feature_only:
|
796 |
+
input_type = 'text'
|
797 |
+
else:
|
798 |
+
input_type = 'speech'
|
799 |
+
|
800 |
+
if prev_output_tokens is not None and len(prev_output_tokens.size()) == 2:
|
801 |
+
output_type = 'text'
|
802 |
+
codebook_out = {}
|
803 |
+
else:
|
804 |
+
output_type = 'speech'
|
805 |
+
|
806 |
+
if task_name is not None and task_name == "s2c":
|
807 |
+
if target_list is not None and target_list.size(1) == 1 and not getattr(self.args, "sid_t5_postnet", False):
|
808 |
+
sid_target = F.one_hot(target_list.squeeze(1), num_classes=self.speaker_decoder_postnet.class_num)
|
809 |
+
else:
|
810 |
+
sid_target = None
|
811 |
+
target_list = None
|
812 |
+
|
813 |
+
# Encoder Prenet
|
814 |
+
if input_type == 'text':
|
815 |
+
encoder_input, encoder_padding_mask = self.text_encoder_prenet(src_tokens)
|
816 |
+
else:
|
817 |
+
if target_list is not None:
|
818 |
+
encoder_input, encoder_padding_mask = self.speech_encoder_prenet(source, require_feat_pen=True, target_list=target_list, padding_mask=padding_mask, mask=mask)
|
819 |
+
encoder_input, features_pen, mask_indices, target_list = encoder_input
|
820 |
+
else:
|
821 |
+
encoder_input, encoder_padding_mask = self.speech_encoder_prenet(source, padding_mask=padding_mask, mask=self.training)
|
822 |
+
# shuffle a batch of inputs of encoder
|
823 |
+
if self.training and hasattr(self.args, "sid_shuffle_encoder_input") and getattr(self.args, "sid_shuffle_encoder_input", False):
|
824 |
+
shuffle_index = torch.randperm(encoder_padding_mask.size(1), device=encoder_padding_mask.device)
|
825 |
+
encoder_input = torch.index_select(encoder_input, 1, shuffle_index)
|
826 |
+
encoder_padding_mask = torch.index_select(encoder_padding_mask, 1, shuffle_index)
|
827 |
+
if getattr(self.args, "sid_encoder_cls", None) == "encoder":
|
828 |
+
prev_output_tokens = torch.zeros_like(prev_output_tokens)
|
829 |
+
encoder_input, encoder_padding_mask = self._integrate_with_speaker_cls(prev_output_tokens, encoder_input, encoder_padding_mask)
|
830 |
+
|
831 |
+
# Encoder: T x B x C
|
832 |
+
encoder_output = self.encoder(encoder_input, encoder_padding_mask, tgt_layer=tgt_enc_layer)
|
833 |
+
|
834 |
+
if task_name is not None and task_name == 'speech_pretrain' and feature_only:
|
835 |
+
return encoder_output["encoder_out"][0].transpose(0, 1)
|
836 |
+
|
837 |
+
if task_name is not None and task_name == 's2c':
|
838 |
+
if self.args.sid_pooling_layer == "encoder":
|
839 |
+
return self.speaker_decoder_postnet(encoder_output["encoder_out"][0].transpose(0, 1).mean(1), sid_target), None
|
840 |
+
elif self.args.sid_pooling_layer == "encoder-cls":
|
841 |
+
return self.speaker_decoder_postnet(encoder_output["encoder_out"][0].transpose(0, 1)[:,0], sid_target), None
|
842 |
+
elif self.args.sid_pooling_layer == "encoder-speaker" or getattr(self.args, "sid_decoder_speaker", False):
|
843 |
+
return self.speaker_decoder_postnet(encoder_output["encoder_out"][0].transpose(0, 1), sid_target), None
|
844 |
+
|
845 |
+
if target_list is not None:
|
846 |
+
hubert_results = self.hubert_layer(
|
847 |
+
encoder_output["encoder_out"][0].transpose(0, 1),
|
848 |
+
encoder_padding_mask,
|
849 |
+
mask_indices,
|
850 |
+
target_list
|
851 |
+
)
|
852 |
+
|
853 |
+
hubert_results['features_pen'] = features_pen
|
854 |
+
|
855 |
+
if "decoder_input" in encoder_output and encoder_output["decoder_input"][0] is not None:
|
856 |
+
# Change the encoder output to decoder input once set unb-enc-layer
|
857 |
+
encoder_output["encoder_out"] = encoder_output["decoder_input"]
|
858 |
+
|
859 |
+
if self.use_codebook:
|
860 |
+
q = self.quantizer(encoder_output["encoder_out"][0].transpose(0, 1))
|
861 |
+
|
862 |
+
# q["x"]: B x T x C
|
863 |
+
# Sample indexs according to the codebook prob
|
864 |
+
random_idx = torch.randperm(q["x"].size(1))[:int(q["x"].size(1) * self.codebook_prob)]
|
865 |
+
# Make weight for q
|
866 |
+
q_w = q["x"].new_zeros(q["x"].size(1))
|
867 |
+
q_w[random_idx] = 1.0
|
868 |
+
# Combine quantized codes and encoder output
|
869 |
+
encoder_output["encoder_out"][0] = (
|
870 |
+
q_w.view(-1, 1) * q["x"] + (- q_w + 1).view(-1, 1) * encoder_output["encoder_out"][0].transpose(0, 1)
|
871 |
+
).transpose(0, 1)
|
872 |
+
|
873 |
+
# encoder_output["encoder_out"][0] = q["x"].transpose(0, 1)
|
874 |
+
if output_type == 'speech':
|
875 |
+
hubert_results["prob_perplexity"] = q["prob_perplexity"]
|
876 |
+
hubert_results["code_perplexity"] = q["code_perplexity"]
|
877 |
+
hubert_results["num_vars"] = q["num_vars"]
|
878 |
+
hubert_results["temp"] = q["temp"]
|
879 |
+
elif output_type == 'text':
|
880 |
+
codebook_out["prob_perplexity"] = q["prob_perplexity"]
|
881 |
+
codebook_out["code_perplexity"] = q["code_perplexity"]
|
882 |
+
codebook_out["num_vars"] = q["num_vars"]
|
883 |
+
codebook_out["temp"] = q["temp"]
|
884 |
+
|
885 |
+
if only_hubert and target_list is not None:
|
886 |
+
return hubert_results, None
|
887 |
+
|
888 |
+
if only_ctc and task_name is not None and task_name == "s2t":
|
889 |
+
return None, encoder_output
|
890 |
+
elif not self.training and prev_output_tokens is None and task_name == "s2t" and task_name is not None:
|
891 |
+
return encoder_output
|
892 |
+
|
893 |
+
# Decoder Prenet
|
894 |
+
if output_type == 'text':
|
895 |
+
# _ is the incremental state
|
896 |
+
prev_output_tokens, tgt_mask, _ = self.text_decoder_prenet(prev_output_tokens)
|
897 |
+
if task_name is not None and task_name == 's2c':
|
898 |
+
prev_output_tokens = torch.zeros_like(prev_output_tokens)
|
899 |
+
else:
|
900 |
+
# integrate speaker embedding
|
901 |
+
if self.spk_embed_integration_type == "pre" and self.spk_embed_dim is not None:
|
902 |
+
# Decoder Prenet
|
903 |
+
prev_output_tokens, tgt_mask = self.speech_decoder_prenet(prev_output_tokens, tgt_lengths, spkembs)
|
904 |
+
else:
|
905 |
+
if self.spk_embed_dim is not None:
|
906 |
+
encoder_output["encoder_out"] = [self._integrate_with_spk_embed(
|
907 |
+
encoder_output["encoder_out"][0].transpose(0, 1), spkembs
|
908 |
+
).transpose(0, 1)]
|
909 |
+
|
910 |
+
prev_output_tokens, tgt_mask = self.speech_decoder_prenet(prev_output_tokens, tgt_lengths)
|
911 |
+
|
912 |
+
# BART Sequence Classification: cat <pad> + feature before decoder
|
913 |
+
if task_name is not None and task_name == 's2c' and self.args.sid_pooling_layer == "decoder-las":
|
914 |
+
decoder_feat_input, decoder_feat_mask = self.speech_decoder_prenet(src_tokens, src_lengths)
|
915 |
+
prev_output_tokens, tgt_mask = self._integrate_with_speaker_cls((prev_output_tokens, tgt_mask), decoder_feat_input, decoder_feat_mask, cls_first=False)
|
916 |
+
|
917 |
+
# SE predict masking to corresponding inputs and source speech replaces the prev_output_tokens as the input of decoder
|
918 |
+
if task_name is not None and task_name == "s2s" and getattr(self.args, "se_decoder_input", "previous_target") == "source":
|
919 |
+
prev_output_tokens, tgt_mask = self.speech_decoder_prenet(src_tokens, src_lengths)
|
920 |
+
|
921 |
+
# Decoder
|
922 |
+
decoder_output, extra = self.decoder(prev_output_tokens, tgt_mask, encoder_output,
|
923 |
+
full_context_alignment=getattr(self.args, "decoder_full_context_alignment", False),
|
924 |
+
alignment_layer=(-1 if target_list is None and output_type == 'speech' else None))
|
925 |
+
# Decoder Postnet
|
926 |
+
if task_name is not None and task_name == 's2c':
|
927 |
+
if not getattr(self.args, "sid_t5_postnet", False):
|
928 |
+
if self.args.sid_pooling_layer == "decoder":
|
929 |
+
return self.speaker_decoder_postnet(decoder_output.mean(1), sid_target), None
|
930 |
+
elif self.args.sid_pooling_layer == "decoder-las":
|
931 |
+
indices = (tgt_mask.eq(False).float().sum(1) - 1.0).type(torch.int64)
|
932 |
+
indices = indices.unsqueeze(1).unsqueeze(2).expand(-1, -1, decoder_output.size(2))
|
933 |
+
return self.speaker_decoder_postnet(decoder_output.gather(1, indices), sid_target), None
|
934 |
+
else:
|
935 |
+
return (self.text_decoder_postnet(decoder_output), None), encoder_output
|
936 |
+
|
937 |
+
# SE predict: masking, target, delta. Ensure reduction factor 1
|
938 |
+
if task_name is not None and task_name == 's2s' and getattr(self.args, "se_predict", None) is not None:
|
939 |
+
assert self.reduction_factor == 1, f"{self.reduction_factor} != 1"
|
940 |
+
before_outs, after_outs, logits = self.speech_decoder_postnet(decoder_output)
|
941 |
+
se_predict = getattr(self.args, "se_predict")
|
942 |
+
if se_predict == "masking":
|
943 |
+
before_outs = torch.sigmoid(before_outs) * src_tokens
|
944 |
+
after_outs = torch.sigmoid(after_outs) * src_tokens
|
945 |
+
return before_outs, after_outs, logits, extra['attn'][0]
|
946 |
+
elif se_predict == "target":
|
947 |
+
return before_outs, after_outs, logits, extra['attn'][0]
|
948 |
+
elif se_predict == "delta":
|
949 |
+
before_outs = before_outs - src_tokens
|
950 |
+
after_outs = after_outs - src_tokens
|
951 |
+
return before_outs, after_outs, logits, extra['attn'][0]
|
952 |
+
else:
|
953 |
+
raise ValueError(f"{se_predict} not in [masking, target, delta]")
|
954 |
+
|
955 |
+
if task_name is not None and task_name == 's2t':
|
956 |
+
#return self.text_decoder_postnet(decoder_output), None
|
957 |
+
return (self.text_decoder_postnet(decoder_output), None), encoder_output
|
958 |
+
if output_type == 'text':
|
959 |
+
return (self.text_decoder_postnet(decoder_output), None), codebook_out, encoder_output
|
960 |
+
else:
|
961 |
+
if target_list is not None:
|
962 |
+
return hubert_results, (self.speech_decoder_postnet(decoder_output) + (extra['attn'][0],))
|
963 |
+
else:
|
964 |
+
return self.speech_decoder_postnet(decoder_output) + (extra['attn'][0],)
|
965 |
+
|
966 |
+
def _integrate_with_speaker_cls(self, pad_input, encoder_input, encoder_padding_mask=None, cls_first=True):
|
967 |
+
"""
|
968 |
+
encoder_input: [B, T, C]
|
969 |
+
encoder_padding_mask: [B, T]
|
970 |
+
"""
|
971 |
+
if hasattr(self, "text_decoder_prenet"):
|
972 |
+
if isinstance(pad_input, tuple):
|
973 |
+
repeat_cls_vector, repeat_cls_mask = pad_input
|
974 |
+
else:
|
975 |
+
repeat_cls_vector, repeat_cls_mask, _ = self.text_decoder_prenet(pad_input)
|
976 |
+
|
977 |
+
if encoder_padding_mask is not None:
|
978 |
+
bsz = encoder_input.size(0)
|
979 |
+
tsz = encoder_input.size(1)
|
980 |
+
encoder_padding_mask = encoder_input.new_zeros((bsz, tsz)) == 1.0
|
981 |
+
if repeat_cls_mask is None:
|
982 |
+
mask_size = (encoder_padding_mask.size(0), 1)
|
983 |
+
mask_type = encoder_padding_mask.dtype
|
984 |
+
repeat_cls_mask = encoder_padding_mask.new_zeros(mask_size) == 1.0
|
985 |
+
ret_encoder_padding_mask = torch.cat([repeat_cls_mask, encoder_padding_mask], dim=1)
|
986 |
+
|
987 |
+
if cls_first:
|
988 |
+
ret_encoder_input = torch.cat([repeat_cls_vector, encoder_input], dim=1)
|
989 |
+
else:
|
990 |
+
ret_encoder_input = torch.cat([encoder_input, encoder_input[:,-1:,:]], dim=1)
|
991 |
+
mask_size = (encoder_padding_mask.size(0), 1)
|
992 |
+
mask_type = encoder_padding_mask.dtype
|
993 |
+
repeat_cls_mask_ = encoder_padding_mask.new_ones(mask_size) == 1.0
|
994 |
+
encoder_padding_mask_ = torch.cat([encoder_padding_mask, repeat_cls_mask_], dim=1)
|
995 |
+
indices = encoder_padding_mask.eq(False).float().sum(1).type(torch.int64).unsqueeze(1)
|
996 |
+
indices_mask = torch.zeros_like(ret_encoder_padding_mask).scatter(1, indices, 1.0)
|
997 |
+
ret_encoder_input = ret_encoder_input * (1.0 - encoder_padding_mask_.type(ret_encoder_input.dtype).unsqueeze(2)) \
|
998 |
+
+ repeat_cls_vector * indices_mask.type(repeat_cls_vector.dtype).unsqueeze(2)
|
999 |
+
|
1000 |
+
return ret_encoder_input, ret_encoder_padding_mask
|
1001 |
+
|
1002 |
+
def _integrate_with_spk_embed(self, hs, spembs):
|
1003 |
+
"""Integrate speaker embedding with hidden states.
|
1004 |
+
Args:
|
1005 |
+
hs (Tensor): Batch of hidden state sequences (B, Tmax, adim).
|
1006 |
+
spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim).
|
1007 |
+
Returns:
|
1008 |
+
Tensor: Batch of integrated hidden state sequences (B, Tmax, adim)
|
1009 |
+
"""
|
1010 |
+
if self.spk_embed_integration_type == "add":
|
1011 |
+
# apply projection and then add to hidden states
|
1012 |
+
spembs = self.projection(F.normalize(spembs))
|
1013 |
+
hs = hs + spembs.unsqueeze(1)
|
1014 |
+
elif self.spk_embed_integration_type == "concat":
|
1015 |
+
# concat hidden states with spk embeds and then apply projection
|
1016 |
+
spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1)
|
1017 |
+
hs = self.projection(torch.cat([hs, spembs], dim=-1))
|
1018 |
+
else:
|
1019 |
+
raise NotImplementedError("support only add or concat.")
|
1020 |
+
|
1021 |
+
return hs
|
1022 |
+
|
1023 |
+
def load_state_dict(
|
1024 |
+
self,
|
1025 |
+
state_dict,
|
1026 |
+
strict=True,
|
1027 |
+
model_cfg=None,
|
1028 |
+
args=None,
|
1029 |
+
):
|
1030 |
+
"""NOT STRICT Copies parameters and buffers from *state_dict* into this module and
|
1031 |
+
its descendants.
|
1032 |
+
|
1033 |
+
Overrides the method in :class:`nn.Module`. Compared with that method
|
1034 |
+
this additionally "upgrades" *state_dicts* from old checkpoints.
|
1035 |
+
"""
|
1036 |
+
# self.prune_modules(model_cfg.modules_filter)
|
1037 |
+
model_dict_size = self.text_decoder_postnet.output_projection.out_features
|
1038 |
+
ckpt_dict_size = state_dict["text_decoder_postnet.output_projection.weight"].size(0)
|
1039 |
+
if model_dict_size != ckpt_dict_size:
|
1040 |
+
# reset dictionary-related modules, such as embedding table and encoder ctc embed
|
1041 |
+
logger.warn(f"not equal dictionary between model and checkpoint: {model_dict_size} vs {ckpt_dict_size}")
|
1042 |
+
logger.info(f"reset model dictionary with size of {model_dict_size}")
|
1043 |
+
removed_keys = [
|
1044 |
+
key for key in state_dict.keys() if any(
|
1045 |
+
key.startswith(previ) for previ in [
|
1046 |
+
"encoder.proj", "text_encoder_prenet", "text_decoder_prenet", "text_decoder_postnet"
|
1047 |
+
]
|
1048 |
+
)
|
1049 |
+
]
|
1050 |
+
for key in removed_keys:
|
1051 |
+
state_dict.pop(key, None)
|
1052 |
+
logger.info(f"removed loaded checkpoint: {key}")
|
1053 |
+
for m in self._modules.keys():
|
1054 |
+
m_state_dict = {
|
1055 |
+
key.replace(f"{m}.", ""): value for key, value in state_dict.items() if key.startswith(f"{m}.")
|
1056 |
+
}
|
1057 |
+
if hasattr(self, m):
|
1058 |
+
self._modules[m].load_state_dict(m_state_dict, False)
|
1059 |
+
return self
|
1060 |
+
|
1061 |
+
def prune_modules(self, modules_filter=None):
|
1062 |
+
"""Prune unused modules for specific tasks."""
|
1063 |
+
if modules_filter is None:
|
1064 |
+
return
|
1065 |
+
elif modules_filter == "s2c":
|
1066 |
+
if hasattr(self, "text_encoder_prenet"): del self.text_encoder_prenet
|
1067 |
+
if hasattr(self, "speech_decoder_prenet") and getattr(self.args, "sid_pooling_layer", None) != "decoder-las":
|
1068 |
+
del self.speech_decoder_prenet
|
1069 |
+
if hasattr(self, "speech_decoder_postnet"): del self.speech_decoder_postnet
|
1070 |
+
if hasattr(self, "text_decoder_postnet"): del self.text_decoder_postnet
|
1071 |
+
if hasattr(self, "speech_encoder_postnet"): del self.speech_encoder_postnet
|
1072 |
+
if hasattr(self.encoder, "proj"): self.encoder.proj = None
|
1073 |
+
if hasattr(self, "projection"): del self.projection
|
1074 |
+
if hasattr(self, "quantizer"): del self.quantizer
|
1075 |
+
if getattr(self.args, "sid_pooling_layer", "decoder").startswith("encoder") or getattr(self.args, "sid_decoder_speaker", False):
|
1076 |
+
if hasattr(self.decoder, "dropout_module"): del self.decoder.dropout_module
|
1077 |
+
if hasattr(self.decoder, "layers"): del self.decoder.layers
|
1078 |
+
if hasattr(self.decoder, "layer_norm"): del self.decoder.layer_norm
|
1079 |
+
if hasattr(self, "text_decoder_prenet"): del self.text_decoder_prenet
|
1080 |
+
elif modules_filter == "s2s":
|
1081 |
+
if hasattr(self, "speaker_decoder_postnet"): del self.speaker_decoder_postnet
|
1082 |
+
if hasattr(self, "text_encoder_prenet"): del self.text_encoder_prenet
|
1083 |
+
if hasattr(self, "text_decoder_prenet"): del self.text_decoder_prenet
|
1084 |
+
if hasattr(self, "text_decoder_postnet"): del self.text_decoder_postnet
|
1085 |
+
if hasattr(self, "speech_encoder_postnet"): del self.speech_encoder_postnet
|
1086 |
+
if hasattr(self.encoder, "proj"): self.encoder.proj = None
|
1087 |
+
if hasattr(self, "projection"): del self.projection
|
1088 |
+
if hasattr(self, "quantizer"): del self.quantizer
|
1089 |
+
elif modules_filter == "t2s":
|
1090 |
+
if hasattr(self, "speaker_decoder_postnet"): del self.speaker_decoder_postnet
|
1091 |
+
if hasattr(self, "speech_encoder_prenet"): del self.speech_encoder_prenet
|
1092 |
+
if hasattr(self, "text_decoder_prenet"): del self.text_decoder_prenet
|
1093 |
+
if hasattr(self, "text_decoder_postnet"): del self.text_decoder_postnet
|
1094 |
+
if hasattr(self, "speech_encoder_postnet"): del self.speech_encoder_postnet
|
1095 |
+
if hasattr(self.encoder, "proj"): self.encoder.proj = None
|
1096 |
+
if hasattr(self, "projection"): del self.projection
|
1097 |
+
if hasattr(self, "quantizer"): del self.quantizer
|
1098 |
+
elif modules_filter == "s3prl":
|
1099 |
+
# remain the encoder and the pre/post net
|
1100 |
+
if hasattr(self.decoder, "dropout_module"): del self.decoder.dropout_module
|
1101 |
+
if hasattr(self.decoder, "layers"): del self.decoder.layers
|
1102 |
+
if hasattr(self.decoder, "layer_norm"): del self.decoder.layer_norm
|
1103 |
+
if hasattr(self, "speaker_decoder_postnet"): del self.speaker_decoder_postnet
|
1104 |
+
if hasattr(self, "text_decoder_prenet"): del self.text_decoder_prenet
|
1105 |
+
if hasattr(self, "text_decoder_postnet"): del self.text_decoder_postnet
|
1106 |
+
if hasattr(self, "speech_decoder_prenet"): del self.speech_decoder_prenet
|
1107 |
+
if hasattr(self, "speech_decoder_postnet"): del self.speech_decoder_postnet
|
1108 |
+
if hasattr(self, "speech_encoder_postnet"): del self.speech_encoder_postnet
|
1109 |
+
if hasattr(self.encoder, "proj"): self.encoder.proj = None
|
1110 |
+
if hasattr(self, "projection"): del self.projection
|
1111 |
+
if hasattr(self, "quantizer"): del self.quantizer
|
1112 |
+
|
1113 |
+
def forward_encoder_torchscript(self, net_input: Dict[str, Tensor]):
|
1114 |
+
"""A TorchScript-compatible version of forward.
|
1115 |
+
|
1116 |
+
Encoders which use additional arguments may want to override
|
1117 |
+
this method for TorchScript compatibility.
|
1118 |
+
"""
|
1119 |
+
if torch.jit.is_scripting():
|
1120 |
+
return self.forward_encoder(
|
1121 |
+
source=net_input["source"],
|
1122 |
+
padding_mask=net_input["padding_mask"]
|
1123 |
+
)
|
1124 |
+
else:
|
1125 |
+
return self.forward_encoder_non_torchscript(net_input)
|
1126 |
+
|
1127 |
+
@torch.jit.unused
|
1128 |
+
def forward_encoder_non_torchscript(self, net_input: Dict[str, Tensor]):
|
1129 |
+
encoder_input = {
|
1130 |
+
k: v for k, v in net_input.items() if k != "prev_output_tokens" and k != "task_name"
|
1131 |
+
}
|
1132 |
+
return self.forward_encoder(**encoder_input)
|
1133 |
+
|
1134 |
+
def forward_encoder(self, source, padding_mask=None):
|
1135 |
+
# Encoder Prenet
|
1136 |
+
encoder_input, encoder_padding_mask = self.speech_encoder_prenet(source, padding_mask=padding_mask, mask=False)
|
1137 |
+
|
1138 |
+
# Encoder
|
1139 |
+
encoder_output = self.encoder(encoder_input, encoder_padding_mask)
|
1140 |
+
|
1141 |
+
return encoder_output
|
1142 |
+
|
1143 |
+
def forward_text_encoder(self, src_tokens):
|
1144 |
+
# Text Encoder Prenet
|
1145 |
+
encoder_input, encoder_padding_mask = self.text_encoder_prenet(src_tokens)
|
1146 |
+
|
1147 |
+
# Encoder
|
1148 |
+
encoder_output = self.encoder(encoder_input, encoder_padding_mask)
|
1149 |
+
|
1150 |
+
return encoder_output
|
1151 |
+
|
1152 |
+
def forward_decoder(self, tokens, encoder_out, incremental_state):
|
1153 |
+
# Decoder Prenet
|
1154 |
+
prev_output_tokens, tgt_mask, incremental_state = self.text_decoder_prenet(tokens, incremental_state)
|
1155 |
+
|
1156 |
+
# Decoder
|
1157 |
+
decoder_output, extra = self.decoder(
|
1158 |
+
prev_output_tokens,
|
1159 |
+
tgt_mask,
|
1160 |
+
encoder_out=encoder_out,
|
1161 |
+
incremental_state=incremental_state,
|
1162 |
+
)
|
1163 |
+
|
1164 |
+
# Decoder Postnet
|
1165 |
+
return self.text_decoder_postnet(decoder_output), extra
|
1166 |
+
|
1167 |
+
def set_num_updates(self, num_updates):
|
1168 |
+
"""Set the number of parameters updates."""
|
1169 |
+
super().set_num_updates(num_updates)
|
1170 |
+
self.num_updates = num_updates
|
1171 |
+
|
1172 |
+
def generate_class(self, source, prev_output_tokens, **kwargs):
|
1173 |
+
encoder_out = self.forward_encoder(source, padding_mask=kwargs["padding_mask"])
|
1174 |
+
|
1175 |
+
prev_output_tokens, tgt_mask, _ = self.text_decoder_prenet(prev_output_tokens, {})
|
1176 |
+
prev_output_tokens = torch.zeros_like(prev_output_tokens) # s2c use zero vector as [CLS]
|
1177 |
+
|
1178 |
+
decoder_output, extra = self.decoder(
|
1179 |
+
prev_output_tokens,
|
1180 |
+
tgt_mask,
|
1181 |
+
encoder_out=encoder_out,
|
1182 |
+
)
|
1183 |
+
|
1184 |
+
decoder_out, embed = self.speaker_decoder_postnet(decoder_output.mean(1))
|
1185 |
+
|
1186 |
+
pred_class = decoder_out.argmax(1)
|
1187 |
+
return pred_class
|
1188 |
+
|
1189 |
+
def generate_speech(self, source=None, src_tokens=None, spkembs=None, **kwargs):
|
1190 |
+
assert source is not None or src_tokens is not None
|
1191 |
+
|
1192 |
+
threshold = kwargs.get("threshold", 0.5)
|
1193 |
+
minlenratio = kwargs.get("threshold", 0.0)
|
1194 |
+
|
1195 |
+
if source is None:
|
1196 |
+
assert src_tokens.size(0) == 1
|
1197 |
+
encoder_out = self.forward_text_encoder(src_tokens)
|
1198 |
+
maxlenratio = kwargs.get("threshold", 20.0)
|
1199 |
+
else:
|
1200 |
+
assert source.size(0) == 1
|
1201 |
+
encoder_out = self.forward_encoder(source, padding_mask=kwargs["padding_mask"])
|
1202 |
+
maxlenratio = kwargs.get("threshold", 10.0)
|
1203 |
+
|
1204 |
+
if spkembs is not None and self.spk_embed_integration_type != "pre":
|
1205 |
+
encoder_out["encoder_out"] = [self._integrate_with_spk_embed(
|
1206 |
+
encoder_out["encoder_out"][0].transpose(0, 1), spkembs
|
1207 |
+
).transpose(0, 1)]
|
1208 |
+
spkembs = None
|
1209 |
+
|
1210 |
+
maxlen = int(encoder_out["encoder_out"][0].size(0) * maxlenratio / self.reduction_factor)
|
1211 |
+
minlen = int(encoder_out["encoder_out"][0].size(0) * minlenratio / self.reduction_factor)
|
1212 |
+
|
1213 |
+
idx = 0
|
1214 |
+
ys = encoder_out["encoder_out"][0].new_zeros(1, 1, self.speech_decoder_postnet.odim)
|
1215 |
+
outs, probs = [], []
|
1216 |
+
|
1217 |
+
# forward decoder step-by-step
|
1218 |
+
if isinstance(self.decoder, FairseqIncrementalDecoder):
|
1219 |
+
incremental_states = {}
|
1220 |
+
else:
|
1221 |
+
incremental_states = None
|
1222 |
+
attns = []
|
1223 |
+
while True:
|
1224 |
+
# update index
|
1225 |
+
idx += 1
|
1226 |
+
# calculate output and stop prob at idx-th step
|
1227 |
+
decoder_in, _ = self.speech_decoder_prenet(ys, spkembs=spkembs)
|
1228 |
+
z, extra = self.decoder(decoder_in[:,-1:], None, encoder_out, incremental_states, alignment_layer=-1)
|
1229 |
+
outs += [self.speech_decoder_postnet.feat_out(z[0, -1]).view(self.reduction_factor, self.speech_decoder_postnet.odim)] # [(r, odim), ...]
|
1230 |
+
probs += [torch.sigmoid(self.speech_decoder_postnet.prob_out(z[0, -1]))] # [(r), ...]
|
1231 |
+
|
1232 |
+
# update next inputs
|
1233 |
+
ys = torch.cat((ys, outs[-1][-1].view(1, 1, self.speech_decoder_postnet.odim)), dim=1) # (1, idx + 1, odim)
|
1234 |
+
attns.append(torch.stack([att_l[0] for att_l in extra['attn'][0]], dim=0))
|
1235 |
+
# check whether to finish generation
|
1236 |
+
if int(sum(probs[-1] >= threshold)) > 0 or idx >= maxlen:
|
1237 |
+
# check mininum length
|
1238 |
+
if idx < minlen:
|
1239 |
+
continue
|
1240 |
+
outs = (torch.cat(outs, dim=0).unsqueeze(0).transpose(1, 2)) # (L, odim) -> (1, L, odim) -> (1, odim, L)
|
1241 |
+
if self.speech_decoder_postnet.postnet is not None:
|
1242 |
+
outs = outs + self.speech_decoder_postnet.postnet(outs) # (1, odim, L)
|
1243 |
+
outs = outs.transpose(2, 1).squeeze(0) # (L, odim)
|
1244 |
+
probs = torch.cat(probs, dim=0)
|
1245 |
+
attn = torch.cat(attns, dim=2)
|
1246 |
+
break
|
1247 |
+
|
1248 |
+
if outs.size(0) == maxlen:
|
1249 |
+
logging.warning("output length reaches maximum length")
|
1250 |
+
return outs, probs, attn
|
1251 |
+
|
1252 |
+
|
1253 |
+
@register_model_architecture(model_name="artst_transformer", arch_name="artst_transformer")
|
1254 |
+
def base_architecture(args):
|
1255 |
+
# Transformer
|
1256 |
+
args.bert_init = getattr(args, "bert_init", False)
|
1257 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
|
1258 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 768 * 4)
|
1259 |
+
args.encoder_layers = getattr(args, "encoder_layers", 12)
|
1260 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
|
1261 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
1262 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
|
1263 |
+
args.decoder_ffn_embed_dim = getattr(
|
1264 |
+
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
|
1265 |
+
)
|
1266 |
+
args.decoder_layers = getattr(args, "decoder_layers", 6)
|
1267 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 12)
|
1268 |
+
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
|
1269 |
+
args.dropout = getattr(args, "dropout", 0.1)
|
1270 |
+
args.attention_dropout = getattr(args, "attention_dropout", args.dropout)
|
1271 |
+
args.activation_dropout = getattr(args, "activation_dropout", args.dropout)
|
1272 |
+
args.activation_fn = getattr(args, "activation_fn", "gelu")
|
1273 |
+
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
|
1274 |
+
args.decoder_output_dim = getattr(
|
1275 |
+
args, "decoder_output_dim", args.decoder_embed_dim
|
1276 |
+
)
|
1277 |
+
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
|
1278 |
+
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
|
1279 |
+
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
|
1280 |
+
args.max_text_positions = getattr(args, "max_text_positions", DEFAULT_MAX_TEXT_POSITIONS)
|
1281 |
+
args.max_speech_positions = getattr(args, "max_speech_positions", DEFAULT_MAX_SPEECH_POSITIONS)
|
1282 |
+
|
1283 |
+
# Espnet related, including prenet, postnet
|
1284 |
+
args.eprenet_conv_layers = getattr(args, "eprenet_conv_layers", 0)
|
1285 |
+
args.eprenet_conv_filts = getattr(args, "eprenet_conv_filts", 0)
|
1286 |
+
args.eprenet_conv_chans = getattr(args, "eprenet_conv_chans", 0)
|
1287 |
+
args.use_batch_norm = getattr(args, "use_batch_norm", True)
|
1288 |
+
args.eprenet_dropout_rate = getattr(args, "eprenet_dropout_rate", 0.0)
|
1289 |
+
args.enc_use_scaled_pos_enc = getattr(args, "enc_use_scaled_pos_enc", True)
|
1290 |
+
args.dec_use_scaled_pos_enc = getattr(args, "dec_use_scaled_pos_enc", True)
|
1291 |
+
args.postnet_layers = getattr(args, "postnet_layers", 5)
|
1292 |
+
args.postnet_chans = getattr(args, "postnet_chans", 256)
|
1293 |
+
args.postnet_filts = getattr(args, "postnet_filts", 5)
|
1294 |
+
args.postnet_dropout_rate = getattr(args, "postnet_dropout_rate", 0.5)
|
1295 |
+
args.dprenet_dropout_rate = getattr(args, "dprenet_dropout_rate", 0.5)
|
1296 |
+
args.dprenet_layers = getattr(args, "dprenet_layers", 2)
|
1297 |
+
args.dprenet_units = getattr(args, "dprenet_units", 256)
|
1298 |
+
args.initial_encoder_alpha = getattr(args, "initial_encoder_alpha", 1.0)
|
1299 |
+
args.initial_decoder_alpha = getattr(args, "initial_decoder_alpha", 1.0)
|
1300 |
+
args.spk_embed_integration_type = getattr(args, "spk_embed_integration_type", "pre")
|
1301 |
+
args.spk_embed_dim = getattr(args, "spk_embed_dim", 512)
|
1302 |
+
args.encoder_reduction_factor = getattr(args, "encoder_reduction_factor", 1)
|
1303 |
+
args.reduction_factor = getattr(args, "reduction_factor", 2)
|
1304 |
+
args.transformer_enc_positional_dropout_rate = getattr(args, "transformer_enc_positional_dropout_rate", 0.1)
|
1305 |
+
args.transformer_dec_positional_dropout_rate = getattr(args, "transformer_dec_positional_dropout_rate", 0.1)
|
1306 |
+
args.layer_norm_eps = getattr(args, "layer_norm_eps", 1e-5)
|
1307 |
+
args.no_scale_embedding = getattr(args, "no_scale_embedding", True)
|
1308 |
+
# Convolutional subsampler
|
1309 |
+
args.encoder_speech_prenet = getattr(args, "encoder_speech_prenet", "conv")
|
1310 |
+
args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5")
|
1311 |
+
args.conv_channels = getattr(args, "conv_channels", 1024)
|
1312 |
+
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
|
1313 |
+
|
1314 |
+
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
|
1315 |
+
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
|
1316 |
+
args.no_token_positional_embeddings = getattr(
|
1317 |
+
args, "no_token_positional_embeddings", False
|
1318 |
+
)
|
1319 |
+
args.adaptive_input = getattr(args, "adaptive_input", False)
|
1320 |
+
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
|
1321 |
+
args.share_input_output_embed = getattr(args, "share_input_output_embed", False)
|
1322 |
+
args.share_ctc_embed = getattr(args, "share_ctc_embed", False)
|
1323 |
+
args.freeze_encoder_updates = getattr(args, "freeze_encoder_updates", 0)
|
1324 |
+
args.freeze_decoder_updates = getattr(args, "freeze_decoder_updates", 0)
|
1325 |
+
args.no_freeze_encoder_layer = getattr(args, "no_freeze_encoder_layer", None)
|
1326 |
+
|
1327 |
+
## sid
|
1328 |
+
args.sid_embed_dim = getattr(args, "sid_embed_dim", 128)
|
1329 |
+
args.sid_pooling_layer = getattr(args, "sid_pooling_layer", "decoder")
|
1330 |
+
args.softmax_scale = getattr(args, "softmax_scale", 1)
|
1331 |
+
args.softmax_margin = getattr(args, "softmax_margin", 0)
|
1332 |
+
args.softmax_easy_margin = getattr(args, "softmax_easy_margin", False)
|
1333 |
+
args.modules_filter = getattr(args, "modules_filter", None)
|
1334 |
+
|
1335 |
+
## Hubert
|
1336 |
+
args.conv_pos = getattr(args, "conv_pos", 128)
|
1337 |
+
args.conv_pos_groups = getattr(args, "conv_pos_groups", 16)
|
1338 |
+
args.target_glu = getattr(args, "target_glu", False)
|
1339 |
+
args.logit_temp = getattr(args, "logit_temp", 0.1)
|
1340 |
+
args.final_dim = getattr(args, "final_dim", 256)
|
1341 |
+
args.untie_final_proj = getattr(args, "untie_final_proj", True)
|
1342 |
+
args.feature_grad_mult = getattr(args, "feature_grad_mult", 0.1)
|
1343 |
+
args.use_sent_enc_layer = getattr(args, "use_sent_enc_layer", True)
|
1344 |
+
# hubert feature extractor
|
1345 |
+
args.extractor_mode = getattr(args, "extractor_mode", "default")
|
1346 |
+
args.conv_feature_layers = getattr(args, "conv_feature_layers", "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2")
|
1347 |
+
args.conv_bias = getattr(args, "conv_bias", False)
|
1348 |
+
# mask
|
1349 |
+
args.hubert_mask_length = getattr(args, "hubert_mask_length", 10)
|
1350 |
+
args.mask_prob = getattr(args, "mask_prob", 0.0)
|
1351 |
+
args.mask_selection = getattr(args, "mask_selection", "static")
|
1352 |
+
args.mask_other = getattr(args, "mask_other", 0)
|
1353 |
+
args.no_mask_overlap = getattr(args, "no_mask_overlap", False)
|
1354 |
+
args.mask_min_space = getattr(args, "mask_min_space", 1)
|
1355 |
+
# channel mask
|
1356 |
+
args.mask_channel_length = getattr(args, "mask_channel_length", 10)
|
1357 |
+
args.mask_channel_prob = getattr(args, "mask_channel_prob", 0.0)
|
1358 |
+
args.mask_channel_selection = getattr(args, "mask_channel_selection", "static")
|
1359 |
+
args.mask_channel_other = getattr(args, "mask_channel_other", 0)
|
1360 |
+
args.no_mask_channel_overlap = getattr(args, "no_mask_channel_overlap", False)
|
1361 |
+
args.mask_channel_min_space = getattr(args, "mask_channel_min_space", 1)
|
1362 |
+
# loss computation
|
1363 |
+
args.skip_masked = getattr(args, "skip_masked", False)
|
1364 |
+
args.skip_nomask = getattr(args, "skip_nomask", False)
|
1365 |
+
# conv Pos
|
1366 |
+
args.use_conv_pos = getattr(args, "use_conv_pos", False)
|
1367 |
+
args.use_sinc_pos = getattr(args, "use_sinc_pos", False)
|
1368 |
+
|
1369 |
+
# codebook
|
1370 |
+
args.use_codebook = getattr(args, "use_codebook", False)
|
1371 |
+
args.latent_vars = getattr(args, "latent_vars", 100)
|
1372 |
+
args.latent_groups = getattr(args, "latent_groups", 2)
|
1373 |
+
args.latent_dim = getattr(args, "latent_dim", 0)
|
1374 |
+
args.latent_temp = getattr(args, "latent_temp", (2, 0.5, 0.999995))
|
1375 |
+
args.quantizer_depth = getattr(args, "quantizer_depth", 1)
|
1376 |
+
args.quantizer_factor = getattr(args, "quantizer_factor", 3)
|
1377 |
+
args.codebook_prob = getattr(args, "codebook_prob", 0.5)
|
1378 |
+
|
1379 |
+
# Relative pos embed
|
1380 |
+
args.relative_position_embedding = getattr(args, "relative_position_embedding", False)
|
1381 |
+
args.num_buckets = getattr(args, "num_buckets", 320)
|
1382 |
+
args.max_distance = getattr(args, "max_distance", 1280)
|
1383 |
+
args.encoder_max_relative_position = getattr(args, "encoder_max_relative_position", 160)
|
1384 |
+
args.decoder_max_relative_position = getattr(args, "decoder_max_relative_position", 160)
|
1385 |
+
|
1386 |
+
@register_model_architecture("artst_transformer", "artst_transformer_base")
|
1387 |
+
def artst_transformer_base(args):
|
1388 |
+
args.use_conv_pos = getattr(args, "use_conv_pos", True)
|
1389 |
+
args.use_sinc_pos = getattr(args, "use_sinc_pos", True)
|
1390 |
+
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
|
1391 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
1392 |
+
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
|
1393 |
+
args.layer_norm_first = getattr(args, "layer_norm_first", False)
|
1394 |
+
args.relative_position_embedding = getattr(args, "relative_position_embedding", True)
|
1395 |
+
args.dropout = getattr(args, "dropout", 0.1)
|
1396 |
+
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
|
1397 |
+
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
|
1398 |
+
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.05)
|
1399 |
+
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.05)
|
1400 |
+
args.mask_prob = getattr(args, "mask_prob", 0.80)
|
1401 |
+
base_architecture(args)
|
1402 |
+
|
1403 |
+
@register_model_architecture("artst_transformer", "artst_transformer_large")
|
1404 |
+
def artst_transformer_large(args):
|
1405 |
+
args.use_conv_pos = getattr(args, "use_conv_pos", True)
|
1406 |
+
args.use_sinc_pos = getattr(args, "use_sinc_pos", True)
|
1407 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
1408 |
+
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
|
1409 |
+
args.layer_norm_first = getattr(args, "layer_norm_first", True)
|
1410 |
+
args.relative_position_embedding = getattr(args, "relative_position_embedding", True)
|
1411 |
+
args.dropout = getattr(args, "dropout", 0.0)
|
1412 |
+
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
|
1413 |
+
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
|
1414 |
+
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
|
1415 |
+
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
|
1416 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
|
1417 |
+
args.encoder_layers = getattr(args, "encoder_layers", 24)
|
1418 |
+
args.decoder_layers = getattr(args, "decoder_layers", 6)
|
1419 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
|
1420 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
|
1421 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
|
1422 |
+
args.feature_grad_mult = getattr(args, "feature_grad_mult", 1.0)
|
1423 |
+
args.extractor_mode = getattr(args, "extractor_mode", "layer_norm")
|
1424 |
+
args.final_dim = getattr(args, "final_dim", 768)
|
1425 |
+
args.mask_prob = getattr(args, "mask_prob", 0.80)
|
1426 |
+
base_architecture(args)
|
1427 |
+
|
1428 |
+
@register_model_architecture("artst_transformer", "artst_transformer_base_asr")
|
1429 |
+
def artst_transformer_base_asr(args):
|
1430 |
+
args.use_conv_pos = getattr(args, "use_conv_pos", True)
|
1431 |
+
args.use_sinc_pos = getattr(args, "use_sinc_pos", True)
|
1432 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
1433 |
+
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
|
1434 |
+
args.layer_norm_first = getattr(args, "layer_norm_first", False)
|
1435 |
+
args.relative_position_embedding = getattr(args, "relative_position_embedding", True)
|
1436 |
+
args.dropout = getattr(args, "dropout", 0.1)
|
1437 |
+
args.activation_dropout = getattr(args, "activation_dropout", 0.1)
|
1438 |
+
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
|
1439 |
+
args.feature_grad_mult = getattr(args, "feature_grad_mult", 0.0)
|
1440 |
+
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.1)
|
1441 |
+
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.1)
|
1442 |
+
args.mask_prob = getattr(args, "mask_prob", 0.75)
|
1443 |
+
args.mask_selection = getattr(args, "mask_selection", "static")
|
1444 |
+
args.mask_channel_length = getattr(args, "mask_channel_length", 64)
|
1445 |
+
args.mask_channel_prob = getattr(args, "mask_channel_prob", 0.5)
|
1446 |
+
args.mask_channel_selection = getattr(args, "mask_channel_selection", "static")
|
1447 |
+
args.max_text_positions = getattr(args, "max_text_positions", 600)
|
1448 |
+
base_architecture(args)
|
artst/models/modules/__init__.py
ADDED
File without changes
|
artst/models/modules/decoder.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
|
5 |
+
# Based on speecht5, fairseq and espnet code bases
|
6 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
7 |
+
# --------------------------------------------------------
|
8 |
+
|
9 |
+
from typing import Any, Dict, List, Optional
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from fairseq import utils
|
14 |
+
from fairseq.distributed import fsdp_wrap
|
15 |
+
from fairseq.models import (
|
16 |
+
FairseqIncrementalDecoder,
|
17 |
+
)
|
18 |
+
from fairseq.modules import (
|
19 |
+
FairseqDropout,
|
20 |
+
LayerDropModuleList,
|
21 |
+
LayerNorm,
|
22 |
+
)
|
23 |
+
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
|
24 |
+
from torch import Tensor
|
25 |
+
|
26 |
+
from .encoder import RelativePositionalEncoding
|
27 |
+
from .transformer_layer import TransformerDecoderLayer
|
28 |
+
|
29 |
+
DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8)
|
30 |
+
|
31 |
+
|
32 |
+
class TransformerDecoder(FairseqIncrementalDecoder):
|
33 |
+
"""
|
34 |
+
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
|
35 |
+
is a :class:`TransformerDecoderLayer`.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
args (argparse.Namespace): parsed command-line arguments
|
39 |
+
dictionary (~fairseq.data.Dictionary): decoding dictionary
|
40 |
+
embed_tokens (torch.nn.Embedding): output embedding
|
41 |
+
no_encoder_attn (bool, optional): whether to attend to encoder outputs
|
42 |
+
(default: False).
|
43 |
+
"""
|
44 |
+
|
45 |
+
def __init__(
|
46 |
+
self,
|
47 |
+
args,
|
48 |
+
no_encoder_attn=False,
|
49 |
+
):
|
50 |
+
self.args = args
|
51 |
+
super().__init__(None)
|
52 |
+
self.register_buffer("version", torch.Tensor([3]))
|
53 |
+
self._future_mask = torch.empty(0)
|
54 |
+
|
55 |
+
self.dropout_module = FairseqDropout(
|
56 |
+
args.dropout, module_name=self.__class__.__name__
|
57 |
+
)
|
58 |
+
self.decoder_layerdrop = args.decoder_layerdrop
|
59 |
+
# self.max_s_positions = args.max_target_positions
|
60 |
+
export = getattr(args, "export", False)
|
61 |
+
self.cross_self_attention = getattr(args, "cross_self_attention", False)
|
62 |
+
|
63 |
+
if self.decoder_layerdrop > 0.0:
|
64 |
+
self.layers = LayerDropModuleList(p=self.decoder_layerdrop)
|
65 |
+
else:
|
66 |
+
self.layers = nn.ModuleList([])
|
67 |
+
self.layers.extend(
|
68 |
+
[
|
69 |
+
self.build_decoder_layer(args, no_encoder_attn)
|
70 |
+
for _ in range(args.decoder_layers)
|
71 |
+
]
|
72 |
+
)
|
73 |
+
self.num_layers = len(self.layers)
|
74 |
+
|
75 |
+
if args.decoder_normalize_before and not getattr(
|
76 |
+
args, "no_decoder_final_norm", False
|
77 |
+
):
|
78 |
+
self.layer_norm = LayerNorm(args.decoder_embed_dim, eps=args.layer_norm_eps, export=export)
|
79 |
+
else:
|
80 |
+
self.layer_norm = None
|
81 |
+
|
82 |
+
if args.relative_position_embedding:
|
83 |
+
self.pos_emb = RelativePositionalEncoding(args.encoder_embed_dim//args.encoder_attention_heads, args.decoder_max_relative_position)
|
84 |
+
|
85 |
+
def build_decoder_layer(self, args, no_encoder_attn=False):
|
86 |
+
layer = TransformerDecoderLayer(args, no_encoder_attn=no_encoder_attn, has_relative_attention_bias=args.relative_position_embedding)
|
87 |
+
checkpoint = getattr(args, "checkpoint_activations", False)
|
88 |
+
if checkpoint:
|
89 |
+
offload_to_cpu = getattr(args, "offload_activations", False)
|
90 |
+
layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
|
91 |
+
# if we are checkpointing, enforce that FSDP always wraps the
|
92 |
+
# checkpointed layer, regardless of layer size
|
93 |
+
min_params_to_wrap = (
|
94 |
+
getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP)
|
95 |
+
if not checkpoint
|
96 |
+
else 0
|
97 |
+
)
|
98 |
+
layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
|
99 |
+
return layer
|
100 |
+
|
101 |
+
def forward(
|
102 |
+
self,
|
103 |
+
prev_output_tokens,
|
104 |
+
tgt_mask,
|
105 |
+
encoder_out: Optional[Dict[str, List[Tensor]]] = None,
|
106 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
107 |
+
full_context_alignment: bool = False,
|
108 |
+
alignment_layer: Optional[int] = None,
|
109 |
+
alignment_heads: Optional[int] = None,
|
110 |
+
src_lengths: Optional[Any] = None,
|
111 |
+
return_all_hiddens: bool = False,
|
112 |
+
):
|
113 |
+
"""
|
114 |
+
Args:
|
115 |
+
prev_output_tokens (LongTensor): previous decoder outputs of shape
|
116 |
+
`(batch, tgt_len)`, for teacher forcing
|
117 |
+
encoder_out (optional): output from the encoder, used for
|
118 |
+
encoder-side attention, should be of size T x B x C
|
119 |
+
incremental_state (dict): dictionary used for storing state during
|
120 |
+
:ref:`Incremental decoding`
|
121 |
+
features_only (bool, optional): only return features without
|
122 |
+
applying output layer (default: False).
|
123 |
+
full_context_alignment (bool, optional): don't apply
|
124 |
+
auto-regressive mask to self-attention (default: False).
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
tuple:
|
128 |
+
- the decoder's output of shape `(batch, tgt_len, vocab)`
|
129 |
+
- a dictionary with any model-specific outputs
|
130 |
+
"""
|
131 |
+
|
132 |
+
x, extra = self.extract_features(
|
133 |
+
prev_output_tokens,
|
134 |
+
tgt_mask,
|
135 |
+
encoder_out=encoder_out,
|
136 |
+
incremental_state=incremental_state,
|
137 |
+
full_context_alignment=full_context_alignment,
|
138 |
+
alignment_layer=alignment_layer,
|
139 |
+
alignment_heads=alignment_heads,
|
140 |
+
)
|
141 |
+
|
142 |
+
return x, extra
|
143 |
+
|
144 |
+
def extract_features(
|
145 |
+
self,
|
146 |
+
prev_output_tokens,
|
147 |
+
tgt_mask,
|
148 |
+
encoder_out: Optional[Dict[str, List[Tensor]]],
|
149 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
150 |
+
full_context_alignment: bool = False,
|
151 |
+
alignment_layer: Optional[int] = None,
|
152 |
+
alignment_heads: Optional[int] = None,
|
153 |
+
):
|
154 |
+
return self.extract_features_scriptable(
|
155 |
+
prev_output_tokens,
|
156 |
+
tgt_mask,
|
157 |
+
encoder_out,
|
158 |
+
incremental_state,
|
159 |
+
full_context_alignment,
|
160 |
+
alignment_layer,
|
161 |
+
alignment_heads,
|
162 |
+
)
|
163 |
+
|
164 |
+
"""
|
165 |
+
A scriptable subclass of this class has an extract_features method and calls
|
166 |
+
super().extract_features, but super() is not supported in torchscript. A copy of
|
167 |
+
this function is made to be used in the subclass instead.
|
168 |
+
"""
|
169 |
+
|
170 |
+
def extract_features_scriptable(
|
171 |
+
self,
|
172 |
+
prev_output_tokens,
|
173 |
+
tgt_mask,
|
174 |
+
encoder_out: Optional[Dict[str, List[Tensor]]],
|
175 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
176 |
+
full_context_alignment: bool = False,
|
177 |
+
alignment_layer: Optional[int] = None,
|
178 |
+
alignment_heads: Optional[int] = None,
|
179 |
+
):
|
180 |
+
"""
|
181 |
+
Similar to *forward* but only return features.
|
182 |
+
|
183 |
+
Includes several features from "Jointly Learning to Align and
|
184 |
+
Translate with Transformer Models" (Garg et al., EMNLP 2019).
|
185 |
+
|
186 |
+
Args:
|
187 |
+
full_context_alignment (bool, optional): don't apply
|
188 |
+
auto-regressive mask to self-attention (default: False).
|
189 |
+
alignment_layer (int, optional): return mean alignment over
|
190 |
+
heads at this layer (default: last layer).
|
191 |
+
alignment_heads (int, optional): only average alignment over
|
192 |
+
this many heads (default: all heads).
|
193 |
+
|
194 |
+
Returns:
|
195 |
+
tuple:
|
196 |
+
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
|
197 |
+
- a dictionary with any model-specific outputs
|
198 |
+
"""
|
199 |
+
bs = prev_output_tokens.size(0)
|
200 |
+
if alignment_layer is None:
|
201 |
+
alignment_layer = self.num_layers - 1
|
202 |
+
|
203 |
+
enc: Optional[Tensor] = None
|
204 |
+
padding_mask: Optional[Tensor] = None
|
205 |
+
if encoder_out is not None and len(encoder_out["encoder_out"]) > 0:
|
206 |
+
enc = encoder_out["encoder_out"][0]
|
207 |
+
assert (
|
208 |
+
enc.size()[1] == bs
|
209 |
+
), f"Expected enc.shape == (t, {bs}, c) got {enc.shape}"
|
210 |
+
if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0:
|
211 |
+
padding_mask = encoder_out["encoder_padding_mask"][0]
|
212 |
+
|
213 |
+
# B x T x C -> T x B x C
|
214 |
+
x = prev_output_tokens.transpose(0, 1)
|
215 |
+
|
216 |
+
self_attn_padding_mask: Optional[Tensor] = None
|
217 |
+
if self.cross_self_attention or tgt_mask is not None:
|
218 |
+
self_attn_padding_mask = tgt_mask
|
219 |
+
|
220 |
+
## relative position embedding
|
221 |
+
if self.args.relative_position_embedding:
|
222 |
+
x_len = x.shape[0]
|
223 |
+
pos_seq = torch.arange(0, x_len).long().to(x.device)
|
224 |
+
pos_seq = pos_seq[:, None] - pos_seq[None, :]
|
225 |
+
pos_k, pos_v = self.pos_emb(pos_seq)
|
226 |
+
else:
|
227 |
+
pos_k = None
|
228 |
+
|
229 |
+
# decoder layers
|
230 |
+
attn_list = []
|
231 |
+
attn: Optional[Tensor] = None
|
232 |
+
inner_states: List[Optional[Tensor]] = [x]
|
233 |
+
for idx, layer in enumerate(self.layers):
|
234 |
+
if incremental_state is None and not full_context_alignment:
|
235 |
+
self_attn_mask = self.buffered_future_mask(x)
|
236 |
+
else:
|
237 |
+
self_attn_mask = None
|
238 |
+
|
239 |
+
x, layer_attn, _ = layer(
|
240 |
+
x,
|
241 |
+
enc,
|
242 |
+
padding_mask,
|
243 |
+
incremental_state,
|
244 |
+
self_attn_mask=self_attn_mask,
|
245 |
+
self_attn_padding_mask=self_attn_padding_mask,
|
246 |
+
need_attn=bool((idx == alignment_layer or alignment_layer == -1)),
|
247 |
+
need_head_weights=bool((idx == alignment_layer or alignment_layer == -1)),
|
248 |
+
pos_bias=pos_k,
|
249 |
+
)
|
250 |
+
inner_states.append(x)
|
251 |
+
if layer_attn is not None and (idx == alignment_layer or alignment_layer == -1):
|
252 |
+
attn = layer_attn.float().to(x)
|
253 |
+
attn_list.append(attn.transpose(0, 1))
|
254 |
+
|
255 |
+
if attn is not None and len(attn_list) == 1:
|
256 |
+
if alignment_heads is not None:
|
257 |
+
attn = attn[:alignment_heads]
|
258 |
+
|
259 |
+
# average probabilities over heads
|
260 |
+
attn = attn.mean(dim=0)
|
261 |
+
|
262 |
+
if self.layer_norm is not None:
|
263 |
+
x = self.layer_norm(x)
|
264 |
+
|
265 |
+
# T x B x C -> B x T x C
|
266 |
+
x = x.transpose(0, 1)
|
267 |
+
|
268 |
+
return x, {"attn": [attn if len(attn_list) <= 1 else attn_list], "inner_states": inner_states}
|
269 |
+
|
270 |
+
# def max_positions(self):
|
271 |
+
# """Maximum output length supported by the decoder."""
|
272 |
+
# return self.max_target_positions
|
273 |
+
|
274 |
+
def buffered_future_mask(self, tensor):
|
275 |
+
dim = tensor.size(0)
|
276 |
+
# self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround.
|
277 |
+
if (
|
278 |
+
self._future_mask.size(0) == 0
|
279 |
+
or (not self._future_mask.device == tensor.device)
|
280 |
+
or self._future_mask.size(0) < dim
|
281 |
+
):
|
282 |
+
self._future_mask = torch.triu(
|
283 |
+
utils.fill_with_neg_inf(torch.zeros([dim, dim], device=tensor.device)), 1,
|
284 |
+
)
|
285 |
+
else:
|
286 |
+
self._future_mask = self._future_mask.to(tensor)
|
287 |
+
return self._future_mask[:dim, :dim]
|
288 |
+
|
289 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
290 |
+
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
|
291 |
+
for i in range(self.num_layers):
|
292 |
+
# update layer norms
|
293 |
+
layer_norm_map = {
|
294 |
+
"0": "self_attn_layer_norm",
|
295 |
+
"1": "encoder_attn_layer_norm",
|
296 |
+
"2": "final_layer_norm",
|
297 |
+
}
|
298 |
+
for old, new in layer_norm_map.items():
|
299 |
+
for m in ("weight", "bias"):
|
300 |
+
k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m)
|
301 |
+
if k in state_dict:
|
302 |
+
state_dict[
|
303 |
+
"{}.layers.{}.{}.{}".format(name, i, new, m)
|
304 |
+
] = state_dict[k]
|
305 |
+
del state_dict[k]
|
306 |
+
|
307 |
+
version_key = "{}.version".format(name)
|
308 |
+
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2:
|
309 |
+
# earlier checkpoints did not normalize after the stack of layers
|
310 |
+
self.layer_norm = None
|
311 |
+
self.normalize = False
|
312 |
+
state_dict[version_key] = torch.Tensor([1])
|
313 |
+
|
314 |
+
return state_dict
|
315 |
+
|
316 |
+
def set_num_updates(self, num_updates):
|
317 |
+
"""State from trainer to pass along to model at every update."""
|
318 |
+
|
319 |
+
def _apply(m):
|
320 |
+
if hasattr(m, "set_num_updates") and m != self:
|
321 |
+
m.set_num_updates(num_updates)
|
322 |
+
|
323 |
+
self.apply(_apply)
|
artst/models/modules/encoder.py
ADDED
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
|
5 |
+
# Based on speecht5, fairseq and espnet code bases
|
6 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
7 |
+
# --------------------------------------------------------
|
8 |
+
|
9 |
+
from typing import Dict, List
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import contextlib
|
15 |
+
from fairseq import utils
|
16 |
+
from fairseq.models import (
|
17 |
+
FairseqEncoder,
|
18 |
+
)
|
19 |
+
from fairseq.modules import (
|
20 |
+
FairseqDropout,
|
21 |
+
LayerNorm,
|
22 |
+
TransformerEncoderLayer,
|
23 |
+
)
|
24 |
+
from torch import Tensor
|
25 |
+
from .transformer_layer import TransformerSentenceEncoderLayer
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8)
|
30 |
+
|
31 |
+
def Linear(in_features, out_features, bias=True):
|
32 |
+
m = nn.Linear(in_features, out_features, bias)
|
33 |
+
nn.init.xavier_uniform_(m.weight)
|
34 |
+
if bias:
|
35 |
+
nn.init.constant_(m.bias, 0.0)
|
36 |
+
return m
|
37 |
+
|
38 |
+
|
39 |
+
class RelativePositionalEncoding(torch.nn.Module):
|
40 |
+
def __init__(self, d_model, maxlen=1000, embed_v=False):
|
41 |
+
super(RelativePositionalEncoding, self).__init__()
|
42 |
+
|
43 |
+
self.d_model = d_model
|
44 |
+
self.maxlen = maxlen
|
45 |
+
self.pe_k = torch.nn.Embedding(2*maxlen, d_model)
|
46 |
+
if embed_v:
|
47 |
+
self.pe_v = torch.nn.Embedding(2*maxlen, d_model)
|
48 |
+
self.embed_v = embed_v
|
49 |
+
|
50 |
+
|
51 |
+
def forward(self, pos_seq):
|
52 |
+
pos_seq[pos_seq < -self.maxlen] = -self.maxlen
|
53 |
+
pos_seq[pos_seq >= self.maxlen] = self.maxlen - 1
|
54 |
+
pos_seq = pos_seq + self.maxlen
|
55 |
+
if self.embed_v:
|
56 |
+
return self.pe_k(pos_seq), self.pe_v(pos_seq)
|
57 |
+
else:
|
58 |
+
return self.pe_k(pos_seq), None
|
59 |
+
|
60 |
+
class TransformerEncoder(FairseqEncoder):
|
61 |
+
"""
|
62 |
+
Transformer encoder consisting of *args.encoder_layers* layers. Each layer
|
63 |
+
is a :class:`TransformerEncoderLayer`.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
args (argparse.Namespace): parsed command-line arguments
|
67 |
+
dictionary (~fairseq.data.Dictionary): encoding dictionary
|
68 |
+
embed_tokens (torch.nn.Embedding): input embedding
|
69 |
+
"""
|
70 |
+
|
71 |
+
def __init__(self, args, tgt_dict=None, embed_tokens=None):
|
72 |
+
self.args = args
|
73 |
+
super().__init__(None)
|
74 |
+
self.register_buffer("version", torch.Tensor([3]))
|
75 |
+
|
76 |
+
self.dropout_module = FairseqDropout(
|
77 |
+
args.dropout, module_name=self.__class__.__name__
|
78 |
+
)
|
79 |
+
self.encoder_layerdrop = args.encoder_layerdrop
|
80 |
+
self.freeze_encoder_updates = args.freeze_encoder_updates
|
81 |
+
if args.no_freeze_encoder_layer is not None:
|
82 |
+
self.no_freeze_encoder_layer = eval(args.no_freeze_encoder_layer)
|
83 |
+
else:
|
84 |
+
self.no_freeze_encoder_layer = None
|
85 |
+
self.num_updates = 0
|
86 |
+
export = getattr(args, "export", False)
|
87 |
+
|
88 |
+
self.layers = nn.ModuleList([])
|
89 |
+
self.layers.extend(
|
90 |
+
[self.build_encoder_layer(args) for i in range(args.encoder_layers)]
|
91 |
+
)
|
92 |
+
self.num_layers = len(self.layers)
|
93 |
+
|
94 |
+
self.use_sent_enc_layer = args.use_sent_enc_layer
|
95 |
+
self.unb_enc_layer = getattr(args, "unb_enc_layer", -1)
|
96 |
+
|
97 |
+
self.layer_norm_first = args.layer_norm_first
|
98 |
+
self.layer_norm = LayerNorm(args.encoder_embed_dim, eps=args.layer_norm_eps, export=export)
|
99 |
+
|
100 |
+
if args.share_ctc_embed and embed_tokens is not None:
|
101 |
+
self.proj = nn.Linear(
|
102 |
+
embed_tokens.weight.shape[1],
|
103 |
+
embed_tokens.weight.shape[0],
|
104 |
+
bias=False,
|
105 |
+
)
|
106 |
+
self.proj.weight = embed_tokens.weight
|
107 |
+
elif tgt_dict is not None:
|
108 |
+
self.proj = Linear(args.encoder_embed_dim, len(tgt_dict))
|
109 |
+
else:
|
110 |
+
self.proj = None
|
111 |
+
|
112 |
+
if args.relative_position_embedding:
|
113 |
+
self.pos_emb = RelativePositionalEncoding(args.encoder_embed_dim//args.encoder_attention_heads, args.encoder_max_relative_position)
|
114 |
+
|
115 |
+
|
116 |
+
def build_encoder_layer(self, args):
|
117 |
+
if args.use_sent_enc_layer:
|
118 |
+
layer = TransformerSentenceEncoderLayer(
|
119 |
+
embedding_dim=args.encoder_embed_dim,
|
120 |
+
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
121 |
+
num_attention_heads=args.encoder_attention_heads,
|
122 |
+
dropout=args.dropout,
|
123 |
+
attention_dropout=args.attention_dropout,
|
124 |
+
activation_dropout=args.activation_dropout,
|
125 |
+
activation_fn=args.activation_fn,
|
126 |
+
layer_norm_first=args.layer_norm_first,
|
127 |
+
has_relative_attention_bias=args.relative_position_embedding,
|
128 |
+
)
|
129 |
+
else:
|
130 |
+
layer = TransformerEncoderLayer(args)
|
131 |
+
return layer
|
132 |
+
|
133 |
+
def forward(
|
134 |
+
self,
|
135 |
+
encoder_in,
|
136 |
+
encoder_padding_mask,
|
137 |
+
return_all_hiddens: bool = False,
|
138 |
+
tgt_layer=None,
|
139 |
+
):
|
140 |
+
"""
|
141 |
+
Args:
|
142 |
+
src_tokens (LongTensor): tokens in the source language of shape
|
143 |
+
`(batch, src_len)`
|
144 |
+
src_lengths (torch.LongTensor): lengths of each source sentence of
|
145 |
+
shape `(batch)`
|
146 |
+
return_all_hiddens (bool, optional): also return all of the
|
147 |
+
intermediate hidden states (default: False).
|
148 |
+
token_embeddings (torch.Tensor, optional): precomputed embeddings
|
149 |
+
default `None` will recompute embeddings
|
150 |
+
|
151 |
+
Returns:
|
152 |
+
dict:
|
153 |
+
- **encoder_out** (Tensor): the last encoder layer's output of
|
154 |
+
shape `(src_len, batch, embed_dim)`
|
155 |
+
- **encoder_padding_mask** (ByteTensor): the positions of
|
156 |
+
padding elements of shape `(batch, src_len)`
|
157 |
+
- **encoder_embedding** (Tensor): the (scaled) embedding lookup
|
158 |
+
of shape `(batch, src_len, embed_dim)`
|
159 |
+
- **encoder_states** (List[Tensor]): all intermediate
|
160 |
+
hidden states of shape `(src_len, batch, embed_dim)`.
|
161 |
+
Only populated if *return_all_hiddens* is True.
|
162 |
+
"""
|
163 |
+
if self.no_freeze_encoder_layer is None:
|
164 |
+
ft = self.freeze_encoder_updates <= self.num_updates
|
165 |
+
else:
|
166 |
+
ft = True
|
167 |
+
with torch.no_grad() if not ft else contextlib.ExitStack():
|
168 |
+
encoder_out = self.forward_scriptable(
|
169 |
+
encoder_in, encoder_padding_mask, return_all_hiddens, tgt_layer=tgt_layer,
|
170 |
+
)
|
171 |
+
|
172 |
+
# CTC and bert
|
173 |
+
if self.proj:
|
174 |
+
x_for_ctc = self.proj(self.dropout_module(encoder_out["encoder_out"][0]))
|
175 |
+
else:
|
176 |
+
x_for_ctc = None
|
177 |
+
|
178 |
+
encoder_out["encoder_out_for_ctc"] = [x_for_ctc] # T x B x C
|
179 |
+
|
180 |
+
return encoder_out
|
181 |
+
|
182 |
+
# TorchScript doesn't support super() method so that the scriptable Subclass
|
183 |
+
# can't access the base class model in Torchscript.
|
184 |
+
# Current workaround is to add a helper function with different name and
|
185 |
+
# call the helper function from scriptable Subclass.
|
186 |
+
def forward_scriptable(
|
187 |
+
self,
|
188 |
+
encoder_in,
|
189 |
+
encoder_padding_mask,
|
190 |
+
return_all_hiddens: bool = False,
|
191 |
+
tgt_layer=None,
|
192 |
+
):
|
193 |
+
"""
|
194 |
+
Args:
|
195 |
+
src_tokens (LongTensor): tokens in the source language of shape
|
196 |
+
`(batch, src_len)`
|
197 |
+
src_lengths (torch.LongTensor): lengths of each source sentence of
|
198 |
+
shape `(batch)`
|
199 |
+
return_all_hiddens (bool, optional): also return all of the
|
200 |
+
intermediate hidden states (default: False).
|
201 |
+
token_embeddings (torch.Tensor, optional): precomputed embeddings
|
202 |
+
default `None` will recompute embeddings
|
203 |
+
|
204 |
+
Returns:
|
205 |
+
dict:
|
206 |
+
- **encoder_out** (Tensor): the last encoder layer's output of
|
207 |
+
shape `(src_len, batch, embed_dim)`
|
208 |
+
- **encoder_padding_mask** (ByteTensor): the positions of
|
209 |
+
padding elements of shape `(batch, src_len)`
|
210 |
+
- **encoder_embedding** (Tensor): the (scaled) embedding lookup
|
211 |
+
of shape `(batch, src_len, embed_dim)`
|
212 |
+
- **encoder_states** (List[Tensor]): all intermediate
|
213 |
+
hidden states of shape `(src_len, batch, embed_dim)`.
|
214 |
+
Only populated if *return_all_hiddens* is True.
|
215 |
+
"""
|
216 |
+
if self.no_freeze_encoder_layer is not None:
|
217 |
+
ft = self.freeze_encoder_updates <= self.num_updates
|
218 |
+
else:
|
219 |
+
ft = True
|
220 |
+
with torch.no_grad() if not ft else contextlib.ExitStack():
|
221 |
+
# compute padding mask
|
222 |
+
if not self.use_sent_enc_layer:
|
223 |
+
has_pads = encoder_in.device.type == "xla" or encoder_padding_mask.any()
|
224 |
+
|
225 |
+
if not self.layer_norm_first:
|
226 |
+
encoder_in = self.layer_norm(encoder_in)
|
227 |
+
|
228 |
+
encoder_in = self.dropout_module(encoder_in)
|
229 |
+
|
230 |
+
# B x T x C -> T x B x C
|
231 |
+
x = encoder_in.transpose(0, 1)
|
232 |
+
|
233 |
+
encoder_states = []
|
234 |
+
|
235 |
+
if return_all_hiddens:
|
236 |
+
encoder_states.append(x)
|
237 |
+
|
238 |
+
## relative position embedding
|
239 |
+
if self.args.relative_position_embedding:
|
240 |
+
x_len = x.shape[0]
|
241 |
+
pos_seq = torch.arange(0, x_len).long().to(x.device)
|
242 |
+
pos_seq = pos_seq[:, None] - pos_seq[None, :]
|
243 |
+
pos_k, pos_v = self.pos_emb(pos_seq)
|
244 |
+
else:
|
245 |
+
pos_k = None
|
246 |
+
|
247 |
+
# encoder layers
|
248 |
+
r = None
|
249 |
+
d = None
|
250 |
+
for i, layer in enumerate(self.layers):
|
251 |
+
dropout_probability = np.random.random()
|
252 |
+
|
253 |
+
with torch.no_grad() if (not ft) and i not in self.no_freeze_encoder_layer else contextlib.ExitStack():
|
254 |
+
if not self.training or (dropout_probability > self.encoder_layerdrop) or i == self.unb_enc_layer:
|
255 |
+
if self.use_sent_enc_layer:
|
256 |
+
x, _ = layer(x, self_attn_padding_mask=encoder_padding_mask, self_attn_mask=None, need_weights=False, pos_bias=pos_k)
|
257 |
+
# x, _ = layer(x, self_attn_padding_mask=encoder_padding_mask, need_weights=False, pos_bias=pos_k)
|
258 |
+
else:
|
259 |
+
x = layer(x, encoder_padding_mask=encoder_padding_mask if has_pads else None, attn_mask=None)
|
260 |
+
# x = layer(x, encoder_padding_mask=encoder_padding_mask if has_pads else None)
|
261 |
+
if i == self.unb_enc_layer:
|
262 |
+
d = x
|
263 |
+
|
264 |
+
if i == tgt_layer:
|
265 |
+
r = x
|
266 |
+
break
|
267 |
+
|
268 |
+
if return_all_hiddens:
|
269 |
+
assert encoder_states is not None
|
270 |
+
encoder_states.append(x)
|
271 |
+
|
272 |
+
with torch.no_grad() if not ft else contextlib.ExitStack():
|
273 |
+
# Finally T x B x C
|
274 |
+
if self.layer_norm_first:
|
275 |
+
x = self.layer_norm(x.transpose(0, 1)).transpose(0, 1)
|
276 |
+
|
277 |
+
if r is not None:
|
278 |
+
x = r
|
279 |
+
|
280 |
+
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
|
281 |
+
# `forward` so we use a dictionary instead.
|
282 |
+
# TorchScript does not support mixed values so the values are all lists.
|
283 |
+
# The empty list is equivalent to None.
|
284 |
+
return {
|
285 |
+
"encoder_out": [x], # T x B x C
|
286 |
+
"encoder_padding_mask": [encoder_padding_mask], # B x T
|
287 |
+
"encoder_states": encoder_states, # List[T x B x C]
|
288 |
+
"src_tokens": [],
|
289 |
+
"decoder_input": [d],
|
290 |
+
}
|
291 |
+
|
292 |
+
@torch.jit.export
|
293 |
+
def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order):
|
294 |
+
"""
|
295 |
+
Reorder encoder output according to *new_order*.
|
296 |
+
|
297 |
+
Args:
|
298 |
+
encoder_out: output from the ``forward()`` method
|
299 |
+
new_order (LongTensor): desired order
|
300 |
+
|
301 |
+
Returns:
|
302 |
+
*encoder_out* rearranged according to *new_order*
|
303 |
+
"""
|
304 |
+
if len(encoder_out["encoder_out"]) == 0:
|
305 |
+
new_encoder_out = []
|
306 |
+
else:
|
307 |
+
new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)]
|
308 |
+
|
309 |
+
if len(encoder_out["encoder_out_for_ctc"]) == 0:
|
310 |
+
new_x_for_ctc = []
|
311 |
+
else:
|
312 |
+
new_x_for_ctc = [encoder_out["encoder_out_for_ctc"][0].index_select(1, new_order)]
|
313 |
+
|
314 |
+
if len(encoder_out["encoder_padding_mask"]) == 0:
|
315 |
+
new_encoder_padding_mask = []
|
316 |
+
else:
|
317 |
+
new_encoder_padding_mask = [
|
318 |
+
encoder_out["encoder_padding_mask"][0].index_select(0, new_order)
|
319 |
+
]
|
320 |
+
|
321 |
+
if len(encoder_out["src_tokens"]) == 0:
|
322 |
+
src_tokens = []
|
323 |
+
else:
|
324 |
+
src_tokens = [(encoder_out["src_tokens"][0]).index_select(0, new_order)]
|
325 |
+
|
326 |
+
if len(encoder_out["decoder_input"]) == 0 or encoder_out["decoder_input"][0] is None:
|
327 |
+
new_decoder_input = []
|
328 |
+
else:
|
329 |
+
new_decoder_input = [
|
330 |
+
encoder_out["decoder_input"][0].index_select(0, new_order)
|
331 |
+
]
|
332 |
+
|
333 |
+
encoder_states = encoder_out["encoder_states"]
|
334 |
+
if len(encoder_states) > 0:
|
335 |
+
for idx, state in enumerate(encoder_states):
|
336 |
+
encoder_states[idx] = state.index_select(1, new_order)
|
337 |
+
|
338 |
+
return {
|
339 |
+
"encoder_out": new_encoder_out, # T x B x C
|
340 |
+
"encoder_padding_mask": new_encoder_padding_mask, # B x T
|
341 |
+
"encoder_states": encoder_states, # List[T x B x C]
|
342 |
+
"src_tokens": src_tokens, # B x T
|
343 |
+
"encoder_out_for_ctc": new_x_for_ctc, # T x B x C
|
344 |
+
"decoder_input": new_decoder_input,
|
345 |
+
}
|
346 |
+
|
347 |
+
# def max_positions(self):
|
348 |
+
# """Maximum input length supported by the encoder."""
|
349 |
+
# return self.max_source_positions
|
350 |
+
|
351 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
352 |
+
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
|
353 |
+
# if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
|
354 |
+
# weights_key = "{}.embed_positions.weights".format(name)
|
355 |
+
# if weights_key in state_dict:
|
356 |
+
# print("deleting {0}".format(weights_key))
|
357 |
+
# del state_dict[weights_key]
|
358 |
+
# state_dict[
|
359 |
+
# "{}.embed_positions._float_tensor".format(name)
|
360 |
+
# ] = torch.FloatTensor(1)
|
361 |
+
for i in range(self.num_layers):
|
362 |
+
# update layer norms
|
363 |
+
if not isinstance(self.layers[i], TransformerSentenceEncoderLayer):
|
364 |
+
self.layers[i].upgrade_state_dict_named(
|
365 |
+
state_dict, "{}.layers.{}".format(name, i)
|
366 |
+
)
|
367 |
+
|
368 |
+
version_key = "{}.version".format(name)
|
369 |
+
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
|
370 |
+
# earlier checkpoints did not normalize after the stack of layers
|
371 |
+
self.layer_norm = None
|
372 |
+
self.normalize = False
|
373 |
+
state_dict[version_key] = torch.Tensor([1])
|
374 |
+
return state_dict
|
375 |
+
|
376 |
+
def set_num_updates(self, num_updates):
|
377 |
+
"""Set the number of parameters updates."""
|
378 |
+
super().set_num_updates(num_updates)
|
379 |
+
self.num_updates = num_updates
|
380 |
+
|
artst/models/modules/multihead_attention.py
ADDED
@@ -0,0 +1,525 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
|
5 |
+
# Based on speecht5, fairseq and espnet code bases
|
6 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
7 |
+
# --------------------------------------------------------
|
8 |
+
|
9 |
+
import math
|
10 |
+
from typing import Dict, Optional, Tuple
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from fairseq import utils
|
15 |
+
from fairseq.incremental_decoding_utils import with_incremental_state
|
16 |
+
from fairseq.modules.fairseq_dropout import FairseqDropout
|
17 |
+
from fairseq.modules.quant_noise import quant_noise
|
18 |
+
from torch import Tensor, nn
|
19 |
+
from torch.nn import Parameter
|
20 |
+
|
21 |
+
|
22 |
+
@with_incremental_state
|
23 |
+
class MultiheadAttention(nn.Module):
|
24 |
+
"""Multi-headed attention.
|
25 |
+
|
26 |
+
See "Attention Is All You Need" for more details.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
embed_dim,
|
32 |
+
num_heads,
|
33 |
+
kdim=None,
|
34 |
+
vdim=None,
|
35 |
+
dropout=0.0,
|
36 |
+
bias=True,
|
37 |
+
add_bias_kv=False,
|
38 |
+
add_zero_attn=False,
|
39 |
+
self_attention=False,
|
40 |
+
encoder_decoder_attention=False,
|
41 |
+
q_noise=0.0,
|
42 |
+
qn_block_size=8,
|
43 |
+
has_relative_attention_bias=False,
|
44 |
+
):
|
45 |
+
super().__init__()
|
46 |
+
self.embed_dim = embed_dim
|
47 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
48 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
49 |
+
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
50 |
+
|
51 |
+
self.num_heads = num_heads
|
52 |
+
self.dropout_module = FairseqDropout(
|
53 |
+
dropout, module_name=self.__class__.__name__
|
54 |
+
)
|
55 |
+
|
56 |
+
self.has_relative_attention_bias = has_relative_attention_bias
|
57 |
+
self.head_dim = embed_dim // num_heads
|
58 |
+
assert (
|
59 |
+
self.head_dim * num_heads == self.embed_dim
|
60 |
+
), "embed_dim must be divisible by num_heads"
|
61 |
+
self.scaling = self.head_dim ** -0.5
|
62 |
+
|
63 |
+
self.self_attention = self_attention
|
64 |
+
self.encoder_decoder_attention = encoder_decoder_attention
|
65 |
+
|
66 |
+
assert not self.self_attention or self.qkv_same_dim, (
|
67 |
+
"Self-attention requires query, key and " "value to be of the same size"
|
68 |
+
)
|
69 |
+
|
70 |
+
self.k_proj = quant_noise(
|
71 |
+
nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size
|
72 |
+
)
|
73 |
+
self.v_proj = quant_noise(
|
74 |
+
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
|
75 |
+
)
|
76 |
+
self.q_proj = quant_noise(
|
77 |
+
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
|
78 |
+
)
|
79 |
+
|
80 |
+
self.out_proj = quant_noise(
|
81 |
+
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
|
82 |
+
)
|
83 |
+
|
84 |
+
if add_bias_kv:
|
85 |
+
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
86 |
+
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
87 |
+
else:
|
88 |
+
self.bias_k = self.bias_v = None
|
89 |
+
|
90 |
+
self.add_zero_attn = add_zero_attn
|
91 |
+
|
92 |
+
self.reset_parameters()
|
93 |
+
|
94 |
+
self.onnx_trace = False
|
95 |
+
|
96 |
+
def prepare_for_onnx_export_(self):
|
97 |
+
self.onnx_trace = True
|
98 |
+
|
99 |
+
def reset_parameters(self):
|
100 |
+
if self.qkv_same_dim:
|
101 |
+
# Empirically observed the convergence to be much better with
|
102 |
+
# the scaled initialization
|
103 |
+
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
104 |
+
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
105 |
+
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
106 |
+
else:
|
107 |
+
nn.init.xavier_uniform_(self.k_proj.weight)
|
108 |
+
nn.init.xavier_uniform_(self.v_proj.weight)
|
109 |
+
nn.init.xavier_uniform_(self.q_proj.weight)
|
110 |
+
|
111 |
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
112 |
+
if self.out_proj.bias is not None:
|
113 |
+
nn.init.constant_(self.out_proj.bias, 0.0)
|
114 |
+
if self.bias_k is not None:
|
115 |
+
nn.init.xavier_normal_(self.bias_k)
|
116 |
+
if self.bias_v is not None:
|
117 |
+
nn.init.xavier_normal_(self.bias_v)
|
118 |
+
|
119 |
+
def forward(
|
120 |
+
self,
|
121 |
+
query,
|
122 |
+
key: Optional[Tensor],
|
123 |
+
value: Optional[Tensor],
|
124 |
+
key_padding_mask: Optional[Tensor] = None,
|
125 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
126 |
+
need_weights: bool = True,
|
127 |
+
static_kv: bool = False,
|
128 |
+
attn_mask: Optional[Tensor] = None,
|
129 |
+
before_softmax: bool = False,
|
130 |
+
need_head_weights: bool = False,
|
131 |
+
position_bias: Optional[Tensor] = None
|
132 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
133 |
+
"""Input shape: Time x Batch x Channel
|
134 |
+
|
135 |
+
Args:
|
136 |
+
key_padding_mask (ByteTensor, optional): mask to exclude
|
137 |
+
keys that are pads, of shape `(batch, src_len)`, where
|
138 |
+
padding elements are indicated by 1s.
|
139 |
+
need_weights (bool, optional): return the attention weights,
|
140 |
+
averaged over heads (default: False).
|
141 |
+
attn_mask (ByteTensor, optional): typically used to
|
142 |
+
implement causal attention, where the mask prevents the
|
143 |
+
attention from looking forward in time (default: None).
|
144 |
+
before_softmax (bool, optional): return the raw attention
|
145 |
+
weights and values before the attention softmax.
|
146 |
+
need_head_weights (bool, optional): return the attention
|
147 |
+
weights for each head. Implies *need_weights*. Default:
|
148 |
+
return the average attention weights over all heads.
|
149 |
+
"""
|
150 |
+
if need_head_weights:
|
151 |
+
need_weights = True
|
152 |
+
|
153 |
+
is_tpu = query.device.type == "xla"
|
154 |
+
|
155 |
+
tgt_len, bsz, embed_dim = query.size()
|
156 |
+
src_len = tgt_len
|
157 |
+
assert embed_dim == self.embed_dim
|
158 |
+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
159 |
+
if key is not None:
|
160 |
+
src_len, key_bsz, _ = key.size()
|
161 |
+
if not torch.jit.is_scripting():
|
162 |
+
assert key_bsz == bsz
|
163 |
+
assert value is not None
|
164 |
+
assert src_len, bsz == value.shape[:2]
|
165 |
+
|
166 |
+
if (
|
167 |
+
not self.onnx_trace
|
168 |
+
and not is_tpu # don't use PyTorch version on TPUs
|
169 |
+
and incremental_state is None
|
170 |
+
and not static_kv
|
171 |
+
# A workaround for quantization to work. Otherwise JIT compilation
|
172 |
+
# treats bias in linear module as method.
|
173 |
+
and not torch.jit.is_scripting()
|
174 |
+
and not self.has_relative_attention_bias
|
175 |
+
):
|
176 |
+
assert key is not None and value is not None
|
177 |
+
# Hawau:
|
178 |
+
if query.dtype != attn_mask.dtype:
|
179 |
+
attn_mask = attn_mask.type(query.dtype)
|
180 |
+
# My code ends here
|
181 |
+
return F.multi_head_attention_forward(
|
182 |
+
query,
|
183 |
+
key,
|
184 |
+
value,
|
185 |
+
self.embed_dim,
|
186 |
+
self.num_heads,
|
187 |
+
torch.empty([0]),
|
188 |
+
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
|
189 |
+
self.bias_k,
|
190 |
+
self.bias_v,
|
191 |
+
self.add_zero_attn,
|
192 |
+
self.dropout_module.p,
|
193 |
+
self.out_proj.weight,
|
194 |
+
self.out_proj.bias,
|
195 |
+
self.training or self.dropout_module.apply_during_inference,
|
196 |
+
key_padding_mask,
|
197 |
+
need_weights,
|
198 |
+
attn_mask,
|
199 |
+
use_separate_proj_weight=True,
|
200 |
+
q_proj_weight=self.q_proj.weight,
|
201 |
+
k_proj_weight=self.k_proj.weight,
|
202 |
+
v_proj_weight=self.v_proj.weight,
|
203 |
+
)
|
204 |
+
|
205 |
+
if incremental_state is not None:
|
206 |
+
saved_state = self._get_input_buffer(incremental_state)
|
207 |
+
if saved_state is not None and "prev_key" in saved_state:
|
208 |
+
# previous time steps are cached - no need to recompute
|
209 |
+
# key and value if they are static
|
210 |
+
if static_kv:
|
211 |
+
assert self.encoder_decoder_attention and not self.self_attention
|
212 |
+
key = value = None
|
213 |
+
else:
|
214 |
+
saved_state = None
|
215 |
+
|
216 |
+
if self.self_attention:
|
217 |
+
q = self.q_proj(query)
|
218 |
+
k = self.k_proj(query)
|
219 |
+
v = self.v_proj(query)
|
220 |
+
elif self.encoder_decoder_attention:
|
221 |
+
# encoder-decoder attention
|
222 |
+
q = self.q_proj(query)
|
223 |
+
if key is None:
|
224 |
+
assert value is None
|
225 |
+
k = v = None
|
226 |
+
else:
|
227 |
+
k = self.k_proj(key)
|
228 |
+
v = self.v_proj(key)
|
229 |
+
|
230 |
+
else:
|
231 |
+
assert key is not None and value is not None
|
232 |
+
q = self.q_proj(query)
|
233 |
+
k = self.k_proj(key)
|
234 |
+
v = self.v_proj(value)
|
235 |
+
q *= self.scaling
|
236 |
+
|
237 |
+
if self.bias_k is not None:
|
238 |
+
assert self.bias_v is not None
|
239 |
+
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
240 |
+
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
241 |
+
if attn_mask is not None:
|
242 |
+
attn_mask = torch.cat(
|
243 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
244 |
+
)
|
245 |
+
if key_padding_mask is not None:
|
246 |
+
key_padding_mask = torch.cat(
|
247 |
+
[
|
248 |
+
key_padding_mask,
|
249 |
+
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
|
250 |
+
],
|
251 |
+
dim=1,
|
252 |
+
)
|
253 |
+
|
254 |
+
q = (
|
255 |
+
q.contiguous()
|
256 |
+
.view(tgt_len, bsz * self.num_heads, self.head_dim)
|
257 |
+
.transpose(0, 1)
|
258 |
+
)
|
259 |
+
if k is not None:
|
260 |
+
k = (
|
261 |
+
k.contiguous()
|
262 |
+
.view(-1, bsz * self.num_heads, self.head_dim)
|
263 |
+
.transpose(0, 1)
|
264 |
+
)
|
265 |
+
if v is not None:
|
266 |
+
v = (
|
267 |
+
v.contiguous()
|
268 |
+
.view(-1, bsz * self.num_heads, self.head_dim)
|
269 |
+
.transpose(0, 1)
|
270 |
+
)
|
271 |
+
|
272 |
+
if saved_state is not None:
|
273 |
+
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
274 |
+
if "prev_key" in saved_state:
|
275 |
+
_prev_key = saved_state["prev_key"]
|
276 |
+
assert _prev_key is not None
|
277 |
+
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
|
278 |
+
if static_kv:
|
279 |
+
k = prev_key
|
280 |
+
else:
|
281 |
+
assert k is not None
|
282 |
+
k = torch.cat([prev_key, k], dim=1)
|
283 |
+
src_len = k.size(1)
|
284 |
+
if "prev_value" in saved_state:
|
285 |
+
_prev_value = saved_state["prev_value"]
|
286 |
+
assert _prev_value is not None
|
287 |
+
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
|
288 |
+
if static_kv:
|
289 |
+
v = prev_value
|
290 |
+
else:
|
291 |
+
assert v is not None
|
292 |
+
v = torch.cat([prev_value, v], dim=1)
|
293 |
+
prev_key_padding_mask: Optional[Tensor] = None
|
294 |
+
if "prev_key_padding_mask" in saved_state:
|
295 |
+
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
|
296 |
+
assert k is not None and v is not None
|
297 |
+
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
|
298 |
+
key_padding_mask=key_padding_mask,
|
299 |
+
prev_key_padding_mask=prev_key_padding_mask,
|
300 |
+
batch_size=bsz,
|
301 |
+
src_len=k.size(1),
|
302 |
+
static_kv=static_kv,
|
303 |
+
)
|
304 |
+
|
305 |
+
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
306 |
+
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
307 |
+
saved_state["prev_key_padding_mask"] = key_padding_mask
|
308 |
+
# In this branch incremental_state is never None
|
309 |
+
assert incremental_state is not None
|
310 |
+
incremental_state = self._set_input_buffer(incremental_state, saved_state)
|
311 |
+
assert k is not None
|
312 |
+
assert k.size(1) == src_len
|
313 |
+
|
314 |
+
# This is part of a workaround to get around fork/join parallelism
|
315 |
+
# not supporting Optional types.
|
316 |
+
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
317 |
+
key_padding_mask = None
|
318 |
+
|
319 |
+
if key_padding_mask is not None:
|
320 |
+
assert key_padding_mask.size(0) == bsz
|
321 |
+
assert key_padding_mask.size(1) == src_len
|
322 |
+
|
323 |
+
if self.add_zero_attn:
|
324 |
+
assert v is not None
|
325 |
+
src_len += 1
|
326 |
+
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
327 |
+
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
328 |
+
if attn_mask is not None:
|
329 |
+
attn_mask = torch.cat(
|
330 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
331 |
+
)
|
332 |
+
if key_padding_mask is not None:
|
333 |
+
key_padding_mask = torch.cat(
|
334 |
+
[
|
335 |
+
key_padding_mask,
|
336 |
+
torch.zeros(key_padding_mask.size(0), 1).type_as(
|
337 |
+
key_padding_mask
|
338 |
+
),
|
339 |
+
],
|
340 |
+
dim=1,
|
341 |
+
)
|
342 |
+
|
343 |
+
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
344 |
+
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
345 |
+
|
346 |
+
if position_bias is not None and self.has_relative_attention_bias: ## first order
|
347 |
+
## position_bias: [241, 241, 64]
|
348 |
+
#print ("attn_weights: ", attn_weights.size()) # [492, 241, 241]
|
349 |
+
reshape_q = q.contiguous().view(bsz * self.num_heads, -1, self.head_dim).transpose(0,1) #[241, 492, 64]
|
350 |
+
#print ("reshape_q: ", reshape_q.size())
|
351 |
+
B = torch.matmul(reshape_q, position_bias.transpose(-2, -1))
|
352 |
+
#print ("B: ", B.size()) ## [241, 492, 241]
|
353 |
+
#B = B.transpose(0, 1).view(bsz, self.num_heads, position_bias.size(0), position_bias.size(1))
|
354 |
+
B = B.transpose(0, 1).view(bsz*self.num_heads, position_bias.size(0), position_bias.size(1))
|
355 |
+
#print ("B 2: ", B.size())
|
356 |
+
attn_weights += B
|
357 |
+
else:
|
358 |
+
position_bias = None
|
359 |
+
|
360 |
+
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
361 |
+
|
362 |
+
if attn_mask is not None:
|
363 |
+
attn_mask = attn_mask.unsqueeze(0)
|
364 |
+
if self.onnx_trace:
|
365 |
+
attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
|
366 |
+
attn_weights += attn_mask
|
367 |
+
|
368 |
+
if key_padding_mask is not None:
|
369 |
+
# don't attend to padding symbols
|
370 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
371 |
+
if not is_tpu:
|
372 |
+
attn_weights = attn_weights.masked_fill(
|
373 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
374 |
+
float("-inf"),
|
375 |
+
)
|
376 |
+
else:
|
377 |
+
attn_weights = attn_weights.transpose(0, 2)
|
378 |
+
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
|
379 |
+
attn_weights = attn_weights.transpose(0, 2)
|
380 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
381 |
+
|
382 |
+
if before_softmax:
|
383 |
+
return attn_weights, v
|
384 |
+
|
385 |
+
attn_weights_float = utils.softmax(
|
386 |
+
attn_weights, dim=-1, onnx_trace=self.onnx_trace
|
387 |
+
)
|
388 |
+
attn_weights = attn_weights_float.type_as(attn_weights)
|
389 |
+
attn_probs = self.dropout_module(attn_weights)
|
390 |
+
|
391 |
+
assert v is not None
|
392 |
+
attn = torch.bmm(attn_probs, v)
|
393 |
+
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
394 |
+
if self.onnx_trace and attn.size(1) == 1:
|
395 |
+
# when ONNX tracing a single decoder step (sequence length == 1)
|
396 |
+
# the transpose is a no-op copy before view, thus unnecessary
|
397 |
+
attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
|
398 |
+
else:
|
399 |
+
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
400 |
+
attn = self.out_proj(attn)
|
401 |
+
attn_weights: Optional[Tensor] = None
|
402 |
+
if need_weights:
|
403 |
+
attn_weights = attn_weights_float.view(
|
404 |
+
bsz, self.num_heads, tgt_len, src_len
|
405 |
+
).transpose(1, 0)
|
406 |
+
if not need_head_weights:
|
407 |
+
# average attention weights over heads
|
408 |
+
attn_weights = attn_weights.mean(dim=0)
|
409 |
+
|
410 |
+
return attn, attn_weights
|
411 |
+
|
412 |
+
@staticmethod
|
413 |
+
def _append_prev_key_padding_mask(
|
414 |
+
key_padding_mask: Optional[Tensor],
|
415 |
+
prev_key_padding_mask: Optional[Tensor],
|
416 |
+
batch_size: int,
|
417 |
+
src_len: int,
|
418 |
+
static_kv: bool,
|
419 |
+
) -> Optional[Tensor]:
|
420 |
+
# saved key padding masks have shape (bsz, seq_len)
|
421 |
+
if prev_key_padding_mask is not None and static_kv:
|
422 |
+
new_key_padding_mask = prev_key_padding_mask
|
423 |
+
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
424 |
+
new_key_padding_mask = torch.cat(
|
425 |
+
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
|
426 |
+
)
|
427 |
+
# During incremental decoding, as the padding token enters and
|
428 |
+
# leaves the frame, there will be a time when prev or current
|
429 |
+
# is None
|
430 |
+
elif prev_key_padding_mask is not None:
|
431 |
+
if src_len > prev_key_padding_mask.size(1):
|
432 |
+
filler = torch.zeros(
|
433 |
+
(batch_size, src_len - prev_key_padding_mask.size(1)),
|
434 |
+
device=prev_key_padding_mask.device,
|
435 |
+
)
|
436 |
+
new_key_padding_mask = torch.cat(
|
437 |
+
[prev_key_padding_mask.float(), filler.float()], dim=1
|
438 |
+
)
|
439 |
+
else:
|
440 |
+
new_key_padding_mask = prev_key_padding_mask.float()
|
441 |
+
elif key_padding_mask is not None:
|
442 |
+
if src_len > key_padding_mask.size(1):
|
443 |
+
filler = torch.zeros(
|
444 |
+
(batch_size, src_len - key_padding_mask.size(1)),
|
445 |
+
device=key_padding_mask.device,
|
446 |
+
)
|
447 |
+
new_key_padding_mask = torch.cat(
|
448 |
+
[filler.float(), key_padding_mask.float()], dim=1
|
449 |
+
)
|
450 |
+
else:
|
451 |
+
new_key_padding_mask = key_padding_mask.float()
|
452 |
+
else:
|
453 |
+
new_key_padding_mask = prev_key_padding_mask
|
454 |
+
return new_key_padding_mask
|
455 |
+
|
456 |
+
@torch.jit.export
|
457 |
+
def reorder_incremental_state(
|
458 |
+
self,
|
459 |
+
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
460 |
+
new_order: Tensor,
|
461 |
+
):
|
462 |
+
"""Reorder buffered internal state (for incremental generation)."""
|
463 |
+
input_buffer = self._get_input_buffer(incremental_state)
|
464 |
+
if input_buffer is not None:
|
465 |
+
for k in input_buffer.keys():
|
466 |
+
input_buffer_k = input_buffer[k]
|
467 |
+
if input_buffer_k is not None:
|
468 |
+
if self.encoder_decoder_attention and input_buffer_k.size(
|
469 |
+
0
|
470 |
+
) == new_order.size(0):
|
471 |
+
break
|
472 |
+
input_buffer[k] = input_buffer_k.index_select(0, new_order)
|
473 |
+
incremental_state = self._set_input_buffer(incremental_state, input_buffer)
|
474 |
+
return incremental_state
|
475 |
+
|
476 |
+
def _get_input_buffer(
|
477 |
+
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
|
478 |
+
) -> Dict[str, Optional[Tensor]]:
|
479 |
+
result = self.get_incremental_state(incremental_state, "attn_state")
|
480 |
+
if result is not None:
|
481 |
+
return result
|
482 |
+
else:
|
483 |
+
empty_result: Dict[str, Optional[Tensor]] = {}
|
484 |
+
return empty_result
|
485 |
+
|
486 |
+
def _set_input_buffer(
|
487 |
+
self,
|
488 |
+
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
489 |
+
buffer: Dict[str, Optional[Tensor]],
|
490 |
+
):
|
491 |
+
return self.set_incremental_state(incremental_state, "attn_state", buffer)
|
492 |
+
|
493 |
+
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
|
494 |
+
return attn_weights
|
495 |
+
|
496 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
497 |
+
prefix = name + "." if name != "" else ""
|
498 |
+
items_to_add = {}
|
499 |
+
keys_to_remove = []
|
500 |
+
for k in state_dict.keys():
|
501 |
+
if k.endswith(prefix + "in_proj_weight"):
|
502 |
+
# in_proj_weight used to be q + k + v with same dimensions
|
503 |
+
dim = int(state_dict[k].shape[0] / 3)
|
504 |
+
items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
|
505 |
+
items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
|
506 |
+
items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
|
507 |
+
|
508 |
+
keys_to_remove.append(k)
|
509 |
+
|
510 |
+
k_bias = prefix + "in_proj_bias"
|
511 |
+
if k_bias in state_dict.keys():
|
512 |
+
dim = int(state_dict[k].shape[0] / 3)
|
513 |
+
items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
|
514 |
+
items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][
|
515 |
+
dim : 2 * dim
|
516 |
+
]
|
517 |
+
items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
|
518 |
+
|
519 |
+
keys_to_remove.append(prefix + "in_proj_bias")
|
520 |
+
|
521 |
+
for k in keys_to_remove:
|
522 |
+
del state_dict[k]
|
523 |
+
|
524 |
+
for key, value in items_to_add.items():
|
525 |
+
state_dict[key] = value
|
artst/models/modules/speaker_decoder_postnet.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
|
5 |
+
# Based on speecht5, fairseq and espnet code bases
|
6 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
7 |
+
# --------------------------------------------------------
|
8 |
+
|
9 |
+
import torch.nn as nn
|
10 |
+
import math
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
|
15 |
+
class AngularMargin(nn.Module):
|
16 |
+
"""
|
17 |
+
An implementation of Angular Margin (AM) proposed in the following
|
18 |
+
paper: '''Margin Matters: Towards More Discriminative Deep Neural Network
|
19 |
+
Embeddings for Speaker Recognition''' (https://arxiv.org/abs/1906.07317)
|
20 |
+
|
21 |
+
Arguments
|
22 |
+
---------
|
23 |
+
margin : float
|
24 |
+
The margin for cosine similiarity
|
25 |
+
scale : float
|
26 |
+
The scale for cosine similiarity
|
27 |
+
|
28 |
+
Return
|
29 |
+
---------
|
30 |
+
predictions : torch.Tensor
|
31 |
+
|
32 |
+
Example
|
33 |
+
-------
|
34 |
+
>>> pred = AngularMargin()
|
35 |
+
>>> outputs = torch.tensor([ [1., -1.], [-1., 1.], [0.9, 0.1], [0.1, 0.9] ])
|
36 |
+
>>> targets = torch.tensor([ [1., 0.], [0., 1.], [ 1., 0.], [0., 1.] ])
|
37 |
+
>>> predictions = pred(outputs, targets)
|
38 |
+
>>> predictions[:,0] > predictions[:,1]
|
39 |
+
tensor([ True, False, True, False])
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(self, margin=0.0, scale=1.0):
|
43 |
+
super(AngularMargin, self).__init__()
|
44 |
+
self.margin = margin
|
45 |
+
self.scale = scale
|
46 |
+
|
47 |
+
def forward(self, outputs, targets):
|
48 |
+
"""Compute AM between two tensors
|
49 |
+
|
50 |
+
Arguments
|
51 |
+
---------
|
52 |
+
outputs : torch.Tensor
|
53 |
+
The outputs of shape [N, C], cosine similarity is required.
|
54 |
+
targets : torch.Tensor
|
55 |
+
The targets of shape [N, C], where the margin is applied for.
|
56 |
+
|
57 |
+
Return
|
58 |
+
---------
|
59 |
+
predictions : torch.Tensor
|
60 |
+
"""
|
61 |
+
outputs = outputs - self.margin * targets
|
62 |
+
return self.scale * outputs
|
63 |
+
|
64 |
+
|
65 |
+
class AdditiveAngularMargin(AngularMargin):
|
66 |
+
"""
|
67 |
+
An implementation of Additive Angular Margin (AAM) proposed
|
68 |
+
in the following paper: '''Margin Matters: Towards More Discriminative Deep
|
69 |
+
Neural Network Embeddings for Speaker Recognition'''
|
70 |
+
(https://arxiv.org/abs/1906.07317)
|
71 |
+
|
72 |
+
Arguments
|
73 |
+
---------
|
74 |
+
margin : float
|
75 |
+
The margin for cosine similiarity, usually 0.2.
|
76 |
+
scale: float
|
77 |
+
The scale for cosine similiarity, usually 30.
|
78 |
+
|
79 |
+
Returns
|
80 |
+
-------
|
81 |
+
predictions : torch.Tensor
|
82 |
+
Tensor.
|
83 |
+
Example
|
84 |
+
-------
|
85 |
+
>>> outputs = torch.tensor([ [1., -1.], [-1., 1.], [0.9, 0.1], [0.1, 0.9] ])
|
86 |
+
>>> targets = torch.tensor([ [1., 0.], [0., 1.], [ 1., 0.], [0., 1.] ])
|
87 |
+
>>> pred = AdditiveAngularMargin()
|
88 |
+
>>> predictions = pred(outputs, targets)
|
89 |
+
>>> predictions[:,0] > predictions[:,1]
|
90 |
+
tensor([ True, False, True, False])
|
91 |
+
"""
|
92 |
+
|
93 |
+
def __init__(self, margin=0.0, scale=1.0, easy_margin=False):
|
94 |
+
super(AdditiveAngularMargin, self).__init__(margin, scale)
|
95 |
+
self.easy_margin = easy_margin
|
96 |
+
|
97 |
+
self.cos_m = math.cos(self.margin)
|
98 |
+
self.sin_m = math.sin(self.margin)
|
99 |
+
self.th = math.cos(math.pi - self.margin)
|
100 |
+
self.mm = math.sin(math.pi - self.margin) * self.margin
|
101 |
+
|
102 |
+
def forward(self, outputs, targets):
|
103 |
+
"""
|
104 |
+
Compute AAM between two tensors
|
105 |
+
|
106 |
+
Arguments
|
107 |
+
---------
|
108 |
+
outputs : torch.Tensor
|
109 |
+
The outputs of shape [N, C], cosine similarity is required.
|
110 |
+
targets : torch.Tensor
|
111 |
+
The targets of shape [N, C], where the margin is applied for.
|
112 |
+
|
113 |
+
Return
|
114 |
+
---------
|
115 |
+
predictions : torch.Tensor
|
116 |
+
"""
|
117 |
+
cosine = outputs.float()
|
118 |
+
sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
|
119 |
+
phi = cosine * self.cos_m - sine * self.sin_m # cos(theta + m)
|
120 |
+
if self.easy_margin:
|
121 |
+
phi = torch.where(cosine > 0, phi, cosine)
|
122 |
+
else:
|
123 |
+
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
|
124 |
+
outputs = (targets * phi) + ((1.0 - targets) * cosine)
|
125 |
+
return self.scale * outputs
|
126 |
+
|
127 |
+
|
128 |
+
class SpeakerDecoderPostnet(nn.Module):
|
129 |
+
"""Speaker Identification Postnet.
|
130 |
+
|
131 |
+
Arguments
|
132 |
+
---------
|
133 |
+
embed_dim : int
|
134 |
+
The size of embedding.
|
135 |
+
class_num: int
|
136 |
+
The number of classes.
|
137 |
+
args : Namespace
|
138 |
+
|
139 |
+
Return
|
140 |
+
---------
|
141 |
+
embed : torch.Tensor
|
142 |
+
output : torch.Tensor
|
143 |
+
"""
|
144 |
+
|
145 |
+
def __init__(self, embed_dim, class_num, args):
|
146 |
+
super(SpeakerDecoderPostnet, self).__init__()
|
147 |
+
self.embed_dim = embed_dim
|
148 |
+
self.class_num = class_num
|
149 |
+
self.no_pooling_bn = getattr(args, "sid_no_pooling_bn", False)
|
150 |
+
self.no_embed_postnet = getattr(args, "sid_no_embed_postnet", False)
|
151 |
+
self.normalize_postnet = getattr(args, "sid_normalize_postnet", False)
|
152 |
+
self.softmax_head = getattr(args, "sid_softmax_type", "softmax")
|
153 |
+
if not self.no_pooling_bn:
|
154 |
+
self.bn_pooling = nn.BatchNorm1d(args.decoder_output_dim)
|
155 |
+
else:
|
156 |
+
self.bn_pooling = None
|
157 |
+
if not self.no_embed_postnet:
|
158 |
+
self.output_embedding = nn.Linear(args.decoder_output_dim, embed_dim, bias=False)
|
159 |
+
self.bn_embedding = nn.BatchNorm1d(embed_dim)
|
160 |
+
else:
|
161 |
+
self.output_embedding = None
|
162 |
+
self.bn_embedding = None
|
163 |
+
self.embed_dim = args.decoder_output_dim
|
164 |
+
self.output_projection = nn.Linear(self.embed_dim, class_num, bias=False)
|
165 |
+
if self.softmax_head == "amsoftmax":
|
166 |
+
self.output_layer = AngularMargin(args.softmax_margin, args.softmax_scale)
|
167 |
+
elif self.softmax_head == "aamsoftmax":
|
168 |
+
self.output_layer = AdditiveAngularMargin(args.softmax_margin, args.softmax_scale, args.softmax_easy_margin)
|
169 |
+
else:
|
170 |
+
self.output_layer = None
|
171 |
+
if self.output_embedding is not None:
|
172 |
+
nn.init.normal_(self.output_embedding.weight, mean=0, std=embed_dim ** -0.5)
|
173 |
+
nn.init.normal_(self.output_projection.weight, mean=0, std=class_num ** -0.5)
|
174 |
+
|
175 |
+
def forward(self, x, target=None):
|
176 |
+
"""
|
177 |
+
Parameters
|
178 |
+
----------
|
179 |
+
x : torch.Tensor of shape [batch, channel] or [batch, time, channel]
|
180 |
+
target : torch.Tensor of shape [batch, channel]
|
181 |
+
"""
|
182 |
+
if self.bn_pooling is not None:
|
183 |
+
x = self.bn_pooling(x)
|
184 |
+
if self.output_embedding is not None and self.bn_embedding is not None:
|
185 |
+
embed = self.bn_embedding(self.output_embedding(x))
|
186 |
+
else:
|
187 |
+
embed = x
|
188 |
+
if self.output_layer is not None or self.normalize_postnet:
|
189 |
+
x_norm = F.normalize(embed, p=2, dim=1)
|
190 |
+
w_norm = F.normalize(self.output_projection.weight, p=2, dim=1) # [out_dim, in_dim]
|
191 |
+
output = F.linear(x_norm, w_norm)
|
192 |
+
if self.training and target is not None and self.output_layer is not None:
|
193 |
+
output = self.output_layer(output, target)
|
194 |
+
else:
|
195 |
+
output = self.output_projection(embed)
|
196 |
+
return output, embed
|
artst/models/modules/speech_decoder_postnet.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
|
5 |
+
# Based on speecht5, fairseq and espnet code bases
|
6 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
7 |
+
# --------------------------------------------------------
|
8 |
+
|
9 |
+
import contextlib
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
|
13 |
+
from espnet.nets.pytorch_backend.tacotron2.decoder import Postnet
|
14 |
+
|
15 |
+
|
16 |
+
class SpeechDecoderPostnet(nn.Module):
|
17 |
+
"""
|
18 |
+
|
19 |
+
Args:
|
20 |
+
in_channels (int): the number of input channels
|
21 |
+
mid_channels (int): the number of intermediate channels
|
22 |
+
out_channels (int): the number of output channels
|
23 |
+
kernel_sizes (List[int]): the kernel size for each convolutional layer
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
odim,
|
29 |
+
args,
|
30 |
+
):
|
31 |
+
super(SpeechDecoderPostnet, self).__init__()
|
32 |
+
# define decoder postnet
|
33 |
+
# define final projection
|
34 |
+
self.feat_out = torch.nn.Linear(args.decoder_embed_dim, odim * args.reduction_factor)
|
35 |
+
self.prob_out = torch.nn.Linear(args.decoder_embed_dim, args.reduction_factor)
|
36 |
+
|
37 |
+
# define postnet
|
38 |
+
self.postnet = (
|
39 |
+
None
|
40 |
+
if args.postnet_layers == 0
|
41 |
+
else Postnet(
|
42 |
+
idim=0,
|
43 |
+
odim=odim,
|
44 |
+
n_layers=args.postnet_layers,
|
45 |
+
n_chans=args.postnet_chans,
|
46 |
+
n_filts=args.postnet_filts,
|
47 |
+
use_batch_norm=args.use_batch_norm,
|
48 |
+
dropout_rate=args.postnet_dropout_rate,
|
49 |
+
)
|
50 |
+
)
|
51 |
+
|
52 |
+
self.odim = odim
|
53 |
+
self.num_updates = 0
|
54 |
+
self.freeze_decoder_updates = args.freeze_decoder_updates
|
55 |
+
|
56 |
+
def forward(self, zs):
|
57 |
+
ft = self.freeze_decoder_updates <= self.num_updates
|
58 |
+
with torch.no_grad() if not ft else contextlib.ExitStack():
|
59 |
+
# (B, Lmax//r, odim * r) -> (B, Lmax//r * r, odim)
|
60 |
+
before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim)
|
61 |
+
# (B, Lmax//r, r) -> (B, Lmax//r * r)
|
62 |
+
logits = self.prob_out(zs).view(zs.size(0), -1)
|
63 |
+
# postnet -> (B, Lmax//r * r, odim)
|
64 |
+
if self.postnet is None:
|
65 |
+
after_outs = before_outs
|
66 |
+
else:
|
67 |
+
after_outs = before_outs + self.postnet(
|
68 |
+
before_outs.transpose(1, 2)
|
69 |
+
).transpose(1, 2)
|
70 |
+
|
71 |
+
return before_outs, after_outs, logits
|
72 |
+
|
73 |
+
def set_num_updates(self, num_updates):
|
74 |
+
"""Set the number of parameters updates."""
|
75 |
+
self.num_updates = num_updates
|
artst/models/modules/speech_decoder_prenet.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
|
5 |
+
# Based on speecht5, fairseq and espnet code bases
|
6 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
7 |
+
# --------------------------------------------------------
|
8 |
+
|
9 |
+
import contextlib
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from espnet.nets.pytorch_backend.tacotron2.decoder import Prenet as TacotronDecoderPrenet
|
15 |
+
from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding
|
16 |
+
from espnet.nets.pytorch_backend.transformer.embedding import ScaledPositionalEncoding
|
17 |
+
from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask
|
18 |
+
|
19 |
+
|
20 |
+
class SpeechDecoderPrenet(nn.Module):
|
21 |
+
"""
|
22 |
+
|
23 |
+
Args:
|
24 |
+
in_channels (int): the number of input channels
|
25 |
+
mid_channels (int): the number of intermediate channels
|
26 |
+
out_channels (int): the number of output channels
|
27 |
+
kernel_sizes (List[int]): the kernel size for each convolutional layer
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
odim,
|
33 |
+
args,
|
34 |
+
):
|
35 |
+
super(SpeechDecoderPrenet, self).__init__()
|
36 |
+
# define decoder prenet
|
37 |
+
if args.dprenet_layers != 0:
|
38 |
+
# decoder prenet
|
39 |
+
decoder_input_layer = torch.nn.Sequential(
|
40 |
+
TacotronDecoderPrenet(
|
41 |
+
idim=odim,
|
42 |
+
n_layers=args.dprenet_layers,
|
43 |
+
n_units=args.dprenet_units,
|
44 |
+
dropout_rate=args.dprenet_dropout_rate,
|
45 |
+
),
|
46 |
+
torch.nn.Linear(args.dprenet_units, args.decoder_embed_dim),
|
47 |
+
)
|
48 |
+
else:
|
49 |
+
decoder_input_layer = "linear"
|
50 |
+
|
51 |
+
pos_enc_class = (
|
52 |
+
ScaledPositionalEncoding if args.dec_use_scaled_pos_enc else PositionalEncoding
|
53 |
+
)
|
54 |
+
|
55 |
+
if decoder_input_layer == "linear":
|
56 |
+
self.decoder_prenet = torch.nn.Sequential(
|
57 |
+
torch.nn.Linear(odim, args.decoder_embed_dim),
|
58 |
+
torch.nn.LayerNorm(args.decoder_embed_dim),
|
59 |
+
torch.nn.Dropout(args.transformer_dec_dropout_rate),
|
60 |
+
torch.nn.ReLU(),
|
61 |
+
pos_enc_class(args.decoder_embed_dim, args.transformer_dec_positional_dropout_rate),
|
62 |
+
)
|
63 |
+
elif isinstance(decoder_input_layer, torch.nn.Module):
|
64 |
+
self.decoder_prenet = torch.nn.Sequential(
|
65 |
+
decoder_input_layer, pos_enc_class(args.decoder_embed_dim, args.transformer_dec_positional_dropout_rate, max_len=args.max_speech_positions)
|
66 |
+
)
|
67 |
+
|
68 |
+
if args.spk_embed_integration_type == 'pre':
|
69 |
+
self.spkembs_layer = torch.nn.Sequential(
|
70 |
+
torch.nn.Linear(args.spk_embed_dim + args.decoder_embed_dim, args.decoder_embed_dim), torch.nn.ReLU()
|
71 |
+
)
|
72 |
+
self.num_updates = 0
|
73 |
+
self.freeze_decoder_updates = args.freeze_decoder_updates
|
74 |
+
|
75 |
+
def forward(self, prev_output_tokens, tgt_lengths_in=None, spkembs=None):
|
76 |
+
ft = self.freeze_decoder_updates <= self.num_updates
|
77 |
+
with torch.no_grad() if not ft else contextlib.ExitStack():
|
78 |
+
prev_output_tokens = self.decoder_prenet(prev_output_tokens)
|
79 |
+
|
80 |
+
if spkembs is not None:
|
81 |
+
spkembs = F.normalize(spkembs).unsqueeze(1).expand(-1, prev_output_tokens.size(1), -1)
|
82 |
+
prev_output_tokens = self.spkembs_layer(torch.cat([prev_output_tokens, spkembs], dim=-1))
|
83 |
+
|
84 |
+
if tgt_lengths_in is not None:
|
85 |
+
tgt_frames_mask = ~(self._source_mask(tgt_lengths_in).squeeze(1))
|
86 |
+
else:
|
87 |
+
tgt_frames_mask = None
|
88 |
+
return prev_output_tokens, tgt_frames_mask
|
89 |
+
|
90 |
+
def _source_mask(self, ilens):
|
91 |
+
"""Make masks for self-attention.
|
92 |
+
Args:
|
93 |
+
ilens (LongTensor or List): Batch of lengths (B,).
|
94 |
+
Returns:
|
95 |
+
Tensor: Mask tensor for self-attention.
|
96 |
+
dtype=torch.uint8 in PyTorch 1.2-
|
97 |
+
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
|
98 |
+
Examples:
|
99 |
+
>>> ilens = [5, 3]
|
100 |
+
>>> self._source_mask(ilens)
|
101 |
+
tensor([[[1, 1, 1, 1, 1],
|
102 |
+
[[1, 1, 1, 0, 0]]], dtype=torch.uint8)
|
103 |
+
"""
|
104 |
+
x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device)
|
105 |
+
return x_masks.unsqueeze(-2)
|
106 |
+
|
107 |
+
def set_num_updates(self, num_updates):
|
108 |
+
"""Set the number of parameters updates."""
|
109 |
+
self.num_updates = num_updates
|
artst/models/modules/speech_encoder_postnet.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
|
5 |
+
# Based on speecht5, fairseq and espnet code bases
|
6 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
7 |
+
# --------------------------------------------------------
|
8 |
+
|
9 |
+
import logging
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch
|
12 |
+
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
class SpeechEncoderPostnet(nn.Module):
|
17 |
+
"""
|
18 |
+
|
19 |
+
Args:
|
20 |
+
in_channels (int): the number of input channels
|
21 |
+
mid_channels (int): the number of intermediate channels
|
22 |
+
out_channels (int): the number of output channels
|
23 |
+
kernel_sizes (List[int]): the kernel size for each convolutional layer
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self, dictionaries, args):
|
27 |
+
super(SpeechEncoderPostnet, self).__init__()
|
28 |
+
# modules below are not needed during fine-tuning
|
29 |
+
self.target_glu = args.target_glu
|
30 |
+
self.skip_masked = args.skip_masked
|
31 |
+
self.skip_nomask = args.skip_nomask
|
32 |
+
self.logit_temp = args.logit_temp
|
33 |
+
|
34 |
+
final_dim = (
|
35 |
+
args.final_dim if args.final_dim > 0 else args.encoder_embed_dim
|
36 |
+
)
|
37 |
+
if any([d is None for d in dictionaries]):
|
38 |
+
logger.info(
|
39 |
+
"cannot find dictionary. assume will be used for fine-tuning"
|
40 |
+
)
|
41 |
+
else:
|
42 |
+
self.num_classes = [len(d) for d in dictionaries]
|
43 |
+
self.label_embs_concat = nn.Parameter(
|
44 |
+
torch.FloatTensor(sum(self.num_classes), final_dim)
|
45 |
+
)
|
46 |
+
nn.init.uniform_(self.label_embs_concat)
|
47 |
+
self.untie_final_proj = args.untie_final_proj
|
48 |
+
if self.untie_final_proj:
|
49 |
+
self.final_proj = nn.Linear(
|
50 |
+
args.encoder_embed_dim, final_dim * len(dictionaries)
|
51 |
+
)
|
52 |
+
else:
|
53 |
+
self.final_proj = nn.Linear(args.encoder_embed_dim, final_dim)
|
54 |
+
|
55 |
+
def compute_nce(self, x, pos, negs):
|
56 |
+
neg_is_pos = (pos == negs).all(-1)
|
57 |
+
pos = pos.unsqueeze(0)
|
58 |
+
targets = torch.cat([pos, negs], dim=0)
|
59 |
+
|
60 |
+
logits = torch.cosine_similarity(
|
61 |
+
x.float(), targets.float(), dim=-1
|
62 |
+
).type_as(x)
|
63 |
+
logits /= self.logit_temp
|
64 |
+
if neg_is_pos.any():
|
65 |
+
logits[1:][neg_is_pos] = float("-inf")
|
66 |
+
logits = logits.transpose(0, 1) # (num_x, num_cls+1)
|
67 |
+
return logits
|
68 |
+
|
69 |
+
def forward(self, x, padding_mask, mask_indices, target_list):
|
70 |
+
def compute_pred(proj_x, target, label_embs):
|
71 |
+
# compute logits for the i-th label set
|
72 |
+
y = torch.index_select(label_embs, 0, target.long())
|
73 |
+
negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1)
|
74 |
+
if self.target_glu:
|
75 |
+
y = self.target_glu(y)
|
76 |
+
negs = self.target_glu(negs)
|
77 |
+
# proj_x: (S, D)
|
78 |
+
# y: (S, D)
|
79 |
+
# negs: (Neg, S, D)
|
80 |
+
return self.compute_nce(proj_x, y, negs)
|
81 |
+
|
82 |
+
label_embs_list = self.label_embs_concat.split(self.num_classes, 0)
|
83 |
+
|
84 |
+
if not self.skip_masked:
|
85 |
+
masked_indices = torch.logical_and(~padding_mask, mask_indices)
|
86 |
+
proj_x_m = self.final_proj(x[masked_indices])
|
87 |
+
if self.untie_final_proj:
|
88 |
+
proj_x_m_list = proj_x_m.chunk(len(target_list), dim=-1)
|
89 |
+
else:
|
90 |
+
proj_x_m_list = [proj_x_m for _ in range(len(target_list))]
|
91 |
+
logit_m_list = [
|
92 |
+
compute_pred(proj_x_m, t[masked_indices], label_embs_list[i])
|
93 |
+
for i, (proj_x_m, t) in enumerate(
|
94 |
+
zip(proj_x_m_list, target_list)
|
95 |
+
)
|
96 |
+
]
|
97 |
+
else:
|
98 |
+
logit_m_list = [None for _ in target_list]
|
99 |
+
|
100 |
+
if not self.skip_nomask:
|
101 |
+
nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
|
102 |
+
proj_x_u = self.final_proj(x[nomask_indices])
|
103 |
+
if self.untie_final_proj:
|
104 |
+
proj_x_u_list = proj_x_u.chunk(len(target_list), dim=-1)
|
105 |
+
else:
|
106 |
+
proj_x_u_list = [proj_x_u for _ in range(len(target_list))]
|
107 |
+
|
108 |
+
logit_u_list = [
|
109 |
+
compute_pred(proj_x_u, t[nomask_indices], label_embs_list[i])
|
110 |
+
for i, (proj_x_u, t) in enumerate(
|
111 |
+
zip(proj_x_u_list, target_list)
|
112 |
+
)
|
113 |
+
]
|
114 |
+
else:
|
115 |
+
logit_u_list = [None for _ in target_list]
|
116 |
+
|
117 |
+
result = {
|
118 |
+
"logit_m_list": logit_m_list,
|
119 |
+
"logit_u_list": logit_u_list,
|
120 |
+
"padding_mask": padding_mask,
|
121 |
+
}
|
122 |
+
|
123 |
+
return result
|
artst/models/modules/speech_encoder_prenet.py
ADDED
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
|
5 |
+
# Based on speecht5, fairseq and espnet code bases
|
6 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
7 |
+
# --------------------------------------------------------
|
8 |
+
|
9 |
+
import logging
|
10 |
+
import math
|
11 |
+
import torch
|
12 |
+
import contextlib
|
13 |
+
from typing import List, Tuple
|
14 |
+
import torch.nn as nn
|
15 |
+
|
16 |
+
from fairseq.data.data_utils import lengths_to_padding_mask
|
17 |
+
from fairseq.data.data_utils import compute_mask_indices
|
18 |
+
from fairseq.modules import (
|
19 |
+
PositionalEmbedding,
|
20 |
+
Fp32GroupNorm,
|
21 |
+
FairseqDropout,
|
22 |
+
SamePad,
|
23 |
+
GradMultiply,
|
24 |
+
LayerNorm,
|
25 |
+
Fp32LayerNorm,
|
26 |
+
TransposeLast,
|
27 |
+
)
|
28 |
+
import numpy as np
|
29 |
+
|
30 |
+
logger = logging.getLogger(__name__)
|
31 |
+
|
32 |
+
|
33 |
+
class LinearLayer(nn.Module):
|
34 |
+
def __init__(self, idim, odom, dropout=0):
|
35 |
+
super(LinearLayer, self).__init__()
|
36 |
+
self.linear = nn.Sequential(
|
37 |
+
nn.Linear(idim, odom),
|
38 |
+
nn.LayerNorm(odom),
|
39 |
+
nn.Dropout(dropout),
|
40 |
+
nn.ReLU(),
|
41 |
+
)
|
42 |
+
|
43 |
+
def get_out_seq_lens_tensor(self, in_seq_lens_tensor):
|
44 |
+
out = in_seq_lens_tensor.clone()
|
45 |
+
return out
|
46 |
+
|
47 |
+
def forward(self, src_tokens, src_lengths):
|
48 |
+
"""
|
49 |
+
src_tokens: [B, T, C]
|
50 |
+
src_lengths: [B]
|
51 |
+
"""
|
52 |
+
x = self.linear(src_tokens)
|
53 |
+
x = x.transpose(0, 1).contiguous() # -> T x B x C
|
54 |
+
return x, src_lengths
|
55 |
+
|
56 |
+
|
57 |
+
class SpeechEncoderPrenet(nn.Module):
|
58 |
+
"""
|
59 |
+
|
60 |
+
Args:
|
61 |
+
in_channels (int): the number of input channels
|
62 |
+
mid_channels (int): the number of intermediate channels
|
63 |
+
out_channels (int): the number of output channels
|
64 |
+
kernel_sizes (List[int]): the kernel size for each convolutional layer
|
65 |
+
"""
|
66 |
+
|
67 |
+
def __init__(self, args):
|
68 |
+
super(SpeechEncoderPrenet, self).__init__()
|
69 |
+
self.dropout_module = FairseqDropout(
|
70 |
+
p=args.dropout, module_name=self.__class__.__name__
|
71 |
+
)
|
72 |
+
self.embed_scale = math.sqrt(args.encoder_embed_dim)
|
73 |
+
if args.no_scale_embedding:
|
74 |
+
self.embed_scale = 1.0
|
75 |
+
self.padding_idx = 1
|
76 |
+
self.freeze_encoder_updates = args.freeze_encoder_updates
|
77 |
+
self.num_updates = 0
|
78 |
+
assert args.encoder_speech_prenet in ["conv", "linear"], args.encoder_speech_prenet
|
79 |
+
feature_enc_layers = eval(args.conv_feature_layers) # noqa
|
80 |
+
self.embed = feature_enc_layers[-1][0]
|
81 |
+
|
82 |
+
self.feature_extractor = ConvFeatureExtractionModel(
|
83 |
+
conv_layers=feature_enc_layers,
|
84 |
+
dropout=0.0,
|
85 |
+
mode=args.extractor_mode,
|
86 |
+
conv_bias=args.conv_bias,
|
87 |
+
)
|
88 |
+
feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers])
|
89 |
+
self.feat2tar_ratio = (
|
90 |
+
args.label_rates * feature_ds_rate / args.sample_rate
|
91 |
+
)
|
92 |
+
|
93 |
+
self.post_extract_proj = (
|
94 |
+
nn.Linear(self.embed, args.encoder_embed_dim)
|
95 |
+
if self.embed != args.encoder_embed_dim
|
96 |
+
else None
|
97 |
+
)
|
98 |
+
|
99 |
+
self.use_conv_pos = args.use_conv_pos
|
100 |
+
self.use_sinc_pos = args.use_sinc_pos
|
101 |
+
self.use_abs_pos = getattr(args, "use_abs_pos", False)
|
102 |
+
|
103 |
+
self.feature_grad_mult = args.feature_grad_mult
|
104 |
+
if self.use_conv_pos:
|
105 |
+
self.layer_norm = LayerNorm(self.embed)
|
106 |
+
self.pos_conv = nn.Conv1d(
|
107 |
+
args.encoder_embed_dim,
|
108 |
+
args.encoder_embed_dim,
|
109 |
+
kernel_size=args.conv_pos,
|
110 |
+
padding=args.conv_pos // 2,
|
111 |
+
groups=args.conv_pos_groups,
|
112 |
+
)
|
113 |
+
dropout = 0
|
114 |
+
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * args.encoder_embed_dim))
|
115 |
+
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
|
116 |
+
nn.init.constant_(self.pos_conv.bias, 0)
|
117 |
+
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
|
118 |
+
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
|
119 |
+
|
120 |
+
assert not (self.use_sinc_pos and self.use_abs_pos), f"sinc pos: {self.use_sinc_pos} abs pos: {self.use_abs_pos}"
|
121 |
+
if self.use_sinc_pos:
|
122 |
+
self.embed_positions = PositionalEmbedding(
|
123 |
+
args.max_speech_positions, args.encoder_embed_dim, self.padding_idx
|
124 |
+
)
|
125 |
+
if self.use_abs_pos:
|
126 |
+
self.embed_positions = PositionalEmbedding(
|
127 |
+
args.max_speech_positions, args.encoder_embed_dim, self.padding_idx, learned=True
|
128 |
+
)
|
129 |
+
|
130 |
+
# Hubert
|
131 |
+
self.mask_prob = args.mask_prob
|
132 |
+
self.mask_selection = args.mask_selection
|
133 |
+
self.mask_other = args.mask_other
|
134 |
+
self.hubert_mask_length = args.hubert_mask_length
|
135 |
+
self.no_mask_overlap = args.no_mask_overlap
|
136 |
+
self.mask_min_space = args.mask_min_space
|
137 |
+
|
138 |
+
self.mask_channel_prob = args.mask_channel_prob
|
139 |
+
self.mask_channel_selection = args.mask_channel_selection
|
140 |
+
self.mask_channel_other = args.mask_channel_other
|
141 |
+
self.mask_channel_length = args.mask_channel_length
|
142 |
+
self.no_mask_channel_overlap = args.no_mask_channel_overlap
|
143 |
+
self.mask_channel_min_space = args.mask_channel_min_space
|
144 |
+
|
145 |
+
self.mask_emb = nn.Parameter(
|
146 |
+
torch.FloatTensor(args.encoder_embed_dim).uniform_()
|
147 |
+
)
|
148 |
+
|
149 |
+
def forward(self, src_tokens, require_feat_pen=False, target_list=None, padding_mask=None, mask=True):
|
150 |
+
ft = self.freeze_encoder_updates <= self.num_updates
|
151 |
+
with torch.no_grad() if not ft else contextlib.ExitStack():
|
152 |
+
return self._forward(src_tokens, require_feat_pen, target_list, padding_mask, mask)
|
153 |
+
|
154 |
+
def _forward(self, src_tokens, require_feat_pen=False, target_list=None, padding_mask=None, mask=True):
|
155 |
+
if self.feature_grad_mult > 0:
|
156 |
+
x = self.feature_extractor(src_tokens)
|
157 |
+
x = x.transpose(1, 2).transpose(0, 1) # [length, batch, hidden_size]
|
158 |
+
if self.feature_grad_mult != 1.0:
|
159 |
+
x = GradMultiply.apply(x, self.feature_grad_mult)
|
160 |
+
else:
|
161 |
+
with torch.no_grad():
|
162 |
+
x = self.feature_extractor(src_tokens)
|
163 |
+
x = x.transpose(1, 2).transpose(0, 1) # [length, batch, hidden_size]
|
164 |
+
x = x.transpose(0, 1) # [batch, length, hidden_size]
|
165 |
+
|
166 |
+
encoder_padding_mask = padding_mask
|
167 |
+
|
168 |
+
x = x.transpose(1, 2) # [batch, hidden_size, length]
|
169 |
+
if target_list is not None:
|
170 |
+
x, target_list = self.forward_targets(x, target_list)
|
171 |
+
features_pen = x.float().pow(2).mean()
|
172 |
+
x = x.transpose(1, 2) # [batch, length, hidden_size]
|
173 |
+
x = self.layer_norm(x)
|
174 |
+
encoder_padding_mask = self.forward_padding_mask(x, encoder_padding_mask)
|
175 |
+
if self.post_extract_proj is not None:
|
176 |
+
x = self.post_extract_proj(x)
|
177 |
+
x = self.dropout_module(x)
|
178 |
+
if mask:
|
179 |
+
x, mask_indices = self.apply_hubert_mask(
|
180 |
+
x, encoder_padding_mask
|
181 |
+
)
|
182 |
+
else:
|
183 |
+
x = x
|
184 |
+
mask_indices = None
|
185 |
+
|
186 |
+
if self.use_conv_pos:
|
187 |
+
positions = self.pos_conv(x.transpose(1, 2))
|
188 |
+
positions = positions.transpose(1, 2)
|
189 |
+
#else:
|
190 |
+
# positions = self.embed_positions(encoder_padding_mask)
|
191 |
+
x = x + positions
|
192 |
+
|
193 |
+
if self.use_sinc_pos:
|
194 |
+
positions = self.embed_positions(encoder_padding_mask)
|
195 |
+
x = x + positions
|
196 |
+
|
197 |
+
# x = self.dropout_module(x)
|
198 |
+
|
199 |
+
if require_feat_pen:
|
200 |
+
return (x, features_pen, mask_indices, target_list), encoder_padding_mask
|
201 |
+
else:
|
202 |
+
# For consistence with encoder
|
203 |
+
return x, encoder_padding_mask
|
204 |
+
|
205 |
+
def forward_targets(
|
206 |
+
self, features: torch.Tensor, target_list: List[torch.Tensor],
|
207 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
208 |
+
# Trim features to ensure labels exist and then get aligned labels
|
209 |
+
feat_tsz = features.size(2)
|
210 |
+
targ_tsz = min([t.size(1) for t in target_list])
|
211 |
+
if self.feat2tar_ratio * feat_tsz > targ_tsz:
|
212 |
+
feat_tsz = int(targ_tsz / self.feat2tar_ratio)
|
213 |
+
features = features[..., :feat_tsz]
|
214 |
+
target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
|
215 |
+
target_list = [t[:, target_inds.long()] for t in target_list]
|
216 |
+
return features, target_list
|
217 |
+
|
218 |
+
def forward_padding_mask(
|
219 |
+
self, features: torch.Tensor, padding_mask: torch.Tensor,
|
220 |
+
) -> torch.Tensor:
|
221 |
+
extra = padding_mask.size(1) % features.size(1)
|
222 |
+
if extra > 0:
|
223 |
+
padding_mask = padding_mask[:, :-extra]
|
224 |
+
padding_mask = padding_mask.view(
|
225 |
+
padding_mask.size(0), features.size(1), -1
|
226 |
+
)
|
227 |
+
padding_mask = padding_mask.all(-1)
|
228 |
+
return padding_mask
|
229 |
+
|
230 |
+
def get_src_lengths(self, src_lengths):
|
231 |
+
return self.feature_extractor.get_out_seq_lens_tensor(src_lengths)
|
232 |
+
|
233 |
+
def apply_hubert_mask(self, x, padding_mask):
|
234 |
+
B, T, C = x.shape
|
235 |
+
if self.mask_prob > 0:
|
236 |
+
mask_indices = compute_mask_indices(
|
237 |
+
(B, T),
|
238 |
+
padding_mask,
|
239 |
+
self.mask_prob,
|
240 |
+
self.hubert_mask_length,
|
241 |
+
self.mask_selection,
|
242 |
+
self.mask_other,
|
243 |
+
min_masks=2,
|
244 |
+
no_overlap=self.no_mask_overlap,
|
245 |
+
min_space=self.mask_min_space,
|
246 |
+
)
|
247 |
+
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
248 |
+
x[mask_indices] = self.mask_emb
|
249 |
+
else:
|
250 |
+
mask_indices = None
|
251 |
+
|
252 |
+
if self.mask_channel_prob > 0:
|
253 |
+
mask_channel_indices = compute_mask_indices(
|
254 |
+
(B, C),
|
255 |
+
None,
|
256 |
+
self.mask_channel_prob,
|
257 |
+
self.mask_channel_length,
|
258 |
+
self.mask_channel_selection,
|
259 |
+
self.mask_channel_other,
|
260 |
+
no_overlap=self.no_mask_channel_overlap,
|
261 |
+
min_space=self.mask_channel_min_space,
|
262 |
+
)
|
263 |
+
mask_channel_indices = (
|
264 |
+
torch.from_numpy(mask_channel_indices)
|
265 |
+
.to(x.device)
|
266 |
+
.unsqueeze(1)
|
267 |
+
.expand(-1, T, -1)
|
268 |
+
)
|
269 |
+
x[mask_channel_indices] = 0
|
270 |
+
|
271 |
+
return x, mask_indices
|
272 |
+
|
273 |
+
def set_num_updates(self, num_updates):
|
274 |
+
"""Set the number of parameters updates."""
|
275 |
+
self.num_updates = num_updates
|
276 |
+
|
277 |
+
class ConvFeatureExtractionModel(nn.Module):
|
278 |
+
def __init__(
|
279 |
+
self,
|
280 |
+
conv_layers: List[Tuple[int, int, int]],
|
281 |
+
dropout: float = 0.0,
|
282 |
+
mode: str = "default",
|
283 |
+
conv_bias: bool = False,
|
284 |
+
):
|
285 |
+
super().__init__()
|
286 |
+
|
287 |
+
assert mode in {"default", "layer_norm"}
|
288 |
+
|
289 |
+
def block(
|
290 |
+
n_in,
|
291 |
+
n_out,
|
292 |
+
k,
|
293 |
+
stride,
|
294 |
+
is_layer_norm=False,
|
295 |
+
is_group_norm=False,
|
296 |
+
conv_bias=False,
|
297 |
+
):
|
298 |
+
def make_conv():
|
299 |
+
conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
|
300 |
+
nn.init.kaiming_normal_(conv.weight)
|
301 |
+
return conv
|
302 |
+
|
303 |
+
assert (
|
304 |
+
is_layer_norm and is_group_norm
|
305 |
+
) == False, "layer norm and group norm are exclusive"
|
306 |
+
|
307 |
+
if is_layer_norm:
|
308 |
+
return nn.Sequential(
|
309 |
+
make_conv(),
|
310 |
+
nn.Dropout(p=dropout),
|
311 |
+
nn.Sequential(
|
312 |
+
TransposeLast(),
|
313 |
+
Fp32LayerNorm(dim, elementwise_affine=True),
|
314 |
+
TransposeLast(),
|
315 |
+
),
|
316 |
+
nn.GELU(),
|
317 |
+
)
|
318 |
+
elif is_group_norm:
|
319 |
+
return nn.Sequential(
|
320 |
+
make_conv(),
|
321 |
+
nn.Dropout(p=dropout),
|
322 |
+
Fp32GroupNorm(dim, dim, affine=True),
|
323 |
+
nn.GELU(),
|
324 |
+
)
|
325 |
+
else:
|
326 |
+
return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
|
327 |
+
|
328 |
+
in_d = 1
|
329 |
+
self.conv_layers = nn.ModuleList()
|
330 |
+
self.conv_layers_infos = conv_layers
|
331 |
+
for i, cl in enumerate(conv_layers):
|
332 |
+
assert len(cl) == 3, "invalid conv definition: " + str(cl)
|
333 |
+
(dim, k, stride) = cl
|
334 |
+
|
335 |
+
self.conv_layers.append(
|
336 |
+
block(
|
337 |
+
in_d,
|
338 |
+
dim,
|
339 |
+
k,
|
340 |
+
stride,
|
341 |
+
is_layer_norm=mode == "layer_norm",
|
342 |
+
is_group_norm=mode == "default" and i == 0,
|
343 |
+
conv_bias=conv_bias,
|
344 |
+
)
|
345 |
+
)
|
346 |
+
in_d = dim
|
347 |
+
|
348 |
+
def forward(self, x):
|
349 |
+
# BxT -> BxCxT
|
350 |
+
x = x.unsqueeze(1)
|
351 |
+
for conv in self.conv_layers:
|
352 |
+
x = conv(x)
|
353 |
+
return x
|
354 |
+
|
355 |
+
def get_out_seq_lens_nonmask_after_a_layer(self, in_seq_lens_tensor, i):
|
356 |
+
"""Returns the out_seq_lens_nonmask 0/1 tensor after a layer.
|
357 |
+
|
358 |
+
Args:
|
359 |
+
in_seq_lens_tensor (LongTensor): length
|
360 |
+
|
361 |
+
Returns:
|
362 |
+
LongTensor: length
|
363 |
+
"""
|
364 |
+
out_lengths = in_seq_lens_tensor.clone()
|
365 |
+
out_lengths = ((out_lengths.float() - (self.conv_layers_infos[i][1] - 1) - 1) / self.conv_layers_infos[i][-1] + 1).floor().long()
|
366 |
+
out_nonmask = (~lengths_to_padding_mask(out_lengths)).float()
|
367 |
+
return out_nonmask, out_lengths
|
368 |
+
|
369 |
+
def get_out_seq_lens_tensor(self, in_seq_lens_tensor):
|
370 |
+
out = in_seq_lens_tensor.clone()
|
371 |
+
for i in range(len(self.conv_layers)):
|
372 |
+
out = ((out.float() - (self.conv_layers_infos[i][1] - 1) - 1) / self.conv_layers_infos[i][-1] + 1).floor().long()
|
373 |
+
return out
|
artst/models/modules/text_decoder_postnet.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
|
5 |
+
# Based on speecht5, fairseq and espnet code bases
|
6 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
7 |
+
# --------------------------------------------------------
|
8 |
+
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch
|
11 |
+
import contextlib
|
12 |
+
|
13 |
+
from fairseq import utils
|
14 |
+
from fairseq.modules import (
|
15 |
+
AdaptiveSoftmax,
|
16 |
+
)
|
17 |
+
|
18 |
+
class TextDecoderPostnet(nn.Module):
|
19 |
+
"""
|
20 |
+
|
21 |
+
Args:
|
22 |
+
in_channels (int): the number of input channels
|
23 |
+
mid_channels (int): the number of intermediate channels
|
24 |
+
out_channels (int): the number of output channels
|
25 |
+
kernel_sizes (List[int]): the kernel size for each convolutional layer
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(self, embed_tokens, dictionary, args, output_projection=None,):
|
29 |
+
super(TextDecoderPostnet, self).__init__()
|
30 |
+
self.output_embed_dim = args.decoder_output_dim
|
31 |
+
self.output_projection = output_projection
|
32 |
+
self.adaptive_softmax = None
|
33 |
+
self.share_input_output_embed = args.share_input_output_embed
|
34 |
+
if self.output_projection is None:
|
35 |
+
self.build_output_projection(args, dictionary, embed_tokens)
|
36 |
+
self.freeze_decoder_updates = args.freeze_decoder_updates
|
37 |
+
self.num_updates = 0
|
38 |
+
|
39 |
+
def output_layer(self, features):
|
40 |
+
"""Project features to the vocabulary size."""
|
41 |
+
if self.adaptive_softmax is None:
|
42 |
+
# project back to size of vocabulary
|
43 |
+
return self.output_projection(features)
|
44 |
+
else:
|
45 |
+
return features
|
46 |
+
|
47 |
+
def build_output_projection(self, args, dictionary, embed_tokens):
|
48 |
+
if args.adaptive_softmax_cutoff is not None:
|
49 |
+
self.adaptive_softmax = AdaptiveSoftmax(
|
50 |
+
len(dictionary),
|
51 |
+
self.output_embed_dim,
|
52 |
+
utils.eval_str_list(args.adaptive_softmax_cutoff, type=int),
|
53 |
+
dropout=args.adaptive_softmax_dropout,
|
54 |
+
adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None,
|
55 |
+
factor=args.adaptive_softmax_factor,
|
56 |
+
tie_proj=args.tie_adaptive_proj,
|
57 |
+
)
|
58 |
+
elif self.share_input_output_embed:
|
59 |
+
self.output_projection = nn.Linear(
|
60 |
+
embed_tokens.weight.shape[1],
|
61 |
+
embed_tokens.weight.shape[0],
|
62 |
+
bias=False,
|
63 |
+
)
|
64 |
+
self.output_projection.weight = embed_tokens.weight
|
65 |
+
else:
|
66 |
+
self.output_projection = nn.Linear(
|
67 |
+
self.output_embed_dim, len(dictionary), bias=False
|
68 |
+
)
|
69 |
+
nn.init.normal_(
|
70 |
+
self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5
|
71 |
+
)
|
72 |
+
# num_base_layers = getattr(args, "base_layers", 0)
|
73 |
+
# for i in range(num_base_layers):
|
74 |
+
# self.layers.insert(
|
75 |
+
# ((i + 1) * args.decoder_layers) // (num_base_layers + 1),
|
76 |
+
# BaseLayer(args),
|
77 |
+
# )
|
78 |
+
|
79 |
+
def forward(self, x):
|
80 |
+
ft = self.freeze_decoder_updates <= self.num_updates
|
81 |
+
with torch.no_grad() if not ft else contextlib.ExitStack():
|
82 |
+
return self._forward(x)
|
83 |
+
|
84 |
+
def _forward(self, x):
|
85 |
+
# embed positions
|
86 |
+
x = self.output_layer(x)
|
87 |
+
|
88 |
+
return x
|
89 |
+
|
90 |
+
def set_num_updates(self, num_updates):
|
91 |
+
"""Set the number of parameters updates."""
|
92 |
+
self.num_updates = num_updates
|
artst/models/modules/text_decoder_prenet.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
|
5 |
+
# Based on speecht5, fairseq and espnet code bases
|
6 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
7 |
+
# --------------------------------------------------------
|
8 |
+
|
9 |
+
import math
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch
|
12 |
+
import contextlib
|
13 |
+
|
14 |
+
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
|
15 |
+
from fairseq.models.transformer import Linear #,LayerNorm
|
16 |
+
from fairseq.modules import (
|
17 |
+
PositionalEmbedding,
|
18 |
+
FairseqDropout,
|
19 |
+
LayerNorm
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
class TextDecoderPrenet(nn.Module):
|
24 |
+
"""
|
25 |
+
|
26 |
+
Args:
|
27 |
+
in_channels (int): the number of input channels
|
28 |
+
mid_channels (int): the number of intermediate channels
|
29 |
+
out_channels (int): the number of output channels
|
30 |
+
kernel_sizes (List[int]): the kernel size for each convolutional layer
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(self, embed_tokens, args):
|
34 |
+
super(TextDecoderPrenet, self).__init__()
|
35 |
+
self.dropout_module = FairseqDropout(
|
36 |
+
args.dropout, module_name=self.__class__.__name__
|
37 |
+
)
|
38 |
+
self.decoder_layerdrop = args.decoder_layerdrop
|
39 |
+
self.num_updates = 0
|
40 |
+
|
41 |
+
input_embed_dim = embed_tokens.embedding_dim
|
42 |
+
embed_dim = args.decoder_embed_dim
|
43 |
+
self.embed_dim = embed_dim
|
44 |
+
self.output_embed_dim = args.decoder_output_dim
|
45 |
+
|
46 |
+
self.padding_idx = embed_tokens.padding_idx
|
47 |
+
|
48 |
+
self.embed_tokens = embed_tokens
|
49 |
+
|
50 |
+
self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
|
51 |
+
|
52 |
+
if not args.adaptive_input and args.quant_noise_pq > 0:
|
53 |
+
self.quant_noise = apply_quant_noise_(
|
54 |
+
nn.Linear(embed_dim, embed_dim, bias=False),
|
55 |
+
args.quant_noise_pq,
|
56 |
+
args.quant_noise_pq_block_size,
|
57 |
+
)
|
58 |
+
else:
|
59 |
+
self.quant_noise = None
|
60 |
+
|
61 |
+
self.project_in_dim = (
|
62 |
+
Linear(input_embed_dim, embed_dim, bias=False)
|
63 |
+
if embed_dim != input_embed_dim
|
64 |
+
else None
|
65 |
+
)
|
66 |
+
self.embed_positions = (
|
67 |
+
PositionalEmbedding(
|
68 |
+
args.max_text_positions,
|
69 |
+
embed_dim,
|
70 |
+
self.padding_idx,
|
71 |
+
learned=args.decoder_learned_pos,
|
72 |
+
)
|
73 |
+
if not args.no_token_positional_embeddings
|
74 |
+
else None
|
75 |
+
)
|
76 |
+
export = getattr(args, "export", False)
|
77 |
+
if getattr(args, "layernorm_embedding", False):
|
78 |
+
self.layernorm_embedding = LayerNorm(embed_dim, export=export)
|
79 |
+
else:
|
80 |
+
self.layernorm_embedding = None
|
81 |
+
|
82 |
+
self.freeze_decoder_updates = args.freeze_decoder_updates
|
83 |
+
|
84 |
+
def forward(self, prev_output_tokens, incremental_state=None):
|
85 |
+
ft = self.freeze_decoder_updates <= self.num_updates
|
86 |
+
with torch.no_grad() if not ft else contextlib.ExitStack():
|
87 |
+
return self._forward(prev_output_tokens, incremental_state)
|
88 |
+
|
89 |
+
def _forward(self, prev_output_tokens, incremental_state=None):
|
90 |
+
if prev_output_tokens.eq(self.padding_idx).any():
|
91 |
+
x_mask = prev_output_tokens.eq(self.padding_idx)
|
92 |
+
else:
|
93 |
+
x_mask = None
|
94 |
+
|
95 |
+
# embed positions
|
96 |
+
positions = None
|
97 |
+
if self.embed_positions is not None:
|
98 |
+
positions = self.embed_positions(
|
99 |
+
prev_output_tokens, incremental_state=incremental_state
|
100 |
+
)
|
101 |
+
|
102 |
+
if incremental_state is not None:
|
103 |
+
prev_output_tokens = prev_output_tokens[:, -1:]
|
104 |
+
if positions is not None:
|
105 |
+
positions = positions[:, -1:]
|
106 |
+
|
107 |
+
# embed tokens and positions
|
108 |
+
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
|
109 |
+
|
110 |
+
if self.quant_noise is not None:
|
111 |
+
x = self.quant_noise(x)
|
112 |
+
|
113 |
+
if self.project_in_dim is not None:
|
114 |
+
x = self.project_in_dim(x)
|
115 |
+
|
116 |
+
if positions is not None:
|
117 |
+
x += positions
|
118 |
+
|
119 |
+
if self.layernorm_embedding is not None:
|
120 |
+
x = self.layernorm_embedding(x)
|
121 |
+
|
122 |
+
x = self.dropout_module(x)
|
123 |
+
|
124 |
+
return x, x_mask, incremental_state
|
125 |
+
|
126 |
+
def set_num_updates(self, num_updates):
|
127 |
+
"""Set the number of parameters updates."""
|
128 |
+
self.num_updates = num_updates
|
artst/models/modules/text_encoder_prenet.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
|
5 |
+
# Based on speecht5, fairseq and espnet code bases
|
6 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
7 |
+
# --------------------------------------------------------
|
8 |
+
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding
|
12 |
+
from espnet.nets.pytorch_backend.transformer.embedding import ScaledPositionalEncoding
|
13 |
+
|
14 |
+
|
15 |
+
class TextEncoderPrenet(nn.Module):
|
16 |
+
"""
|
17 |
+
|
18 |
+
Args:
|
19 |
+
in_channels (int): the number of input channels
|
20 |
+
mid_channels (int): the number of intermediate channels
|
21 |
+
out_channels (int): the number of output channels
|
22 |
+
kernel_sizes (List[int]): the kernel size for each convolutional layer
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
embed_tokens,
|
28 |
+
args,
|
29 |
+
):
|
30 |
+
super(TextEncoderPrenet, self).__init__()
|
31 |
+
self.padding_idx = embed_tokens.padding_idx
|
32 |
+
# define encoder prenet
|
33 |
+
# get positional encoding class
|
34 |
+
pos_enc_class = (
|
35 |
+
ScaledPositionalEncoding if args.enc_use_scaled_pos_enc else PositionalEncoding
|
36 |
+
)
|
37 |
+
|
38 |
+
self.encoder_prenet = nn.Sequential(
|
39 |
+
embed_tokens,
|
40 |
+
pos_enc_class(args.encoder_embed_dim, args.transformer_enc_positional_dropout_rate, max_len=args.max_text_positions),
|
41 |
+
)
|
42 |
+
|
43 |
+
def forward(self, src_tokens):
|
44 |
+
return self.encoder_prenet(src_tokens), src_tokens.eq(self.padding_idx)
|
artst/models/modules/transformer_layer.py
ADDED
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
|
5 |
+
# Based on speecht5, fairseq and espnet code bases
|
6 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
7 |
+
# --------------------------------------------------------
|
8 |
+
|
9 |
+
from typing import Dict, List, Optional
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import contextlib
|
14 |
+
from fairseq import utils
|
15 |
+
from fairseq.modules import LayerNorm
|
16 |
+
from .multihead_attention import MultiheadAttention
|
17 |
+
from fairseq.modules.fairseq_dropout import FairseqDropout
|
18 |
+
from fairseq.modules.quant_noise import quant_noise
|
19 |
+
from torch import Tensor
|
20 |
+
|
21 |
+
|
22 |
+
class TransformerSentenceEncoderLayer(nn.Module):
|
23 |
+
"""
|
24 |
+
Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
|
25 |
+
models.
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
embedding_dim: float = 768,
|
31 |
+
ffn_embedding_dim: float = 3072,
|
32 |
+
num_attention_heads: float = 8,
|
33 |
+
dropout: float = 0.1,
|
34 |
+
attention_dropout: float = 0.1,
|
35 |
+
activation_dropout: float = 0.1,
|
36 |
+
activation_fn: str = "relu",
|
37 |
+
layer_norm_first: bool = False,
|
38 |
+
has_relative_attention_bias: bool = False,
|
39 |
+
) -> None:
|
40 |
+
|
41 |
+
super().__init__()
|
42 |
+
# Initialize parameters
|
43 |
+
self.embedding_dim = embedding_dim
|
44 |
+
self.dropout = dropout
|
45 |
+
self.activation_dropout = activation_dropout
|
46 |
+
|
47 |
+
# Initialize blocks
|
48 |
+
self.activation_fn = utils.get_activation_fn(activation_fn)
|
49 |
+
self.self_attn = MultiheadAttention(
|
50 |
+
self.embedding_dim,
|
51 |
+
num_attention_heads,
|
52 |
+
dropout=attention_dropout,
|
53 |
+
self_attention=True,
|
54 |
+
has_relative_attention_bias=has_relative_attention_bias,
|
55 |
+
)
|
56 |
+
|
57 |
+
self.dropout1 = nn.Dropout(dropout)
|
58 |
+
self.dropout2 = nn.Dropout(self.activation_dropout)
|
59 |
+
self.dropout3 = nn.Dropout(dropout)
|
60 |
+
|
61 |
+
self.layer_norm_first = layer_norm_first
|
62 |
+
|
63 |
+
# layer norm associated with the self attention layer
|
64 |
+
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
|
65 |
+
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
|
66 |
+
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
|
67 |
+
|
68 |
+
# layer norm associated with the position wise feed-forward NN
|
69 |
+
self.final_layer_norm = LayerNorm(self.embedding_dim)
|
70 |
+
|
71 |
+
if has_relative_attention_bias:
|
72 |
+
self.norm_k = LayerNorm(self.embedding_dim//num_attention_heads)
|
73 |
+
|
74 |
+
def forward(
|
75 |
+
self,
|
76 |
+
x: torch.Tensor,
|
77 |
+
self_attn_mask: torch.Tensor = None,
|
78 |
+
self_attn_padding_mask: torch.Tensor = None,
|
79 |
+
need_weights: bool = False,
|
80 |
+
att_args=None,
|
81 |
+
pos_bias=None,
|
82 |
+
):
|
83 |
+
"""
|
84 |
+
LayerNorm is applied either before or after the self-attention/ffn
|
85 |
+
modules similar to the original Transformer imlementation.
|
86 |
+
"""
|
87 |
+
residual = x
|
88 |
+
|
89 |
+
if self.layer_norm_first:
|
90 |
+
x = self.self_attn_layer_norm(x)
|
91 |
+
if pos_bias is not None:
|
92 |
+
pos_bias = self.norm_k(pos_bias)
|
93 |
+
x, attn = self.self_attn(
|
94 |
+
query=x,
|
95 |
+
key=x,
|
96 |
+
value=x,
|
97 |
+
key_padding_mask=self_attn_padding_mask,
|
98 |
+
attn_mask=self_attn_mask,
|
99 |
+
position_bias=pos_bias,
|
100 |
+
)
|
101 |
+
x = self.dropout1(x)
|
102 |
+
x = residual + x
|
103 |
+
|
104 |
+
residual = x
|
105 |
+
x = self.final_layer_norm(x)
|
106 |
+
x = self.activation_fn(self.fc1(x))
|
107 |
+
x = self.dropout2(x)
|
108 |
+
x = self.fc2(x)
|
109 |
+
x = self.dropout3(x)
|
110 |
+
x = residual + x
|
111 |
+
else:
|
112 |
+
x, attn = self.self_attn(
|
113 |
+
query=x,
|
114 |
+
key=x,
|
115 |
+
value=x,
|
116 |
+
key_padding_mask=self_attn_padding_mask,
|
117 |
+
position_bias=pos_bias,
|
118 |
+
)
|
119 |
+
|
120 |
+
x = self.dropout1(x)
|
121 |
+
x = residual + x
|
122 |
+
|
123 |
+
x = self.self_attn_layer_norm(x)
|
124 |
+
|
125 |
+
residual = x
|
126 |
+
x = self.activation_fn(self.fc1(x))
|
127 |
+
x = self.dropout2(x)
|
128 |
+
x = self.fc2(x)
|
129 |
+
x = self.dropout3(x)
|
130 |
+
x = residual + x
|
131 |
+
x = self.final_layer_norm(x)
|
132 |
+
|
133 |
+
return x, attn
|
134 |
+
|
135 |
+
|
136 |
+
class TransformerDecoderLayer(nn.Module):
|
137 |
+
"""Decoder layer block.
|
138 |
+
|
139 |
+
In the original paper each operation (multi-head attention, encoder
|
140 |
+
attention or FFN) is postprocessed with: `dropout -> add residual ->
|
141 |
+
layernorm`. In the tensor2tensor code they suggest that learning is more
|
142 |
+
robust when preprocessing each layer with layernorm and postprocessing with:
|
143 |
+
`dropout -> add residual`. We default to the approach in the paper, but the
|
144 |
+
tensor2tensor approach can be enabled by setting
|
145 |
+
*args.decoder_normalize_before* to ``True``.
|
146 |
+
|
147 |
+
Args:
|
148 |
+
args (argparse.Namespace): parsed command-line arguments
|
149 |
+
no_encoder_attn (bool, optional): whether to attend to encoder outputs
|
150 |
+
(default: False).
|
151 |
+
"""
|
152 |
+
|
153 |
+
def __init__(
|
154 |
+
self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False, has_relative_attention_bias=False
|
155 |
+
):
|
156 |
+
super().__init__()
|
157 |
+
self.embed_dim = args.decoder_embed_dim
|
158 |
+
self.num_updates = 0
|
159 |
+
self.dropout_module = FairseqDropout(
|
160 |
+
args.dropout, module_name=self.__class__.__name__
|
161 |
+
)
|
162 |
+
self.quant_noise = getattr(args, "quant_noise_pq", 0)
|
163 |
+
self.quant_noise_block_size = getattr(args, "quant_noise_pq_block_size", 8)
|
164 |
+
|
165 |
+
self.cross_self_attention = getattr(args, "cross_self_attention", False)
|
166 |
+
|
167 |
+
self.freeze_decoder_updates = getattr(args, "freeze_decoder_updates", 0)
|
168 |
+
|
169 |
+
self.self_attn = self.build_self_attention(
|
170 |
+
self.embed_dim,
|
171 |
+
args,
|
172 |
+
add_bias_kv=add_bias_kv,
|
173 |
+
add_zero_attn=add_zero_attn,
|
174 |
+
)
|
175 |
+
|
176 |
+
self.activation_fn = utils.get_activation_fn(
|
177 |
+
activation=str(args.activation_fn)
|
178 |
+
if getattr(args, "activation_fn", None) is not None
|
179 |
+
else "relu"
|
180 |
+
)
|
181 |
+
activation_dropout_p = getattr(args, "activation_dropout", 0) or 0
|
182 |
+
if activation_dropout_p == 0:
|
183 |
+
# for backwards compatibility with models that use args.relu_dropout
|
184 |
+
activation_dropout_p = getattr(args, "relu_dropout", 0) or 0
|
185 |
+
self.activation_dropout_module = FairseqDropout(
|
186 |
+
float(activation_dropout_p), module_name=self.__class__.__name__
|
187 |
+
)
|
188 |
+
self.normalize_before = args.decoder_normalize_before
|
189 |
+
|
190 |
+
export = getattr(args, "export", False)
|
191 |
+
self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
|
192 |
+
|
193 |
+
if no_encoder_attn:
|
194 |
+
self.encoder_attn = None
|
195 |
+
self.encoder_attn_layer_norm = None
|
196 |
+
else:
|
197 |
+
self.encoder_attn = self.build_encoder_attention(self.embed_dim, args)
|
198 |
+
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
|
199 |
+
|
200 |
+
self.fc1 = self.build_fc1(
|
201 |
+
self.embed_dim,
|
202 |
+
args.decoder_ffn_embed_dim,
|
203 |
+
self.quant_noise,
|
204 |
+
self.quant_noise_block_size,
|
205 |
+
)
|
206 |
+
self.fc2 = self.build_fc2(
|
207 |
+
args.decoder_ffn_embed_dim,
|
208 |
+
self.embed_dim,
|
209 |
+
self.quant_noise,
|
210 |
+
self.quant_noise_block_size,
|
211 |
+
)
|
212 |
+
|
213 |
+
self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
|
214 |
+
self.need_attn = True
|
215 |
+
|
216 |
+
self.onnx_trace = False
|
217 |
+
|
218 |
+
self.has_relative_attention_bias = has_relative_attention_bias
|
219 |
+
if self.has_relative_attention_bias:
|
220 |
+
self.norm_k = LayerNorm(self.embed_dim//args.decoder_attention_heads)
|
221 |
+
|
222 |
+
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
|
223 |
+
return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
|
224 |
+
|
225 |
+
def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
|
226 |
+
return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
|
227 |
+
|
228 |
+
def build_self_attention(
|
229 |
+
self, embed_dim, args, add_bias_kv=False, add_zero_attn=False
|
230 |
+
):
|
231 |
+
return MultiheadAttention(
|
232 |
+
embed_dim,
|
233 |
+
args.decoder_attention_heads,
|
234 |
+
dropout=args.attention_dropout,
|
235 |
+
add_bias_kv=add_bias_kv,
|
236 |
+
add_zero_attn=add_zero_attn,
|
237 |
+
self_attention=not getattr(args, "cross_self_attention", False),
|
238 |
+
q_noise=self.quant_noise,
|
239 |
+
qn_block_size=self.quant_noise_block_size,
|
240 |
+
#has_relative_attention_bias=args.has_relative_attention_bias,
|
241 |
+
)
|
242 |
+
|
243 |
+
def build_encoder_attention(self, embed_dim, args):
|
244 |
+
return MultiheadAttention(
|
245 |
+
embed_dim,
|
246 |
+
args.decoder_attention_heads,
|
247 |
+
kdim=getattr(args, "encoder_embed_dim", None),
|
248 |
+
vdim=getattr(args, "encoder_embed_dim", None),
|
249 |
+
dropout=args.attention_dropout,
|
250 |
+
encoder_decoder_attention=True,
|
251 |
+
q_noise=self.quant_noise,
|
252 |
+
qn_block_size=self.quant_noise_block_size,
|
253 |
+
)
|
254 |
+
|
255 |
+
def prepare_for_onnx_export_(self):
|
256 |
+
self.onnx_trace = True
|
257 |
+
|
258 |
+
def residual_connection(self, x, residual):
|
259 |
+
return residual + x
|
260 |
+
|
261 |
+
def forward(
|
262 |
+
self,
|
263 |
+
x,
|
264 |
+
encoder_out: Optional[torch.Tensor] = None,
|
265 |
+
encoder_padding_mask: Optional[torch.Tensor] = None,
|
266 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
267 |
+
prev_self_attn_state: Optional[List[torch.Tensor]] = None,
|
268 |
+
prev_attn_state: Optional[List[torch.Tensor]] = None,
|
269 |
+
self_attn_mask: Optional[torch.Tensor] = None,
|
270 |
+
self_attn_padding_mask: Optional[torch.Tensor] = None,
|
271 |
+
need_attn: bool = False,
|
272 |
+
need_head_weights: bool = False,
|
273 |
+
pos_bias=None,
|
274 |
+
):
|
275 |
+
"""
|
276 |
+
Args:
|
277 |
+
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
|
278 |
+
encoder_padding_mask (ByteTensor, optional): binary
|
279 |
+
ByteTensor of shape `(batch, src_len)` where padding
|
280 |
+
elements are indicated by ``1``.
|
281 |
+
need_attn (bool, optional): return attention weights
|
282 |
+
need_head_weights (bool, optional): return attention weights
|
283 |
+
for each head (default: return average over heads).
|
284 |
+
|
285 |
+
Returns:
|
286 |
+
encoded output of shape `(seq_len, batch, embed_dim)`
|
287 |
+
"""
|
288 |
+
ft = self.freeze_decoder_updates <= self.num_updates
|
289 |
+
|
290 |
+
with torch.no_grad() if not ft else contextlib.ExitStack():
|
291 |
+
if need_head_weights:
|
292 |
+
need_attn = True
|
293 |
+
|
294 |
+
residual = x
|
295 |
+
if self.normalize_before:
|
296 |
+
x = self.self_attn_layer_norm(x)
|
297 |
+
if pos_bias is not None:
|
298 |
+
pos_bias = self.norm_k(pos_bias)
|
299 |
+
if prev_self_attn_state is not None:
|
300 |
+
prev_key, prev_value = prev_self_attn_state[:2]
|
301 |
+
saved_state: Dict[str, Optional[Tensor]] = {
|
302 |
+
"prev_key": prev_key,
|
303 |
+
"prev_value": prev_value,
|
304 |
+
}
|
305 |
+
if len(prev_self_attn_state) >= 3:
|
306 |
+
saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
|
307 |
+
assert incremental_state is not None
|
308 |
+
self.self_attn._set_input_buffer(incremental_state, saved_state)
|
309 |
+
_self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state)
|
310 |
+
if self.cross_self_attention and not (
|
311 |
+
incremental_state is not None
|
312 |
+
and _self_attn_input_buffer is not None
|
313 |
+
and "prev_key" in _self_attn_input_buffer
|
314 |
+
):
|
315 |
+
if self_attn_mask is not None:
|
316 |
+
assert encoder_out is not None
|
317 |
+
self_attn_mask = torch.cat(
|
318 |
+
(x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1
|
319 |
+
)
|
320 |
+
if self_attn_padding_mask is not None:
|
321 |
+
if encoder_padding_mask is None:
|
322 |
+
assert encoder_out is not None
|
323 |
+
encoder_padding_mask = self_attn_padding_mask.new_zeros(
|
324 |
+
encoder_out.size(1), encoder_out.size(0)
|
325 |
+
)
|
326 |
+
self_attn_padding_mask = torch.cat(
|
327 |
+
(encoder_padding_mask, self_attn_padding_mask), dim=1
|
328 |
+
)
|
329 |
+
assert encoder_out is not None
|
330 |
+
y = torch.cat((encoder_out, x), dim=0)
|
331 |
+
else:
|
332 |
+
y = x
|
333 |
+
|
334 |
+
x, attn = self.self_attn(
|
335 |
+
query=x,
|
336 |
+
key=y,
|
337 |
+
value=y,
|
338 |
+
key_padding_mask=self_attn_padding_mask,
|
339 |
+
incremental_state=incremental_state,
|
340 |
+
need_weights=False,
|
341 |
+
attn_mask=self_attn_mask,
|
342 |
+
position_bias=pos_bias,
|
343 |
+
)
|
344 |
+
x = self.dropout_module(x)
|
345 |
+
x = self.residual_connection(x, residual)
|
346 |
+
if not self.normalize_before:
|
347 |
+
x = self.self_attn_layer_norm(x)
|
348 |
+
|
349 |
+
if self.encoder_attn is not None and encoder_out is not None:
|
350 |
+
residual = x
|
351 |
+
if self.normalize_before:
|
352 |
+
x = self.encoder_attn_layer_norm(x)
|
353 |
+
if prev_attn_state is not None:
|
354 |
+
prev_key, prev_value = prev_attn_state[:2]
|
355 |
+
saved_state: Dict[str, Optional[Tensor]] = {
|
356 |
+
"prev_key": prev_key,
|
357 |
+
"prev_value": prev_value,
|
358 |
+
}
|
359 |
+
if len(prev_attn_state) >= 3:
|
360 |
+
saved_state["prev_key_padding_mask"] = prev_attn_state[2]
|
361 |
+
assert incremental_state is not None
|
362 |
+
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
|
363 |
+
|
364 |
+
x, attn = self.encoder_attn(
|
365 |
+
query=x,
|
366 |
+
key=encoder_out,
|
367 |
+
value=encoder_out,
|
368 |
+
key_padding_mask=encoder_padding_mask,
|
369 |
+
incremental_state=incremental_state,
|
370 |
+
static_kv=True,
|
371 |
+
need_weights=need_attn or (not self.training and self.need_attn),
|
372 |
+
need_head_weights=need_head_weights,
|
373 |
+
)
|
374 |
+
x = self.dropout_module(x)
|
375 |
+
x = self.residual_connection(x, residual)
|
376 |
+
if not self.normalize_before:
|
377 |
+
x = self.encoder_attn_layer_norm(x)
|
378 |
+
|
379 |
+
with torch.no_grad() if not ft else contextlib.ExitStack():
|
380 |
+
residual = x
|
381 |
+
if self.normalize_before:
|
382 |
+
x = self.final_layer_norm(x)
|
383 |
+
|
384 |
+
x = self.activation_fn(self.fc1(x))
|
385 |
+
x = self.activation_dropout_module(x)
|
386 |
+
x = self.fc2(x)
|
387 |
+
x = self.dropout_module(x)
|
388 |
+
x = self.residual_connection(x, residual)
|
389 |
+
if not self.normalize_before:
|
390 |
+
x = self.final_layer_norm(x)
|
391 |
+
if self.onnx_trace and incremental_state is not None:
|
392 |
+
saved_state = self.self_attn._get_input_buffer(incremental_state)
|
393 |
+
assert saved_state is not None
|
394 |
+
if self_attn_padding_mask is not None:
|
395 |
+
self_attn_state = [
|
396 |
+
saved_state["prev_key"],
|
397 |
+
saved_state["prev_value"],
|
398 |
+
saved_state["prev_key_padding_mask"],
|
399 |
+
]
|
400 |
+
else:
|
401 |
+
self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]]
|
402 |
+
return x, attn, self_attn_state
|
403 |
+
return x, attn, None
|
404 |
+
|
405 |
+
def make_generation_fast_(self, need_attn: bool = False, **kwargs):
|
406 |
+
self.need_attn = need_attn
|
407 |
+
|
408 |
+
def set_num_updates(self, num_updates):
|
409 |
+
"""Set the number of parameters updates."""
|
410 |
+
self.num_updates = num_updates
|
artst/models/t5_transformer_lm.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
# Based on speecht5, fairseq and espnet code bases
|
5 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
from fairseq.models import (
|
9 |
+
register_model_architecture,
|
10 |
+
)
|
11 |
+
from fairseq.models.transformer_lm import base_lm_architecture
|
12 |
+
|
13 |
+
|
14 |
+
# @register_model_architecture(model_name="transformer_lm", arch_name="transformer_lm_t5")
|
15 |
+
def transformer_lm_t5(args):
|
16 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1280)
|
17 |
+
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 6144)
|
18 |
+
args.decoder_layers = getattr(args, "decoder_layers", 20)
|
19 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
|
20 |
+
args.dropout = getattr(args, "dropout", 0.1)
|
21 |
+
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
|
22 |
+
args.activation_fn = getattr(args, "activation_fn", "gelu")
|
23 |
+
base_lm_architecture(args)
|
artst/sequence_generator.py
ADDED
@@ -0,0 +1,1080 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
# Based on speecht5, fairseq and espnet code bases
|
5 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import math
|
9 |
+
from typing import Dict, List, Optional
|
10 |
+
import sys
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
from fairseq import search, utils
|
15 |
+
from fairseq.data import data_utils
|
16 |
+
from fairseq.models import FairseqIncrementalDecoder
|
17 |
+
from torch import Tensor
|
18 |
+
from fairseq.ngram_repeat_block import NGramRepeatBlock
|
19 |
+
from espnet.nets.ctc_prefix_score import CTCPrefixScore
|
20 |
+
import numpy
|
21 |
+
|
22 |
+
CTC_SCORING_RATIO = 7.0
|
23 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
24 |
+
|
25 |
+
class SequenceGenerator(nn.Module):
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
models,
|
29 |
+
tgt_dict,
|
30 |
+
beam_size=1,
|
31 |
+
max_len_a=0,
|
32 |
+
max_len_b=200,
|
33 |
+
max_len=0,
|
34 |
+
min_len=1,
|
35 |
+
normalize_scores=True,
|
36 |
+
len_penalty=1.0,
|
37 |
+
unk_penalty=0.0,
|
38 |
+
temperature=1.0,
|
39 |
+
match_source_len=False,
|
40 |
+
no_repeat_ngram_size=0,
|
41 |
+
search_strategy=None,
|
42 |
+
eos=None,
|
43 |
+
symbols_to_strip_from_output=None,
|
44 |
+
lm_model=None,
|
45 |
+
lm_weight=1.0,
|
46 |
+
ctc_weight=0.0,
|
47 |
+
):
|
48 |
+
"""Generates translations of a given source sentence.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
models (List[~fairseq.models.FairseqModel]): ensemble of models,
|
52 |
+
currently support fairseq.models.TransformerModel for scripting
|
53 |
+
beam_size (int, optional): beam width (default: 1)
|
54 |
+
max_len_a/b (int, optional): generate sequences of maximum length
|
55 |
+
ax + b, where x is the source length
|
56 |
+
max_len (int, optional): the maximum length of the generated output
|
57 |
+
(not including end-of-sentence)
|
58 |
+
min_len (int, optional): the minimum length of the generated output
|
59 |
+
(not including end-of-sentence)
|
60 |
+
normalize_scores (bool, optional): normalize scores by the length
|
61 |
+
of the output (default: True)
|
62 |
+
len_penalty (float, optional): length penalty, where <1.0 favors
|
63 |
+
shorter, >1.0 favors longer sentences (default: 1.0)
|
64 |
+
unk_penalty (float, optional): unknown word penalty, where <0
|
65 |
+
produces more unks, >0 produces fewer (default: 0.0)
|
66 |
+
temperature (float, optional): temperature, where values
|
67 |
+
>1.0 produce more uniform samples and values <1.0 produce
|
68 |
+
sharper samples (default: 1.0)
|
69 |
+
match_source_len (bool, optional): outputs should match the source
|
70 |
+
length (default: False)
|
71 |
+
"""
|
72 |
+
super().__init__()
|
73 |
+
if isinstance(models, EnsembleModel):
|
74 |
+
self.model = models
|
75 |
+
else:
|
76 |
+
self.model = EnsembleModel(models)
|
77 |
+
self.tgt_dict = tgt_dict
|
78 |
+
self.pad = tgt_dict.pad()
|
79 |
+
self.unk = tgt_dict.unk()
|
80 |
+
self.eos = tgt_dict.eos() if eos is None else eos
|
81 |
+
self.blank = self.tgt_dict.index("<ctc_blank>")
|
82 |
+
self.mask = self.tgt_dict.index("<mask>")
|
83 |
+
self.mask_idxs = []
|
84 |
+
if self.tgt_dict.index("<mask>0") != self.unk:
|
85 |
+
count = 0
|
86 |
+
while self.tgt_dict.index("<mask>" + str(count)) != self.unk:
|
87 |
+
self.mask_idxs.append(self.tgt_dict.index("<mask>" + str(count)))
|
88 |
+
count += 1
|
89 |
+
self.mask_idxs = torch.tensor(self.mask_idxs)
|
90 |
+
self.symbols_to_strip_from_output = (
|
91 |
+
symbols_to_strip_from_output.union({self.eos})
|
92 |
+
if symbols_to_strip_from_output is not None
|
93 |
+
else {self.eos}
|
94 |
+
)
|
95 |
+
self.vocab_size = len(tgt_dict)
|
96 |
+
self.beam_size = beam_size
|
97 |
+
# the max beam size is the dictionary size - 1, since we never select pad
|
98 |
+
self.beam_size = min(beam_size, self.vocab_size - 1)
|
99 |
+
self.max_len_a = max_len_a
|
100 |
+
self.max_len_b = max_len_b
|
101 |
+
self.min_len = min_len
|
102 |
+
self.max_len = max_len or self.model.max_decoder_positions()
|
103 |
+
|
104 |
+
self.normalize_scores = normalize_scores
|
105 |
+
self.len_penalty = len_penalty
|
106 |
+
self.unk_penalty = unk_penalty
|
107 |
+
self.temperature = temperature
|
108 |
+
self.match_source_len = match_source_len
|
109 |
+
|
110 |
+
if no_repeat_ngram_size > 0:
|
111 |
+
self.repeat_ngram_blocker = NGramRepeatBlock(no_repeat_ngram_size)
|
112 |
+
else:
|
113 |
+
self.repeat_ngram_blocker = None
|
114 |
+
|
115 |
+
assert temperature > 0, "--temperature must be greater than 0"
|
116 |
+
|
117 |
+
self.search = (
|
118 |
+
search.BeamSearch(tgt_dict) if search_strategy is None else search_strategy
|
119 |
+
)
|
120 |
+
# We only need to set src_lengths in LengthConstrainedBeamSearch.
|
121 |
+
# As a module attribute, setting it would break in multithread
|
122 |
+
# settings when the model is shared.
|
123 |
+
self.should_set_src_lengths = (
|
124 |
+
hasattr(self.search, "needs_src_lengths") and self.search.needs_src_lengths
|
125 |
+
)
|
126 |
+
|
127 |
+
self.model.eval()
|
128 |
+
|
129 |
+
self.lm_model = lm_model
|
130 |
+
self.lm_weight = lm_weight
|
131 |
+
self.ctc_weight = ctc_weight
|
132 |
+
if self.lm_model is not None:
|
133 |
+
self.lm_model.eval()
|
134 |
+
|
135 |
+
def cuda(self):
|
136 |
+
self.model.cuda()
|
137 |
+
return self
|
138 |
+
|
139 |
+
@torch.no_grad()
|
140 |
+
def forward(
|
141 |
+
self,
|
142 |
+
sample: Dict[str, Dict[str, Tensor]],
|
143 |
+
prefix_tokens: Optional[Tensor] = None,
|
144 |
+
bos_token: Optional[int] = None,
|
145 |
+
):
|
146 |
+
"""Generate a batch of translations.
|
147 |
+
|
148 |
+
Args:
|
149 |
+
sample (dict): batch
|
150 |
+
prefix_tokens (torch.LongTensor, optional): force decoder to begin
|
151 |
+
with these tokens
|
152 |
+
bos_token (int, optional): beginning of sentence token
|
153 |
+
(default: self.eos)
|
154 |
+
"""
|
155 |
+
return self._generate(sample, prefix_tokens, bos_token=bos_token)
|
156 |
+
|
157 |
+
# TODO(myleott): unused, deprecate after pytorch-translate migration
|
158 |
+
def generate_batched_itr(self, data_itr, beam_size=None, cuda=False, timer=None):
|
159 |
+
"""Iterate over a batched dataset and yield individual translations.
|
160 |
+
Args:
|
161 |
+
cuda (bool, optional): use GPU for generation
|
162 |
+
timer (StopwatchMeter, optional): time generations
|
163 |
+
"""
|
164 |
+
for sample in data_itr:
|
165 |
+
s = utils.move_to_cuda(sample) if cuda else sample
|
166 |
+
if "net_input" not in s:
|
167 |
+
continue
|
168 |
+
input = s["net_input"]
|
169 |
+
# model.forward normally channels prev_output_tokens into the decoder
|
170 |
+
# separately, but SequenceGenerator directly calls model.encoder
|
171 |
+
encoder_input = {
|
172 |
+
k: v for k, v in input.items() if k != "prev_output_tokens"
|
173 |
+
}
|
174 |
+
if timer is not None:
|
175 |
+
timer.start()
|
176 |
+
with torch.no_grad():
|
177 |
+
hypos = self.generate(encoder_input)
|
178 |
+
if timer is not None:
|
179 |
+
timer.stop(sum(len(h[0]["tokens"]) for h in hypos))
|
180 |
+
for i, id in enumerate(s["id"].data):
|
181 |
+
# remove padding
|
182 |
+
src = utils.strip_pad(input["src_tokens"].data[i, :], self.pad)
|
183 |
+
ref = (
|
184 |
+
utils.strip_pad(s["target"].data[i, :], self.pad)
|
185 |
+
if s["target"] is not None
|
186 |
+
else None
|
187 |
+
)
|
188 |
+
yield id, src, ref, hypos[i]
|
189 |
+
|
190 |
+
@torch.no_grad()
|
191 |
+
def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs):
|
192 |
+
"""Generate translations. Match the api of other fairseq generators.
|
193 |
+
|
194 |
+
Args:
|
195 |
+
models (List[~fairseq.models.FairseqModel]): ensemble of models
|
196 |
+
sample (dict): batch
|
197 |
+
prefix_tokens (torch.LongTensor, optional): force decoder to begin
|
198 |
+
with these tokens
|
199 |
+
constraints (torch.LongTensor, optional): force decoder to include
|
200 |
+
the list of constraints
|
201 |
+
bos_token (int, optional): beginning of sentence token
|
202 |
+
(default: self.eos)
|
203 |
+
"""
|
204 |
+
return self._generate(sample, **kwargs)
|
205 |
+
|
206 |
+
def _generate(
|
207 |
+
self,
|
208 |
+
sample: Dict[str, Dict[str, Tensor]],
|
209 |
+
prefix_tokens: Optional[Tensor] = None,
|
210 |
+
constraints: Optional[Tensor] = None,
|
211 |
+
bos_token: Optional[int] = None,
|
212 |
+
):
|
213 |
+
incremental_states = torch.jit.annotate(
|
214 |
+
List[Dict[str, Dict[str, Optional[Tensor]]]],
|
215 |
+
[
|
216 |
+
torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})
|
217 |
+
for i in range(self.model.models_size)
|
218 |
+
],
|
219 |
+
)
|
220 |
+
net_input = sample["net_input"]
|
221 |
+
|
222 |
+
if "src_tokens" in net_input:
|
223 |
+
src_tokens = net_input["src_tokens"]
|
224 |
+
# length of the source text being the character length except EndOfSentence and pad
|
225 |
+
src_lengths = (
|
226 |
+
(src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
|
227 |
+
)
|
228 |
+
elif "source" in net_input:
|
229 |
+
src_tokens = net_input["source"]
|
230 |
+
src_lengths = (
|
231 |
+
net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1)
|
232 |
+
if net_input["padding_mask"] is not None
|
233 |
+
else torch.tensor(src_tokens.size(-1)).to(src_tokens)
|
234 |
+
)
|
235 |
+
elif "features" in net_input:
|
236 |
+
src_tokens = net_input["features"]
|
237 |
+
src_lengths = (
|
238 |
+
net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1)
|
239 |
+
if net_input["padding_mask"] is not None
|
240 |
+
else torch.tensor(src_tokens.size(-1)).to(src_tokens)
|
241 |
+
)
|
242 |
+
else:
|
243 |
+
raise Exception("expected src_tokens or source in net input. input keys: " + str(net_input.keys()))
|
244 |
+
|
245 |
+
# bsz: total number of sentences in beam
|
246 |
+
# Note that src_tokens may have more than 2 dimensions (i.e. audio features)
|
247 |
+
bsz, src_len = src_tokens.size()[:2]
|
248 |
+
beam_size = self.beam_size
|
249 |
+
|
250 |
+
if constraints is not None and not self.search.supports_constraints:
|
251 |
+
raise NotImplementedError(
|
252 |
+
"Target-side constraints were provided, but search method doesn't support them"
|
253 |
+
)
|
254 |
+
|
255 |
+
# Initialize constraints, when active
|
256 |
+
self.search.init_constraints(constraints, beam_size)
|
257 |
+
|
258 |
+
max_len: int = -1
|
259 |
+
if self.match_source_len:
|
260 |
+
max_len = src_lengths.max().item()
|
261 |
+
else:
|
262 |
+
max_len = min(
|
263 |
+
int(self.max_len_a * src_len + self.max_len_b),
|
264 |
+
self.max_len - 1,
|
265 |
+
)
|
266 |
+
assert (
|
267 |
+
self.min_len <= max_len
|
268 |
+
), "min_len cannot be larger than max_len, please adjust these!"
|
269 |
+
# compute the encoder output for each beam
|
270 |
+
encoder_outs = self.model.forward_encoder(net_input)
|
271 |
+
|
272 |
+
# Get CTC lprobs and prep ctc_scorer
|
273 |
+
if self.ctc_weight > 0:
|
274 |
+
ctc_lprobs = self.model.models[0].get_normalized_probs_for_ctc(
|
275 |
+
encoder_outs[0], log_probs=True
|
276 |
+
).contiguous().transpose(0, 1) # (B, T, C) from the encoder
|
277 |
+
|
278 |
+
hyp = {}
|
279 |
+
ctc_prefix_score = CTCPrefixScore(ctc_lprobs[0].detach().cpu().numpy(), self.blank, self.eos, numpy)
|
280 |
+
hyp["ctc_state_prev"] = ctc_prefix_score.initial_state()
|
281 |
+
hyp["ctc_score_prev"] = 0.0
|
282 |
+
ctc_beam = min(ctc_lprobs.shape[-1] - self.mask_idxs.size(-1), int(beam_size * CTC_SCORING_RATIO))
|
283 |
+
ctc_hyps = {str(self.eos): hyp}
|
284 |
+
|
285 |
+
# placeholder of indices for bsz * beam_size to hold tokens and accumulative scores
|
286 |
+
new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
|
287 |
+
new_order = new_order.to(src_tokens.device).long()
|
288 |
+
encoder_outs = self.model.reorder_encoder_out(encoder_outs, new_order)
|
289 |
+
# ensure encoder_outs is a List.
|
290 |
+
assert encoder_outs is not None
|
291 |
+
|
292 |
+
# initialize buffers
|
293 |
+
scores = (
|
294 |
+
torch.zeros(bsz * beam_size, max_len + 1).to(src_tokens).float()
|
295 |
+
) # +1 for eos; pad is never chosen for scoring
|
296 |
+
tokens = (
|
297 |
+
torch.zeros(bsz * beam_size, max_len + 2)
|
298 |
+
.to(src_tokens)
|
299 |
+
.long()
|
300 |
+
.fill_(self.pad)
|
301 |
+
) # +2 for eos and pad
|
302 |
+
tokens[:, 0] = self.eos if bos_token is None else bos_token
|
303 |
+
attn: Optional[Tensor] = None
|
304 |
+
|
305 |
+
# A list that indicates candidates that should be ignored.
|
306 |
+
# For example, suppose we're sampling and have already finalized 2/5
|
307 |
+
# samples. Then cands_to_ignore would mark 2 positions as being ignored,
|
308 |
+
# so that we only finalize the remaining 3 samples.
|
309 |
+
cands_to_ignore = (
|
310 |
+
torch.zeros(bsz, beam_size).to(src_tokens).eq(-1)
|
311 |
+
) # forward and backward-compatible False mask
|
312 |
+
|
313 |
+
# list of completed sentences
|
314 |
+
finalized = torch.jit.annotate(
|
315 |
+
List[List[Dict[str, Tensor]]],
|
316 |
+
[torch.jit.annotate(List[Dict[str, Tensor]], []) for i in range(bsz)],
|
317 |
+
) # contains lists of dictionaries of infomation about the hypothesis being finalized at each step
|
318 |
+
|
319 |
+
# a boolean array indicating if the sentence at the index is finished or not
|
320 |
+
finished = [False for i in range(bsz)]
|
321 |
+
num_remaining_sent = bsz # number of sentences remaining
|
322 |
+
|
323 |
+
# number of candidate hypos per step
|
324 |
+
cand_size = 2 * beam_size # 2 x beam size in case half are EOS
|
325 |
+
|
326 |
+
# offset arrays for converting between different indexing schemes
|
327 |
+
bbsz_offsets = (
|
328 |
+
(torch.arange(0, bsz) * beam_size)
|
329 |
+
.unsqueeze(1)
|
330 |
+
.type_as(tokens)
|
331 |
+
.to(src_tokens.device)
|
332 |
+
)
|
333 |
+
cand_offsets = torch.arange(0, cand_size).type_as(tokens).to(src_tokens.device)
|
334 |
+
|
335 |
+
reorder_state: Optional[Tensor] = None
|
336 |
+
ctc_state = None
|
337 |
+
batch_idxs: Optional[Tensor] = None
|
338 |
+
|
339 |
+
original_batch_idxs: Optional[Tensor] = None
|
340 |
+
if "id" in sample and isinstance(sample["id"], Tensor):
|
341 |
+
original_batch_idxs = sample["id"]
|
342 |
+
else:
|
343 |
+
original_batch_idxs = torch.arange(0, bsz).type_as(tokens)
|
344 |
+
|
345 |
+
for step in range(max_len + 1): # one extra step for EOS marker
|
346 |
+
# reorder decoder internal states based on the prev choice of beams
|
347 |
+
if reorder_state is not None:
|
348 |
+
if batch_idxs is not None:
|
349 |
+
# update beam indices to take into account removed sentences
|
350 |
+
corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(
|
351 |
+
batch_idxs
|
352 |
+
)
|
353 |
+
reorder_state.view(-1, beam_size).add_(
|
354 |
+
corr.unsqueeze(-1) * beam_size
|
355 |
+
)
|
356 |
+
original_batch_idxs = original_batch_idxs[batch_idxs]
|
357 |
+
self.model.reorder_incremental_state(incremental_states, reorder_state)
|
358 |
+
encoder_outs = self.model.reorder_encoder_out(
|
359 |
+
encoder_outs, reorder_state
|
360 |
+
)
|
361 |
+
|
362 |
+
lprobs, avg_attn_scores = self.model.forward_decoder(
|
363 |
+
tokens[:, : step + 1],
|
364 |
+
encoder_outs,
|
365 |
+
incremental_states,
|
366 |
+
self.temperature,
|
367 |
+
)
|
368 |
+
|
369 |
+
if self.ctc_weight > 0 and step != 0:
|
370 |
+
# lprobs[:, self.blank] = -math.inf # never select blank
|
371 |
+
ctc_lprobs = lprobs.clone()
|
372 |
+
ctc_lprobs[:, self.blank] = -math.inf # never select blank
|
373 |
+
if self.mask != self.unk:
|
374 |
+
ctc_lprobs[:, self.mask] = -math.inf # never select mask
|
375 |
+
if self.mask_idxs.size(0) != 0:
|
376 |
+
ctc_lprobs[:, self.mask_idxs] = -math.inf # never select mask
|
377 |
+
local_best_scores, local_best_ids = torch.topk(ctc_lprobs, ctc_beam, dim=-1)
|
378 |
+
for b in range(tokens.size(0)):
|
379 |
+
hyp_key = " ".join(str(x) for x in tokens[b, : step + 1].tolist())
|
380 |
+
|
381 |
+
ctc_scores, ctc_states = ctc_prefix_score(
|
382 |
+
tokens[b, : step + 1].cpu(), local_best_ids[b].cpu(), ctc_hyps[hyp_key]["ctc_state_prev"]
|
383 |
+
)
|
384 |
+
lprobs[b] = lprobs[b]
|
385 |
+
lprobs[b, local_best_ids[b]] = (1 - self.ctc_weight) * (lprobs[b, local_best_ids[b]]) + self.ctc_weight * torch.from_numpy(
|
386 |
+
ctc_scores - ctc_hyps[hyp_key]["ctc_score_prev"]
|
387 |
+
).to(device=device)
|
388 |
+
for j in range(len(local_best_ids[b])):
|
389 |
+
ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())] = {}
|
390 |
+
ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())]["ctc_score_prev"] = ctc_scores[j]
|
391 |
+
ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())]["ctc_state_prev"] = ctc_states[j]
|
392 |
+
|
393 |
+
# local_ctc_scores, ctc_state = ctc_scorer(
|
394 |
+
# tokens[:, : step + 1], ctc_state, part_ids
|
395 |
+
# )
|
396 |
+
# lprobs += local_ctc_scores * self.ctc_weight
|
397 |
+
elif self.ctc_weight > 0 and step == 0:
|
398 |
+
ctc_lprobs = lprobs.clone()
|
399 |
+
ctc_lprobs[:, self.blank] = -math.inf # never select blank
|
400 |
+
if self.mask != self.unk:
|
401 |
+
ctc_lprobs[:, self.mask] = -math.inf # never select mask
|
402 |
+
if self.mask_idxs.size(0) != 0:
|
403 |
+
ctc_lprobs[:, self.mask_idxs] = -math.inf # never select mask
|
404 |
+
local_best_scores, local_best_ids = torch.topk(ctc_lprobs, ctc_beam, dim=-1)
|
405 |
+
for b in range(tokens.size(0)):
|
406 |
+
hyp_key = " ".join(str(x) for x in tokens[b, : step + 1].tolist())
|
407 |
+
ctc_scores, ctc_states = ctc_prefix_score(
|
408 |
+
tokens[b, : step + 1].cpu(), local_best_ids[b].cpu(), ctc_hyps[hyp_key]["ctc_state_prev"]
|
409 |
+
)
|
410 |
+
lprobs[b] = lprobs[b]
|
411 |
+
lprobs[b, local_best_ids[b]] = (1 - self.ctc_weight) * (lprobs[b, local_best_ids[b]]) + self.ctc_weight * torch.from_numpy(
|
412 |
+
ctc_scores - ctc_hyps[hyp_key]["ctc_score_prev"]
|
413 |
+
).to(device=device)
|
414 |
+
for j in range(len(local_best_ids[b])):
|
415 |
+
if b == 0:
|
416 |
+
ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())] = {}
|
417 |
+
ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())]["ctc_score_prev"] = ctc_scores[j]
|
418 |
+
ctc_hyps[hyp_key + " " + str(local_best_ids[b][j].item())]["ctc_state_prev"] = ctc_states[j]
|
419 |
+
|
420 |
+
if self.lm_model is not None:
|
421 |
+
lm_out = self.lm_model(tokens[:, : step + 1])
|
422 |
+
probs = self.lm_model.get_normalized_probs(
|
423 |
+
lm_out, log_probs=True, sample=None
|
424 |
+
)
|
425 |
+
probs = probs[:, -1, :] * self.lm_weight
|
426 |
+
lprobs[:, :probs.size(1)] += probs
|
427 |
+
|
428 |
+
# handle prefix tokens (possibly with different lengths)
|
429 |
+
if (
|
430 |
+
prefix_tokens is not None
|
431 |
+
and step < prefix_tokens.size(1)
|
432 |
+
and step < max_len
|
433 |
+
):
|
434 |
+
lprobs, tokens, scores = self._prefix_tokens(
|
435 |
+
step, lprobs, scores, tokens, prefix_tokens, beam_size
|
436 |
+
)
|
437 |
+
elif step < self.min_len:
|
438 |
+
# minimum length constraint (does not apply if using prefix_tokens)
|
439 |
+
lprobs[:, self.eos] = -math.inf
|
440 |
+
|
441 |
+
lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs)
|
442 |
+
|
443 |
+
lprobs[:, self.pad] = -math.inf # never select pad
|
444 |
+
lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty
|
445 |
+
lprobs[:, self.blank] = -math.inf # never select blank
|
446 |
+
if self.mask != self.unk:
|
447 |
+
lprobs[:, self.mask] = -math.inf # never select mask
|
448 |
+
if self.mask_idxs.size(0) != 0:
|
449 |
+
lprobs[:, self.mask_idxs] = -math.inf # never select mask
|
450 |
+
|
451 |
+
# handle max length constraint
|
452 |
+
if step >= max_len:
|
453 |
+
lprobs[:, : self.eos] = -math.inf
|
454 |
+
lprobs[:, self.eos + 1 :] = -math.inf
|
455 |
+
|
456 |
+
# Record attention scores, only support avg_attn_scores is a Tensor
|
457 |
+
if avg_attn_scores is not None:
|
458 |
+
if attn is None:
|
459 |
+
attn = torch.empty(
|
460 |
+
bsz * beam_size, avg_attn_scores.size(1), max_len + 2
|
461 |
+
).to(scores)
|
462 |
+
attn[:, :, step + 1].copy_(avg_attn_scores)
|
463 |
+
|
464 |
+
scores = scores.type_as(lprobs)
|
465 |
+
eos_bbsz_idx = torch.empty(0).to(
|
466 |
+
tokens
|
467 |
+
) # indices of hypothesis ending with eos (finished sentences)
|
468 |
+
eos_scores = torch.empty(0).to(
|
469 |
+
scores
|
470 |
+
) # scores of hypothesis ending with eos (finished sentences)
|
471 |
+
|
472 |
+
if self.should_set_src_lengths:
|
473 |
+
self.search.set_src_lengths(src_lengths)
|
474 |
+
|
475 |
+
if self.repeat_ngram_blocker is not None:
|
476 |
+
lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, beam_size, step)
|
477 |
+
|
478 |
+
# Shape: (batch, cand_size)
|
479 |
+
cand_scores, cand_indices, cand_beams = self.search.step(
|
480 |
+
step,
|
481 |
+
lprobs.view(bsz, -1, self.vocab_size),
|
482 |
+
scores.view(bsz, beam_size, -1)[:, :, :step],
|
483 |
+
tokens[:, : step + 1],
|
484 |
+
original_batch_idxs,
|
485 |
+
)
|
486 |
+
|
487 |
+
# cand_bbsz_idx contains beam indices for the top candidate
|
488 |
+
# hypotheses, with a range of values: [0, bsz*beam_size),
|
489 |
+
# and dimensions: [bsz, cand_size]
|
490 |
+
cand_bbsz_idx = cand_beams.add(bbsz_offsets)
|
491 |
+
|
492 |
+
# finalize hypotheses that end in eos
|
493 |
+
# Shape of eos_mask: (batch size, beam size)
|
494 |
+
eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf)
|
495 |
+
eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to(eos_mask)
|
496 |
+
|
497 |
+
# only consider eos when it's among the top beam_size indices
|
498 |
+
# Now we know what beam item(s) to finish
|
499 |
+
# Shape: 1d list of absolute-numbered
|
500 |
+
eos_bbsz_idx = torch.masked_select(
|
501 |
+
cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size]
|
502 |
+
)
|
503 |
+
|
504 |
+
finalized_sents: List[int] = []
|
505 |
+
if eos_bbsz_idx.numel() > 0:
|
506 |
+
eos_scores = torch.masked_select(
|
507 |
+
cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size]
|
508 |
+
)
|
509 |
+
|
510 |
+
finalized_sents = self.finalize_hypos(
|
511 |
+
step,
|
512 |
+
eos_bbsz_idx,
|
513 |
+
eos_scores,
|
514 |
+
tokens,
|
515 |
+
scores,
|
516 |
+
finalized,
|
517 |
+
finished,
|
518 |
+
beam_size,
|
519 |
+
attn,
|
520 |
+
src_lengths,
|
521 |
+
max_len,
|
522 |
+
)
|
523 |
+
num_remaining_sent -= len(finalized_sents)
|
524 |
+
|
525 |
+
assert num_remaining_sent >= 0
|
526 |
+
if num_remaining_sent == 0:
|
527 |
+
break
|
528 |
+
if self.search.stop_on_max_len and step >= max_len:
|
529 |
+
break
|
530 |
+
assert step < max_len, f"{step} < {max_len}"
|
531 |
+
|
532 |
+
# Remove finalized sentences (ones for which {beam_size}
|
533 |
+
# finished hypotheses have been generated) from the batch.
|
534 |
+
if len(finalized_sents) > 0:
|
535 |
+
new_bsz = bsz - len(finalized_sents)
|
536 |
+
|
537 |
+
# construct batch_idxs which holds indices of batches to keep for the next pass
|
538 |
+
batch_mask = torch.ones(
|
539 |
+
bsz, dtype=torch.bool, device=cand_indices.device
|
540 |
+
)
|
541 |
+
batch_mask[finalized_sents] = False
|
542 |
+
# TODO replace `nonzero(as_tuple=False)` after TorchScript supports it
|
543 |
+
batch_idxs = torch.arange(
|
544 |
+
bsz, device=cand_indices.device
|
545 |
+
).masked_select(batch_mask)
|
546 |
+
|
547 |
+
# Choose the subset of the hypothesized constraints that will continue
|
548 |
+
self.search.prune_sentences(batch_idxs)
|
549 |
+
|
550 |
+
eos_mask = eos_mask[batch_idxs]
|
551 |
+
cand_beams = cand_beams[batch_idxs]
|
552 |
+
bbsz_offsets.resize_(new_bsz, 1)
|
553 |
+
cand_bbsz_idx = cand_beams.add(bbsz_offsets)
|
554 |
+
cand_scores = cand_scores[batch_idxs]
|
555 |
+
cand_indices = cand_indices[batch_idxs]
|
556 |
+
|
557 |
+
if prefix_tokens is not None:
|
558 |
+
prefix_tokens = prefix_tokens[batch_idxs]
|
559 |
+
src_lengths = src_lengths[batch_idxs]
|
560 |
+
cands_to_ignore = cands_to_ignore[batch_idxs]
|
561 |
+
|
562 |
+
scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
|
563 |
+
tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
|
564 |
+
if attn is not None:
|
565 |
+
attn = attn.view(bsz, -1)[batch_idxs].view(
|
566 |
+
new_bsz * beam_size, attn.size(1), -1
|
567 |
+
)
|
568 |
+
bsz = new_bsz
|
569 |
+
else:
|
570 |
+
batch_idxs = None
|
571 |
+
|
572 |
+
# Set active_mask so that values > cand_size indicate eos hypos
|
573 |
+
# and values < cand_size indicate candidate active hypos.
|
574 |
+
# After, the min values per row are the top candidate active hypos
|
575 |
+
|
576 |
+
# Rewrite the operator since the element wise or is not supported in torchscript.
|
577 |
+
|
578 |
+
eos_mask[:, :beam_size] = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size]))
|
579 |
+
active_mask = torch.add(
|
580 |
+
eos_mask.type_as(cand_offsets) * cand_size,
|
581 |
+
cand_offsets[: eos_mask.size(1)],
|
582 |
+
)
|
583 |
+
|
584 |
+
# get the top beam_size active hypotheses, which are just
|
585 |
+
# the hypos with the smallest values in active_mask.
|
586 |
+
# {active_hypos} indicates which {beam_size} hypotheses
|
587 |
+
# from the list of {2 * beam_size} candidates were
|
588 |
+
# selected. Shapes: (batch size, beam size)
|
589 |
+
new_cands_to_ignore, active_hypos = torch.topk(
|
590 |
+
active_mask, k=beam_size, dim=1, largest=False
|
591 |
+
)
|
592 |
+
|
593 |
+
# update cands_to_ignore to ignore any finalized hypos.
|
594 |
+
cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size]
|
595 |
+
# Make sure there is at least one active item for each sentence in the batch.
|
596 |
+
assert (~cands_to_ignore).any(dim=1).all()
|
597 |
+
|
598 |
+
# update cands_to_ignore to ignore any finalized hypos
|
599 |
+
|
600 |
+
# {active_bbsz_idx} denotes which beam number is continued for each new hypothesis (a beam
|
601 |
+
# can be selected more than once).
|
602 |
+
active_bbsz_idx = torch.gather(cand_bbsz_idx, dim=1, index=active_hypos)
|
603 |
+
active_scores = torch.gather(cand_scores, dim=1, index=active_hypos)
|
604 |
+
|
605 |
+
active_bbsz_idx = active_bbsz_idx.view(-1)
|
606 |
+
active_scores = active_scores.view(-1)
|
607 |
+
|
608 |
+
# copy tokens and scores for active hypotheses
|
609 |
+
|
610 |
+
# Set the tokens for each beam (can select the same row more than once)
|
611 |
+
tokens[:, : step + 1] = torch.index_select(
|
612 |
+
tokens[:, : step + 1], dim=0, index=active_bbsz_idx
|
613 |
+
)
|
614 |
+
# Select the next token for each of them
|
615 |
+
tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather(
|
616 |
+
cand_indices, dim=1, index=active_hypos
|
617 |
+
)
|
618 |
+
if step > 0:
|
619 |
+
scores[:, :step] = torch.index_select(
|
620 |
+
scores[:, :step], dim=0, index=active_bbsz_idx
|
621 |
+
)
|
622 |
+
scores.view(bsz, beam_size, -1)[:, :, step] = torch.gather(
|
623 |
+
cand_scores, dim=1, index=active_hypos
|
624 |
+
)
|
625 |
+
|
626 |
+
# Update constraints based on which candidates were selected for the next beam
|
627 |
+
self.search.update_constraints(active_hypos)
|
628 |
+
|
629 |
+
# copy attention for active hypotheses
|
630 |
+
if attn is not None:
|
631 |
+
attn[:, :, : step + 2] = torch.index_select(
|
632 |
+
attn[:, :, : step + 2], dim=0, index=active_bbsz_idx
|
633 |
+
)
|
634 |
+
|
635 |
+
# reorder incremental state in decoder
|
636 |
+
reorder_state = active_bbsz_idx
|
637 |
+
|
638 |
+
# if self.ctc_weight > 0:
|
639 |
+
# accum_best_id = torch.gather(cand_indices, dim=1, index=active_hypos)
|
640 |
+
# ctc_state = ctc_scorer.index_select_state(
|
641 |
+
# ctc_state, accum_best_id
|
642 |
+
# )
|
643 |
+
|
644 |
+
# sort by score descending
|
645 |
+
for sent in range(len(finalized)):
|
646 |
+
scores = torch.tensor(
|
647 |
+
[float(elem["score"].item()) for elem in finalized[sent]]
|
648 |
+
)
|
649 |
+
_, sorted_scores_indices = torch.sort(scores, descending=True)
|
650 |
+
finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices]
|
651 |
+
finalized[sent] = torch.jit.annotate(
|
652 |
+
List[Dict[str, Tensor]], finalized[sent]
|
653 |
+
)
|
654 |
+
return finalized
|
655 |
+
|
656 |
+
def _prefix_tokens(
|
657 |
+
self, step: int, lprobs, scores, tokens, prefix_tokens, beam_size: int
|
658 |
+
):
|
659 |
+
"""Handle prefix tokens"""
|
660 |
+
prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1)
|
661 |
+
prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1))
|
662 |
+
prefix_mask = prefix_toks.ne(self.pad)
|
663 |
+
lprobs[prefix_mask] = torch.min(prefix_lprobs) - 1
|
664 |
+
lprobs[prefix_mask] = lprobs[prefix_mask].scatter(
|
665 |
+
-1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask]
|
666 |
+
)
|
667 |
+
# if prefix includes eos, then we should make sure tokens and
|
668 |
+
# scores are the same across all beams
|
669 |
+
eos_mask = prefix_toks.eq(self.eos)
|
670 |
+
if eos_mask.any():
|
671 |
+
# validate that the first beam matches the prefix
|
672 |
+
first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[
|
673 |
+
:, 0, 1 : step + 1
|
674 |
+
]
|
675 |
+
eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0]
|
676 |
+
target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step]
|
677 |
+
assert (first_beam == target_prefix).all()
|
678 |
+
|
679 |
+
# copy tokens, scores and lprobs from the first beam to all beams
|
680 |
+
tokens = self.replicate_first_beam(tokens, eos_mask_batch_dim, beam_size)
|
681 |
+
scores = self.replicate_first_beam(scores, eos_mask_batch_dim, beam_size)
|
682 |
+
lprobs = self.replicate_first_beam(lprobs, eos_mask_batch_dim, beam_size)
|
683 |
+
return lprobs, tokens, scores
|
684 |
+
|
685 |
+
def replicate_first_beam(self, tensor, mask, beam_size: int):
|
686 |
+
tensor = tensor.view(-1, beam_size, tensor.size(-1))
|
687 |
+
tensor[mask] = tensor[mask][:, :1, :]
|
688 |
+
return tensor.view(-1, tensor.size(-1))
|
689 |
+
|
690 |
+
def finalize_hypos(
|
691 |
+
self,
|
692 |
+
step: int,
|
693 |
+
bbsz_idx,
|
694 |
+
eos_scores,
|
695 |
+
tokens,
|
696 |
+
scores,
|
697 |
+
finalized: List[List[Dict[str, Tensor]]],
|
698 |
+
finished: List[bool],
|
699 |
+
beam_size: int,
|
700 |
+
attn: Optional[Tensor],
|
701 |
+
src_lengths,
|
702 |
+
max_len: int,
|
703 |
+
):
|
704 |
+
"""Finalize hypothesis, store finalized information in `finalized`, and change `finished` accordingly.
|
705 |
+
A sentence is finalized when {beam_size} finished items have been collected for it.
|
706 |
+
Returns number of sentences (not beam items) being finalized.
|
707 |
+
These will be removed from the batch and not processed further.
|
708 |
+
Args:
|
709 |
+
bbsz_idx (Tensor):
|
710 |
+
"""
|
711 |
+
assert bbsz_idx.numel() == eos_scores.numel()
|
712 |
+
|
713 |
+
# clone relevant token and attention tensors.
|
714 |
+
# tokens is (batch * beam, max_len). So the index_select
|
715 |
+
# gets the newly EOS rows, then selects cols 1..{step + 2}
|
716 |
+
tokens_clone = tokens.index_select(0, bbsz_idx)[
|
717 |
+
:, 1 : step + 2
|
718 |
+
] # skip the first index, which is EOS
|
719 |
+
|
720 |
+
tokens_clone[:, step] = self.eos
|
721 |
+
attn_clone = (
|
722 |
+
attn.index_select(0, bbsz_idx)[:, :, 1 : step + 2]
|
723 |
+
if attn is not None
|
724 |
+
else None
|
725 |
+
)
|
726 |
+
|
727 |
+
# compute scores per token position
|
728 |
+
pos_scores = scores.index_select(0, bbsz_idx)[:, : step + 1]
|
729 |
+
pos_scores[:, step] = eos_scores
|
730 |
+
# convert from cumulative to per-position scores
|
731 |
+
pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]
|
732 |
+
|
733 |
+
# normalize sentence-level scores
|
734 |
+
if self.normalize_scores:
|
735 |
+
eos_scores /= (step + 1) ** self.len_penalty
|
736 |
+
|
737 |
+
# cum_unfin records which sentences in the batch are finished.
|
738 |
+
# It helps match indexing between (a) the original sentences
|
739 |
+
# in the batch and (b) the current, possibly-reduced set of
|
740 |
+
# sentences.
|
741 |
+
cum_unfin: List[int] = []
|
742 |
+
prev = 0
|
743 |
+
for f in finished:
|
744 |
+
if f:
|
745 |
+
prev += 1
|
746 |
+
else:
|
747 |
+
cum_unfin.append(prev)
|
748 |
+
cum_fin_tensor = torch.tensor(cum_unfin, dtype=torch.int).to(bbsz_idx)
|
749 |
+
|
750 |
+
unfin_idx = bbsz_idx // beam_size
|
751 |
+
sent = unfin_idx + torch.index_select(cum_fin_tensor, 0, unfin_idx)
|
752 |
+
|
753 |
+
# Create a set of "{sent}{unfin_idx}", where
|
754 |
+
# "unfin_idx" is the index in the current (possibly reduced)
|
755 |
+
# list of sentences, and "sent" is the index in the original,
|
756 |
+
# unreduced batch
|
757 |
+
# For every finished beam item
|
758 |
+
# sentence index in the current (possibly reduced) batch
|
759 |
+
seen = (sent << 32) + unfin_idx
|
760 |
+
unique_seen: List[int] = torch.unique(seen).tolist()
|
761 |
+
|
762 |
+
if self.match_source_len:
|
763 |
+
condition = step > torch.index_select(src_lengths, 0, unfin_idx)
|
764 |
+
eos_scores = torch.where(condition, torch.tensor(-math.inf), eos_scores)
|
765 |
+
sent_list: List[int] = sent.tolist()
|
766 |
+
for i in range(bbsz_idx.size()[0]):
|
767 |
+
# An input sentence (among those in a batch) is finished when
|
768 |
+
# beam_size hypotheses have been collected for it
|
769 |
+
if len(finalized[sent_list[i]]) < beam_size:
|
770 |
+
if attn_clone is not None:
|
771 |
+
# remove padding tokens from attn scores
|
772 |
+
hypo_attn = attn_clone[i]
|
773 |
+
else:
|
774 |
+
hypo_attn = torch.empty(0)
|
775 |
+
|
776 |
+
finalized[sent_list[i]].append(
|
777 |
+
{
|
778 |
+
"tokens": tokens_clone[i],
|
779 |
+
"score": eos_scores[i],
|
780 |
+
"attention": hypo_attn, # src_len x tgt_len
|
781 |
+
"alignment": torch.empty(0),
|
782 |
+
"positional_scores": pos_scores[i],
|
783 |
+
}
|
784 |
+
)
|
785 |
+
|
786 |
+
newly_finished: List[int] = []
|
787 |
+
for unique_s in unique_seen:
|
788 |
+
# check termination conditions for this sentence
|
789 |
+
unique_sent: int = unique_s >> 32
|
790 |
+
unique_unfin_idx: int = unique_s - (unique_sent << 32)
|
791 |
+
|
792 |
+
if not finished[unique_sent] and self.is_finished(
|
793 |
+
step, unique_unfin_idx, max_len, len(finalized[unique_sent]), beam_size
|
794 |
+
):
|
795 |
+
finished[unique_sent] = True
|
796 |
+
newly_finished.append(unique_unfin_idx)
|
797 |
+
|
798 |
+
return newly_finished
|
799 |
+
|
800 |
+
def is_finished(
|
801 |
+
self,
|
802 |
+
step: int,
|
803 |
+
unfin_idx: int,
|
804 |
+
max_len: int,
|
805 |
+
finalized_sent_len: int,
|
806 |
+
beam_size: int,
|
807 |
+
):
|
808 |
+
"""
|
809 |
+
Check whether decoding for a sentence is finished, which
|
810 |
+
occurs when the list of finalized sentences has reached the
|
811 |
+
beam size, or when we reach the maximum length.
|
812 |
+
"""
|
813 |
+
assert finalized_sent_len <= beam_size
|
814 |
+
if finalized_sent_len == beam_size or step == max_len:
|
815 |
+
return True
|
816 |
+
return False
|
817 |
+
|
818 |
+
|
819 |
+
class EnsembleModel(nn.Module):
|
820 |
+
"""A wrapper around an ensemble of models."""
|
821 |
+
|
822 |
+
def __init__(self, models):
|
823 |
+
super().__init__()
|
824 |
+
self.models_size = len(models)
|
825 |
+
# method '__len__' is not supported in ModuleList for torch script
|
826 |
+
self.single_model = models[0]
|
827 |
+
self.models = nn.ModuleList(models)
|
828 |
+
|
829 |
+
self.has_incremental: bool = False
|
830 |
+
if all(
|
831 |
+
hasattr(m, "decoder") and isinstance(m.decoder, FairseqIncrementalDecoder)
|
832 |
+
for m in models
|
833 |
+
):
|
834 |
+
self.has_incremental = True
|
835 |
+
|
836 |
+
def forward(self):
|
837 |
+
pass
|
838 |
+
|
839 |
+
def has_encoder(self):
|
840 |
+
return hasattr(self.single_model, "encoder")
|
841 |
+
|
842 |
+
def is_t5_structure(self):
|
843 |
+
t5_structure = hasattr(self.single_model, "text_encoder_prenet") and hasattr(self.single_model, "speech_encoder_prenet") or \
|
844 |
+
hasattr(self.single_model, "encoder_prenet") and hasattr(self.single_model, "encoder_prenet")
|
845 |
+
return t5_structure
|
846 |
+
|
847 |
+
def has_incremental_states(self):
|
848 |
+
return self.has_incremental
|
849 |
+
|
850 |
+
def max_decoder_positions(self):
|
851 |
+
return min([m.max_decoder_positions() for m in self.models if hasattr(m, "max_decoder_positions")] + [sys.maxsize])
|
852 |
+
|
853 |
+
@torch.jit.export
|
854 |
+
def forward_encoder(self, net_input: Dict[str, Tensor]):
|
855 |
+
if not self.has_encoder():
|
856 |
+
return None
|
857 |
+
elif self.is_t5_structure():
|
858 |
+
return [model.forward_encoder_torchscript(net_input) for model in self.models]
|
859 |
+
else:
|
860 |
+
return [model.encoder.forward_torchscript(net_input) for model in self.models]
|
861 |
+
|
862 |
+
@torch.jit.export
|
863 |
+
def forward_decoder(
|
864 |
+
self,
|
865 |
+
tokens,
|
866 |
+
encoder_outs: List[Dict[str, List[Tensor]]],
|
867 |
+
incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]],
|
868 |
+
temperature: float = 1.0,
|
869 |
+
):
|
870 |
+
log_probs = []
|
871 |
+
avg_attn: Optional[Tensor] = None
|
872 |
+
encoder_out: Optional[Dict[str, List[Tensor]]] = None
|
873 |
+
for i, model in enumerate(self.models):
|
874 |
+
if self.has_encoder():
|
875 |
+
encoder_out = encoder_outs[i]
|
876 |
+
# decode each model
|
877 |
+
if self.has_incremental_states():
|
878 |
+
if self.is_t5_structure:
|
879 |
+
decoder_out = model.forward_decoder(
|
880 |
+
tokens,
|
881 |
+
encoder_out=encoder_out,
|
882 |
+
incremental_state=incremental_states[i]
|
883 |
+
)
|
884 |
+
else:
|
885 |
+
decoder_out = model.decoder.forward(
|
886 |
+
tokens,
|
887 |
+
encoder_out=encoder_out,
|
888 |
+
incremental_state=incremental_states[i],
|
889 |
+
)
|
890 |
+
else:
|
891 |
+
if hasattr(model, "decoder"):
|
892 |
+
decoder_out = model.decoder.forward(tokens, encoder_out=encoder_out)
|
893 |
+
else:
|
894 |
+
decoder_out = model.forward(tokens)
|
895 |
+
|
896 |
+
attn: Optional[Tensor] = None
|
897 |
+
decoder_len = len(decoder_out)
|
898 |
+
if decoder_len > 1 and decoder_out[1] is not None:
|
899 |
+
if isinstance(decoder_out[1], Tensor):
|
900 |
+
attn = decoder_out[1]
|
901 |
+
else:
|
902 |
+
attn_holder = decoder_out[1]["attn"]
|
903 |
+
if isinstance(attn_holder, Tensor):
|
904 |
+
attn = attn_holder
|
905 |
+
elif attn_holder is not None:
|
906 |
+
attn = attn_holder[0]
|
907 |
+
if attn is not None:
|
908 |
+
attn = attn[:, -1, :]
|
909 |
+
|
910 |
+
decoder_out_tuple = (
|
911 |
+
decoder_out[0][:, -1:, :].div_(temperature),
|
912 |
+
None if decoder_len <= 1 else decoder_out[1],
|
913 |
+
)
|
914 |
+
probs = model.get_normalized_probs(
|
915 |
+
decoder_out_tuple, log_probs=True, sample=None
|
916 |
+
)
|
917 |
+
probs = probs[:, -1, :]
|
918 |
+
if self.models_size == 1:
|
919 |
+
return probs, attn
|
920 |
+
|
921 |
+
log_probs.append(probs)
|
922 |
+
if attn is not None:
|
923 |
+
if avg_attn is None:
|
924 |
+
avg_attn = attn
|
925 |
+
else:
|
926 |
+
avg_attn.add_(attn)
|
927 |
+
|
928 |
+
avg_probs = torch.logsumexp(torch.stack(log_probs, dim=0), dim=0) - math.log(
|
929 |
+
self.models_size
|
930 |
+
)
|
931 |
+
|
932 |
+
if avg_attn is not None:
|
933 |
+
avg_attn.div_(self.models_size)
|
934 |
+
return avg_probs, avg_attn
|
935 |
+
|
936 |
+
@torch.jit.export
|
937 |
+
def reorder_encoder_out(
|
938 |
+
self, encoder_outs: Optional[List[Dict[str, List[Tensor]]]], new_order
|
939 |
+
):
|
940 |
+
"""
|
941 |
+
Reorder encoder output according to *new_order*.
|
942 |
+
|
943 |
+
Args:
|
944 |
+
encoder_out: output from the ``forward()`` method
|
945 |
+
new_order (LongTensor): desired order
|
946 |
+
|
947 |
+
Returns:
|
948 |
+
*encoder_out* rearranged according to *new_order*
|
949 |
+
"""
|
950 |
+
new_outs: List[Dict[str, List[Tensor]]] = []
|
951 |
+
if not self.has_encoder():
|
952 |
+
return new_outs
|
953 |
+
for i, model in enumerate(self.models):
|
954 |
+
assert encoder_outs is not None
|
955 |
+
new_outs.append(
|
956 |
+
model.encoder.reorder_encoder_out(encoder_outs[i], new_order)
|
957 |
+
)
|
958 |
+
return new_outs
|
959 |
+
|
960 |
+
@torch.jit.export
|
961 |
+
def reorder_incremental_state(
|
962 |
+
self,
|
963 |
+
incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]],
|
964 |
+
new_order,
|
965 |
+
):
|
966 |
+
if not self.has_incremental_states():
|
967 |
+
return
|
968 |
+
for i, model in enumerate(self.models):
|
969 |
+
model.decoder.reorder_incremental_state_scripting(
|
970 |
+
incremental_states[i], new_order
|
971 |
+
)
|
972 |
+
|
973 |
+
|
974 |
+
class SequenceGeneratorWithAlignment(SequenceGenerator):
|
975 |
+
def __init__(
|
976 |
+
self, models, tgt_dict, left_pad_target=False, print_alignment="hard", **kwargs
|
977 |
+
):
|
978 |
+
"""Generates translations of a given source sentence.
|
979 |
+
|
980 |
+
Produces alignments following "Jointly Learning to Align and
|
981 |
+
Translate with Transformer Models" (Garg et al., EMNLP 2019).
|
982 |
+
|
983 |
+
Args:
|
984 |
+
left_pad_target (bool, optional): Whether or not the
|
985 |
+
hypothesis should be left padded or not when they are
|
986 |
+
teacher forced for generating alignments.
|
987 |
+
"""
|
988 |
+
super().__init__(EnsembleModelWithAlignment(models), tgt_dict, **kwargs)
|
989 |
+
self.left_pad_target = left_pad_target
|
990 |
+
|
991 |
+
if print_alignment == "hard":
|
992 |
+
self.extract_alignment = utils.extract_hard_alignment
|
993 |
+
elif print_alignment == "soft":
|
994 |
+
self.extract_alignment = utils.extract_soft_alignment
|
995 |
+
|
996 |
+
@torch.no_grad()
|
997 |
+
def generate(self, models, sample, **kwargs):
|
998 |
+
finalized = super()._generate(sample, **kwargs)
|
999 |
+
|
1000 |
+
src_tokens = sample["net_input"]["src_tokens"]
|
1001 |
+
bsz = src_tokens.shape[0]
|
1002 |
+
beam_size = self.beam_size
|
1003 |
+
(
|
1004 |
+
src_tokens,
|
1005 |
+
src_lengths,
|
1006 |
+
prev_output_tokens,
|
1007 |
+
tgt_tokens,
|
1008 |
+
) = self._prepare_batch_for_alignment(sample, finalized)
|
1009 |
+
if any(getattr(m, "full_context_alignment", False) for m in self.model.models):
|
1010 |
+
attn = self.model.forward_align(src_tokens, src_lengths, prev_output_tokens)
|
1011 |
+
else:
|
1012 |
+
attn = [
|
1013 |
+
finalized[i // beam_size][i % beam_size]["attention"].transpose(1, 0)
|
1014 |
+
for i in range(bsz * beam_size)
|
1015 |
+
]
|
1016 |
+
|
1017 |
+
if src_tokens.device != "cpu":
|
1018 |
+
src_tokens = src_tokens.to("cpu")
|
1019 |
+
tgt_tokens = tgt_tokens.to("cpu")
|
1020 |
+
attn = [i.to("cpu") for i in attn]
|
1021 |
+
|
1022 |
+
# Process the attn matrix to extract hard alignments.
|
1023 |
+
for i in range(bsz * beam_size):
|
1024 |
+
alignment = self.extract_alignment(
|
1025 |
+
attn[i], src_tokens[i], tgt_tokens[i], self.pad, self.eos
|
1026 |
+
)
|
1027 |
+
finalized[i // beam_size][i % beam_size]["alignment"] = alignment
|
1028 |
+
return finalized
|
1029 |
+
|
1030 |
+
def _prepare_batch_for_alignment(self, sample, hypothesis):
|
1031 |
+
src_tokens = sample["net_input"]["src_tokens"]
|
1032 |
+
bsz = src_tokens.shape[0]
|
1033 |
+
src_tokens = (
|
1034 |
+
src_tokens[:, None, :]
|
1035 |
+
.expand(-1, self.beam_size, -1)
|
1036 |
+
.contiguous()
|
1037 |
+
.view(bsz * self.beam_size, -1)
|
1038 |
+
)
|
1039 |
+
src_lengths = sample["net_input"]["src_lengths"]
|
1040 |
+
src_lengths = (
|
1041 |
+
src_lengths[:, None]
|
1042 |
+
.expand(-1, self.beam_size)
|
1043 |
+
.contiguous()
|
1044 |
+
.view(bsz * self.beam_size)
|
1045 |
+
)
|
1046 |
+
prev_output_tokens = data_utils.collate_tokens(
|
1047 |
+
[beam["tokens"] for example in hypothesis for beam in example],
|
1048 |
+
self.pad,
|
1049 |
+
self.eos,
|
1050 |
+
self.left_pad_target,
|
1051 |
+
move_eos_to_beginning=True,
|
1052 |
+
)
|
1053 |
+
tgt_tokens = data_utils.collate_tokens(
|
1054 |
+
[beam["tokens"] for example in hypothesis for beam in example],
|
1055 |
+
self.pad,
|
1056 |
+
self.eos,
|
1057 |
+
self.left_pad_target,
|
1058 |
+
move_eos_to_beginning=False,
|
1059 |
+
)
|
1060 |
+
return src_tokens, src_lengths, prev_output_tokens, tgt_tokens
|
1061 |
+
|
1062 |
+
|
1063 |
+
class EnsembleModelWithAlignment(EnsembleModel):
|
1064 |
+
"""A wrapper around an ensemble of models."""
|
1065 |
+
|
1066 |
+
def __init__(self, models):
|
1067 |
+
super().__init__(models)
|
1068 |
+
|
1069 |
+
def forward_align(self, src_tokens, src_lengths, prev_output_tokens):
|
1070 |
+
avg_attn = None
|
1071 |
+
for model in self.models:
|
1072 |
+
decoder_out = model(src_tokens, src_lengths, prev_output_tokens)
|
1073 |
+
attn = decoder_out[1]["attn"][0]
|
1074 |
+
if avg_attn is None:
|
1075 |
+
avg_attn = attn
|
1076 |
+
else:
|
1077 |
+
avg_attn.add_(attn)
|
1078 |
+
if len(self.models) > 1:
|
1079 |
+
avg_attn.div_(len(self.models))
|
1080 |
+
return avg_attn
|
artst/tasks/__init__.py
ADDED
File without changes
|
artst/tasks/artst.py
ADDED
@@ -0,0 +1,711 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
|
5 |
+
# Based on speecht5, fairseq and espnet code bases
|
6 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
7 |
+
# --------------------------------------------------------
|
8 |
+
|
9 |
+
import logging
|
10 |
+
import os.path as op
|
11 |
+
from argparse import Namespace
|
12 |
+
from collections import OrderedDict
|
13 |
+
|
14 |
+
import torch
|
15 |
+
from fairseq.data import (
|
16 |
+
Dictionary,
|
17 |
+
encoders,
|
18 |
+
PrependTokenDataset,
|
19 |
+
AppendTokenDataset,
|
20 |
+
data_utils,
|
21 |
+
StripTokenDataset,
|
22 |
+
TokenBlockDataset,
|
23 |
+
)
|
24 |
+
from fairseq.data.encoders.utils import get_whole_word_mask
|
25 |
+
from fairseq import utils
|
26 |
+
from artst.data.multitask_dataset import MultitaskDataset
|
27 |
+
from artst.data.speech_to_text_dataset import SpeechToTextDataset
|
28 |
+
from artst.data.text_to_speech_dataset import TextToSpeechDataset
|
29 |
+
from artst.data.speech_to_speech_dataset import SpeechToSpeechDataset
|
30 |
+
from artst.data.speech_to_class_dataset import SpeechToClassDataset
|
31 |
+
from artst.data.speech_dataset import SpeechPretrainDataset
|
32 |
+
from artst.data.text_dataset import TextPretrainDataset
|
33 |
+
from fairseq.data.shorten_dataset import maybe_shorten_dataset
|
34 |
+
from fairseq.tasks import LegacyFairseqTask, register_task
|
35 |
+
from fairseq.tasks.hubert_pretraining import LabelEncoder
|
36 |
+
|
37 |
+
logger = logging.getLogger(__name__)
|
38 |
+
|
39 |
+
TASK_NAME = ["s2t", "t2s", "s2s", "s2c", "pretrain"]
|
40 |
+
|
41 |
+
@register_task("artst")
|
42 |
+
class ArTSTTask(LegacyFairseqTask):
|
43 |
+
@staticmethod
|
44 |
+
def add_args(parser):
|
45 |
+
parser.add_argument("data", help="manifest root path")
|
46 |
+
parser.add_argument(
|
47 |
+
"--config-yaml",
|
48 |
+
type=str,
|
49 |
+
default="config.yaml",
|
50 |
+
help="Configuration YAML filename (under manifest root)",
|
51 |
+
)
|
52 |
+
parser.add_argument(
|
53 |
+
"--max-speech-sample-size",
|
54 |
+
default=None,
|
55 |
+
type=int,
|
56 |
+
metavar="N",
|
57 |
+
help="max speech sample size",
|
58 |
+
)
|
59 |
+
parser.add_argument(
|
60 |
+
"--min-speech-sample-size",
|
61 |
+
default=None,
|
62 |
+
type=int,
|
63 |
+
metavar="N",
|
64 |
+
help="min speech sample size",
|
65 |
+
)
|
66 |
+
parser.add_argument(
|
67 |
+
"--max-speech-positions",
|
68 |
+
default=4000,
|
69 |
+
type=int,
|
70 |
+
metavar="N",
|
71 |
+
help="max number of tokens in the source sequence",
|
72 |
+
)
|
73 |
+
parser.add_argument(
|
74 |
+
"--max-text-positions",
|
75 |
+
default=450,
|
76 |
+
type=int,
|
77 |
+
metavar="N",
|
78 |
+
help="max number of tokens in the target sequence",
|
79 |
+
)
|
80 |
+
parser.add_argument(
|
81 |
+
'--t5-task',
|
82 |
+
choices=TASK_NAME,
|
83 |
+
help='task for training'
|
84 |
+
)
|
85 |
+
parser.add_argument(
|
86 |
+
"--bpe-tokenizer",
|
87 |
+
type=str,
|
88 |
+
default=None,
|
89 |
+
help="bpe tokenizer for s2t",
|
90 |
+
)
|
91 |
+
# Speaker Identification (SID)
|
92 |
+
parser.add_argument(
|
93 |
+
"--finetune-from-modules",
|
94 |
+
default=None,
|
95 |
+
# choices=[
|
96 |
+
# "encoder-decoder", "encoder", "decoder",
|
97 |
+
# "speech_encoder_prenet-encoder-decoder-text_decoder_prenet-text_decoder_postnet", # ASR, T5 SID
|
98 |
+
# "speech_encoder_prenet-encoder-decoder-text_decoder_prenet-speaker_decoder_postnet", # SID
|
99 |
+
# "speech_encoder_prenet-encoder-decoder-speech_decoder_prenet-speech_decoder_postnet", # VC, SE
|
100 |
+
# "text_encoder_prenet-encoder-decoder-speech_decoder_prenet-speech_decoder_postnet", # TTS
|
101 |
+
# ],
|
102 |
+
help="If set, using part modules of finetune model.",
|
103 |
+
)
|
104 |
+
parser.add_argument(
|
105 |
+
"--finetune-out-of-modules",
|
106 |
+
default=None,
|
107 |
+
# choices=[
|
108 |
+
# "speaker_decoder_postnet", # SID
|
109 |
+
# "speech_decoder_postnet", # SE with reduction factor 1
|
110 |
+
# ],
|
111 |
+
help="If set, remove part modules of finetune model.",
|
112 |
+
)
|
113 |
+
# BART
|
114 |
+
parser.add_argument(
|
115 |
+
"--shorten-method",
|
116 |
+
default="none",
|
117 |
+
choices=["none", "truncate", "random_crop"],
|
118 |
+
help="if not none, shorten sequences that exceed --tokens-per-sample",
|
119 |
+
)
|
120 |
+
parser.add_argument(
|
121 |
+
"--shorten-data-split-list",
|
122 |
+
default="",
|
123 |
+
help="comma-separated list of dataset splits to apply shortening to, "
|
124 |
+
'e.g., "train,valid" (default: all dataset splits)',
|
125 |
+
)
|
126 |
+
|
127 |
+
parser.add_argument(
|
128 |
+
"--tokens-per-sample",
|
129 |
+
default=512,
|
130 |
+
type=int,
|
131 |
+
help="max number of total tokens over all segments"
|
132 |
+
" per sample for dataset",
|
133 |
+
)
|
134 |
+
parser.add_argument(
|
135 |
+
"--sample-break-mode",
|
136 |
+
default="eos",
|
137 |
+
type=str,
|
138 |
+
help="mode for breaking sentence",
|
139 |
+
)
|
140 |
+
parser.add_argument(
|
141 |
+
"--mask",
|
142 |
+
default=0.3,
|
143 |
+
type=float,
|
144 |
+
help="fraction of words/subwords that will be masked",
|
145 |
+
)
|
146 |
+
parser.add_argument(
|
147 |
+
"--mask-random",
|
148 |
+
default=0.1,
|
149 |
+
type=float,
|
150 |
+
help="instead of using [MASK], use random token this often",
|
151 |
+
)
|
152 |
+
parser.add_argument(
|
153 |
+
"--insert",
|
154 |
+
default=0.0,
|
155 |
+
type=float,
|
156 |
+
help="insert this percentage of additional random tokens",
|
157 |
+
)
|
158 |
+
parser.add_argument(
|
159 |
+
"--permute",
|
160 |
+
default=0.0,
|
161 |
+
type=float,
|
162 |
+
help="take this proportion of subwords and permute them",
|
163 |
+
)
|
164 |
+
parser.add_argument(
|
165 |
+
"--rotate",
|
166 |
+
default=0.0,
|
167 |
+
type=float,
|
168 |
+
help="rotate this proportion of inputs",
|
169 |
+
)
|
170 |
+
parser.add_argument(
|
171 |
+
"--poisson-lambda",
|
172 |
+
default=3.5,
|
173 |
+
type=float,
|
174 |
+
help="randomly shuffle sentences for this proportion of inputs",
|
175 |
+
)
|
176 |
+
parser.add_argument(
|
177 |
+
"--permute-sentences",
|
178 |
+
default=0.0,
|
179 |
+
type=float,
|
180 |
+
help="shuffle this proportion of sentences in all inputs",
|
181 |
+
)
|
182 |
+
# parser.add_argument(
|
183 |
+
# "--mask-length",
|
184 |
+
# default="span-poisson",
|
185 |
+
# type=str,
|
186 |
+
# choices=["subword", "word", "span-poisson"],
|
187 |
+
# help="mask length to choose",
|
188 |
+
# )
|
189 |
+
parser.add_argument(
|
190 |
+
"--replace-length",
|
191 |
+
default=1,
|
192 |
+
type=int,
|
193 |
+
help="when masking N tokens, replace with 0, 1, or N tokens (use -1 for N)",
|
194 |
+
)
|
195 |
+
parser.add_argument(
|
196 |
+
"--iid-noise-target",
|
197 |
+
action="store_true",
|
198 |
+
help="whether to use t5 form target",
|
199 |
+
)
|
200 |
+
# Hubert
|
201 |
+
parser.add_argument(
|
202 |
+
"--hubert-labels",
|
203 |
+
nargs="*",
|
204 |
+
type=str,
|
205 |
+
default=['km'],
|
206 |
+
help="extension of the label files to load, frame-level labels for pre-training, and sequence-level label for fine-tuning",
|
207 |
+
)
|
208 |
+
parser.add_argument(
|
209 |
+
"--hubert-label-dir",
|
210 |
+
type=str,
|
211 |
+
default=None,
|
212 |
+
help="if set, looks for labels in this directory instead",
|
213 |
+
)
|
214 |
+
parser.add_argument(
|
215 |
+
"--sample-rate",
|
216 |
+
default=100,
|
217 |
+
type=float,
|
218 |
+
help="target sample rate. audio files will be up/down sampled to this rate",
|
219 |
+
)
|
220 |
+
parser.add_argument(
|
221 |
+
"--label-rates",
|
222 |
+
default=-1,
|
223 |
+
type=float,
|
224 |
+
help="if set, looks for labels in this directory instead",
|
225 |
+
)
|
226 |
+
parser.add_argument(
|
227 |
+
"--normalize",
|
228 |
+
action="store_true",
|
229 |
+
help="if set, normalizes input to have 0 mean and unit variance",
|
230 |
+
)
|
231 |
+
parser.add_argument(
|
232 |
+
"--enable-padding",
|
233 |
+
action="store_true",
|
234 |
+
help="pad shorter samples instead of cropping",
|
235 |
+
)
|
236 |
+
parser.add_argument(
|
237 |
+
"--pad-audio",
|
238 |
+
action="store_true",
|
239 |
+
help="pad audio to the longest one in the batch if true",
|
240 |
+
)
|
241 |
+
parser.add_argument(
|
242 |
+
"--random-crop",
|
243 |
+
action="store_true",
|
244 |
+
help="always crop from the beginning if false",
|
245 |
+
)
|
246 |
+
parser.add_argument(
|
247 |
+
"--single-target",
|
248 |
+
action="store_true",
|
249 |
+
help="if set, AddTargetDatasets outputs same keys "
|
250 |
+
"as AddTargetDataset",
|
251 |
+
)
|
252 |
+
parser.add_argument(
|
253 |
+
"--batch-ratio",
|
254 |
+
default=None,
|
255 |
+
type=str,
|
256 |
+
help="ratio of bach size for each dataset",
|
257 |
+
)
|
258 |
+
parser.add_argument(
|
259 |
+
"--sample-ratios",
|
260 |
+
default=None,
|
261 |
+
type=str,
|
262 |
+
help="ratio of sample for each dataset",
|
263 |
+
)
|
264 |
+
parser.add_argument(
|
265 |
+
"--ctc-weight",
|
266 |
+
type=float,
|
267 |
+
default=0.0,
|
268 |
+
help="ctc weight for inference",
|
269 |
+
)
|
270 |
+
parser.add_argument(
|
271 |
+
"--inference-speech",
|
272 |
+
type=bool,
|
273 |
+
default=False,
|
274 |
+
help="inference for TTS",
|
275 |
+
)
|
276 |
+
|
277 |
+
def __init__(self, args, dicts, config):
|
278 |
+
super().__init__(args)
|
279 |
+
self.dicts = dicts
|
280 |
+
self.config = config
|
281 |
+
self.t5_task = args.t5_task
|
282 |
+
# Used for filter size
|
283 |
+
if self.t5_task in ['s2t', 't2s', 's2s', 's2c']:
|
284 |
+
self.max_pos = [self.args.max_speech_positions * 256]
|
285 |
+
elif self.t5_task == 'pretrain':
|
286 |
+
self.max_pos = [self.args.max_speech_positions * 256, self.args.max_text_positions]
|
287 |
+
|
288 |
+
self.mask_idx = self.dicts["text"].add_symbol("<mask>")
|
289 |
+
# add blank token for ctc
|
290 |
+
# if args.ctc_weight > 0:
|
291 |
+
self.blank_symbol_idx = self.dicts["text"].add_symbol("<ctc_blank>")
|
292 |
+
self.blank_symbol = "<ctc_blank>"
|
293 |
+
|
294 |
+
# add mask token
|
295 |
+
if hasattr(args, "iid_noise_target") and args.iid_noise_target:
|
296 |
+
self.uni_mask_idxs = []
|
297 |
+
for i in range(600):
|
298 |
+
self.uni_mask_idxs.append(self.dicts["text"].add_symbol("<mask>" + str(i)))
|
299 |
+
self.uni_mask_idxs = torch.tensor(self.uni_mask_idxs)
|
300 |
+
|
301 |
+
self.seed = args.seed
|
302 |
+
|
303 |
+
@classmethod
|
304 |
+
def setup_task(cls, args, **kwargs):
|
305 |
+
# load dictionaries and config
|
306 |
+
dicts = OrderedDict()
|
307 |
+
if args.t5_task == 'pretrain' and not hasattr(args, "shuffle_instance"):
|
308 |
+
args.shuffle_instance = False
|
309 |
+
|
310 |
+
# Prepare config
|
311 |
+
config = None
|
312 |
+
logger.info('No config file for ' + args.t5_task)
|
313 |
+
|
314 |
+
if args.t5_task == "pretrain":
|
315 |
+
dicts["hubert"] = [Dictionary.load(f"{args.hubert_label_dir}/dict.{label}.txt") for label in args.hubert_labels]
|
316 |
+
dicts["text"] = Dictionary.load(op.join(args.data, "dict.txt"))
|
317 |
+
else:
|
318 |
+
if config is None:
|
319 |
+
dicts["text"] = Dictionary.load(op.join(args.data, "dict.txt"))
|
320 |
+
else:
|
321 |
+
dicts["text"] = Dictionary.load(op.join(args.data, config.vocab_filename))
|
322 |
+
|
323 |
+
return cls(args, dicts, config)
|
324 |
+
|
325 |
+
def build_criterion(self, args):
|
326 |
+
from fairseq import criterions
|
327 |
+
return criterions.build_criterion(args, self)
|
328 |
+
|
329 |
+
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
330 |
+
sample_ratios = []
|
331 |
+
if self.t5_task == "s2t":
|
332 |
+
## For speech to text task
|
333 |
+
bpe_tokenizer = self.build_bpe(self.args)
|
334 |
+
manifest = f"{self.args.data}/{split}.tsv"
|
335 |
+
procs = [LabelEncoder(self.dicts["text"])]
|
336 |
+
paths = [f"{self.args.hubert_label_dir}/{split}.txt"]
|
337 |
+
# Hawau: view dataset...
|
338 |
+
logger.info(f"Manifest: {manifest}")
|
339 |
+
# logger.info(f"Paths: {paths}")
|
340 |
+
self.datasets[split] = SpeechToTextDataset(
|
341 |
+
manifest,
|
342 |
+
sample_rate=self.args.sample_rate,
|
343 |
+
label_paths=paths,
|
344 |
+
label_processors=procs,
|
345 |
+
max_keep_sample_size=self.max_pos[0] if self.args.max_speech_sample_size is None else self.args.max_speech_sample_size,
|
346 |
+
min_keep_sample_size=self.args.min_speech_sample_size,
|
347 |
+
normalize=self.args.normalize,
|
348 |
+
store_labels=False,
|
349 |
+
tgt_dict=self.dicts["text"],
|
350 |
+
tokenizer=bpe_tokenizer,
|
351 |
+
)
|
352 |
+
elif self.t5_task == "t2s":
|
353 |
+
## For text to speech task
|
354 |
+
from fairseq.data import ConcatDataset
|
355 |
+
bpe_tokenizer = self.build_bpe(self.args)
|
356 |
+
procs = [LabelEncoder(self.dicts["text"])]
|
357 |
+
t2s_datasets = [
|
358 |
+
TextToSpeechDataset(
|
359 |
+
manifest_path=f"{self.args.data}/{name}.tsv",
|
360 |
+
sample_rate=self.args.sample_rate,
|
361 |
+
label_paths=[f"{self.args.hubert_label_dir}/{name}.txt"],
|
362 |
+
label_processors=procs,
|
363 |
+
max_keep_sample_size=self.max_pos[0],
|
364 |
+
normalize=self.args.normalize,
|
365 |
+
store_labels=False,
|
366 |
+
src_dict=self.dicts["text"],
|
367 |
+
tokenizer=bpe_tokenizer,
|
368 |
+
reduction_factor=self.args.reduction_factor,
|
369 |
+
inference=self.args.inference_speech,
|
370 |
+
)
|
371 |
+
for name in split.split(",")
|
372 |
+
]
|
373 |
+
self.datasets[split] = ConcatDataset(t2s_datasets) if len(t2s_datasets) > 1 else t2s_datasets[0]
|
374 |
+
elif self.t5_task == "s2s":
|
375 |
+
manifest = f"{self.args.data}/{split}.tsv"
|
376 |
+
self.datasets[split] = SpeechToSpeechDataset(
|
377 |
+
manifest_path=manifest,
|
378 |
+
sample_rate=self.args.sample_rate,
|
379 |
+
max_keep_sample_size=self.max_pos[0] if self.args.max_speech_sample_size is None else self.args.max_speech_sample_size,
|
380 |
+
min_keep_sample_size=self.args.min_speech_sample_size,
|
381 |
+
normalize=self.args.normalize,
|
382 |
+
reduction_factor=self.args.reduction_factor,
|
383 |
+
)
|
384 |
+
elif self.t5_task == "s2c":
|
385 |
+
is_train_split = ("train" in split)
|
386 |
+
is_valid_split = ("valid" in split)
|
387 |
+
if is_train_split:
|
388 |
+
max_length = 51200
|
389 |
+
elif is_valid_split:
|
390 |
+
max_length = 76800
|
391 |
+
else:
|
392 |
+
max_length = 2560000
|
393 |
+
manifest = op.join(f"{self.args.data}", f"{split}.tsv")
|
394 |
+
procs = LabelEncoder(self.dicts["text"]) # map speaker to id
|
395 |
+
self.datasets[split] = SpeechToClassDataset(
|
396 |
+
manifest_path=manifest,
|
397 |
+
sample_rate=self.args.sample_rate,
|
398 |
+
label_processors=procs,
|
399 |
+
max_keep_sample_size=self.max_pos[0] if self.args.max_speech_sample_size is None else self.args.max_speech_sample_size,
|
400 |
+
min_keep_sample_size=self.args.min_speech_sample_size,
|
401 |
+
normalize=self.args.normalize,
|
402 |
+
tgt_dict=self.dicts["text"],
|
403 |
+
max_length=max_length
|
404 |
+
)
|
405 |
+
elif self.t5_task == "pretrain":
|
406 |
+
is_train_split = ("train" in split)
|
407 |
+
pretrain_datasets = []
|
408 |
+
speech_split, text_split = split.split('|')
|
409 |
+
|
410 |
+
## Speech pre-train
|
411 |
+
manifest = f"{self.args.data}/{speech_split}.tsv"
|
412 |
+
dicts = self.dicts["hubert"]
|
413 |
+
pad_list = [dict.pad() for dict in dicts]
|
414 |
+
eos_list = [dict.eos() for dict in dicts]
|
415 |
+
procs = [LabelEncoder(dict) for dict in dicts]
|
416 |
+
paths = [
|
417 |
+
f"{self.args.hubert_label_dir}/{speech_split}.{l}" for l in self.args.hubert_labels
|
418 |
+
]
|
419 |
+
# hubert v1: pad_audio=True, random_crop=False;
|
420 |
+
self.args.dec_weight = getattr(self.args, "dec_weight", 1.0)
|
421 |
+
pretrain_datasets.append(
|
422 |
+
SpeechPretrainDataset(
|
423 |
+
manifest,
|
424 |
+
sample_rate=self.args.sample_rate,
|
425 |
+
label_paths=paths,
|
426 |
+
label_rates=self.args.label_rates,
|
427 |
+
pad_list=pad_list,
|
428 |
+
eos_list=eos_list,
|
429 |
+
label_processors=procs,
|
430 |
+
max_keep_sample_size=None,
|
431 |
+
min_keep_sample_size=32000,
|
432 |
+
max_sample_size=self.args.max_speech_sample_size,
|
433 |
+
pad_audio=self.args.pad_audio,
|
434 |
+
normalize=self.args.normalize,
|
435 |
+
store_labels=False,
|
436 |
+
random_crop=self.args.random_crop,
|
437 |
+
single_target=self.args.single_target,
|
438 |
+
reduction_factor=self.args.reduction_factor,
|
439 |
+
)
|
440 |
+
)
|
441 |
+
sample_ratios.append(sum([pretrain_datasets[0].size(i) for i in range(len(pretrain_datasets[0]))]))
|
442 |
+
|
443 |
+
## Text pre-train
|
444 |
+
paths = utils.split_paths(self.args.data)
|
445 |
+
assert len(paths) > 0
|
446 |
+
data_path = paths[(epoch - 1) % len(paths)]
|
447 |
+
print(f"Loading {text_split} from data_path={data_path}")
|
448 |
+
split_path = op.join(data_path, text_split)
|
449 |
+
print(f"split_path={split_path}")
|
450 |
+
bart_dataset = data_utils.load_indexed_dataset(
|
451 |
+
split_path,
|
452 |
+
self.dicts["text"],
|
453 |
+
self.args.dataset_impl,
|
454 |
+
combine=combine,
|
455 |
+
)
|
456 |
+
if bart_dataset is None:
|
457 |
+
raise FileNotFoundError(
|
458 |
+
"Dataset not found: {} ({})".format(text_split, split_path)
|
459 |
+
)
|
460 |
+
bart_dataset = StripTokenDataset(bart_dataset, self.dicts["text"].eos())
|
461 |
+
bart_dataset = maybe_shorten_dataset(
|
462 |
+
bart_dataset,
|
463 |
+
text_split,
|
464 |
+
self.args.shorten_data_split_list,
|
465 |
+
self.args.shorten_method,
|
466 |
+
self.args.tokens_per_sample,
|
467 |
+
self.args.seed,
|
468 |
+
)
|
469 |
+
# create continuous blocks of tokens
|
470 |
+
bart_dataset = TokenBlockDataset(
|
471 |
+
bart_dataset,
|
472 |
+
bart_dataset.sizes,
|
473 |
+
self.args.tokens_per_sample - 2, # one less for <s> and one for </s>
|
474 |
+
pad=self.dicts["text"].pad(),
|
475 |
+
eos=self.dicts["text"].eos(),
|
476 |
+
break_mode=self.args.sample_break_mode,
|
477 |
+
document_sep_len=0,
|
478 |
+
)
|
479 |
+
# prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
|
480 |
+
bart_dataset = PrependTokenDataset(bart_dataset, self.dicts["text"].bos())
|
481 |
+
bart_dataset = AppendTokenDataset(bart_dataset, self.dicts["text"].eos())
|
482 |
+
mask_whole_words = (
|
483 |
+
get_whole_word_mask(self.args, self.dicts["text"])
|
484 |
+
if self.args.mask_length != "subword"
|
485 |
+
else None
|
486 |
+
)
|
487 |
+
self.args.bert_weight = getattr(self.args, "bert_weight", 0.0)
|
488 |
+
pretrain_datasets.append(
|
489 |
+
TextPretrainDataset(
|
490 |
+
bart_dataset,
|
491 |
+
bart_dataset.sizes,
|
492 |
+
self.dicts["text"],
|
493 |
+
self.mask_idx,
|
494 |
+
mask_whole_words,
|
495 |
+
shuffle=self.args.shuffle_instance,
|
496 |
+
seed=self.seed,
|
497 |
+
args=self.args,
|
498 |
+
iid_noise_target=self.args.iid_noise_target,
|
499 |
+
uni_mask_idxs=self.uni_mask_idxs if self.args.iid_noise_target else None,
|
500 |
+
)
|
501 |
+
)
|
502 |
+
sample_ratios.append(sum(pretrain_datasets[1].sizes))
|
503 |
+
logger.info(
|
504 |
+
"Task: {0}, Loaded {1} samples of denoising_dataset".format(
|
505 |
+
'bart',
|
506 |
+
len(pretrain_datasets[1]),
|
507 |
+
)
|
508 |
+
)
|
509 |
+
|
510 |
+
logger.info('token ratio is ' + str(sample_ratios))
|
511 |
+
if self.args.batch_ratio is not None:
|
512 |
+
batch_ratio = eval(self.args.batch_ratio)
|
513 |
+
assert len(batch_ratio) == len(sample_ratios)
|
514 |
+
sample_ratios = [sample_ratios[i] / batch_ratio[i] for i in range(len(sample_ratios))]
|
515 |
+
else:
|
516 |
+
batch_ratio = None
|
517 |
+
max_size = max(sample_ratios)
|
518 |
+
sample_ratios = [max_size / r for r in sample_ratios]
|
519 |
+
if hasattr(self.args, "sample_ratios") and self.args.sample_ratios is not None:
|
520 |
+
sample_ratios = eval(self.args.sample_ratios)
|
521 |
+
if is_train_split:
|
522 |
+
self.datasets[split] = MultitaskDataset(
|
523 |
+
pretrain_datasets, sample_ratios, batch_ratio
|
524 |
+
)
|
525 |
+
else:
|
526 |
+
self.datasets[split] = MultitaskDataset(
|
527 |
+
pretrain_datasets, batch_ratio=batch_ratio
|
528 |
+
)
|
529 |
+
|
530 |
+
def train_step(
|
531 |
+
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
|
532 |
+
):
|
533 |
+
model.train()
|
534 |
+
model.set_num_updates(update_num)
|
535 |
+
|
536 |
+
# Junyi: not use sample_size, but normalize the loss locally
|
537 |
+
agg_loss, agg_sample_size, agg_logging_output = 0.0, 1.0, {}
|
538 |
+
agg_logging_output['sample_size'] = 1
|
539 |
+
|
540 |
+
def forward_backward(model, samples, weight=1.0):
|
541 |
+
nonlocal agg_loss, agg_logging_output
|
542 |
+
if samples is None or len(samples) == 0:
|
543 |
+
return
|
544 |
+
loss, sample_size, logging_output = criterion(model, samples)
|
545 |
+
if ignore_grad:
|
546 |
+
loss *= 0
|
547 |
+
else:
|
548 |
+
loss *= weight
|
549 |
+
loss = loss / sample_size
|
550 |
+
optimizer.backward(loss)
|
551 |
+
agg_loss += loss.detach().item()
|
552 |
+
# # TODO make summing of the sample sizes configurable
|
553 |
+
for k in logging_output:
|
554 |
+
if k == 'ntokens' or k == 'nsentences':
|
555 |
+
if k not in agg_logging_output:
|
556 |
+
agg_logging_output[k] = 0
|
557 |
+
agg_logging_output[k] += logging_output[k]
|
558 |
+
# continue
|
559 |
+
# agg_logging_output[k] += logging_output[k]
|
560 |
+
# agg_logging_output[task_name] += logging_output[k]
|
561 |
+
agg_logging_output[samples['task_name']] = logging_output
|
562 |
+
|
563 |
+
forward_backward(model, sample)
|
564 |
+
|
565 |
+
agg_logging_output["loss"] = agg_loss
|
566 |
+
|
567 |
+
return agg_loss, agg_sample_size, agg_logging_output
|
568 |
+
|
569 |
+
def valid_step(self, sample, model, criterion):
|
570 |
+
model.eval()
|
571 |
+
with torch.no_grad():
|
572 |
+
from collections import defaultdict
|
573 |
+
|
574 |
+
agg_loss, agg_sample_size, agg_logging_output = 0.0, 1.0, defaultdict(float)
|
575 |
+
agg_logging_output['sample_size'] = 1
|
576 |
+
loss, sample_size, logging_output = criterion(model, sample)
|
577 |
+
loss = loss / sample_size
|
578 |
+
# agg_loss += loss.data.item() if isinstance(loss, torch.Tensor) else loss
|
579 |
+
agg_loss += loss.item() if isinstance(loss, torch.Tensor) else loss
|
580 |
+
agg_logging_output[sample['task_name']] = logging_output
|
581 |
+
agg_logging_output["loss"] = agg_loss
|
582 |
+
return agg_loss, agg_sample_size, agg_logging_output
|
583 |
+
|
584 |
+
@property
|
585 |
+
def target_dictionary(self):
|
586 |
+
return self.dicts["text"]
|
587 |
+
|
588 |
+
@property
|
589 |
+
def source_dictionary(self):
|
590 |
+
return None
|
591 |
+
|
592 |
+
def build_model(self, args):
|
593 |
+
try:
|
594 |
+
args.input_feat_per_channel = self.config.input_feat_per_channel
|
595 |
+
args.input_channels = self.config.input_channels
|
596 |
+
except Exception as e:
|
597 |
+
args.input_feat_per_channel = 80
|
598 |
+
args.input_channels = 1
|
599 |
+
logger.info(f"Cannot set input_feat_per_channel, input_channels, since: ")
|
600 |
+
logger.warn(e)
|
601 |
+
logger.info(f"Set to: {args.input_feat_per_channel} and {args.input_channels}")
|
602 |
+
|
603 |
+
args.speech_odim = args.input_feat_per_channel * args.input_channels
|
604 |
+
|
605 |
+
args.label_rates = self.args.label_rates
|
606 |
+
args.sample_rate = self.args.sample_rate
|
607 |
+
self.args.reduction_factor = args.reduction_factor
|
608 |
+
return super(ArTSTTask, self).build_model(args)
|
609 |
+
|
610 |
+
def build_generator(
|
611 |
+
self,
|
612 |
+
models,
|
613 |
+
args,
|
614 |
+
seq_gen_cls=None,
|
615 |
+
extra_gen_cls_kwargs=None,
|
616 |
+
):
|
617 |
+
from artst.sequence_generator import SequenceGenerator
|
618 |
+
extra_gen_cls_kwargs = {
|
619 |
+
"ctc_weight": self.args.ctc_weight,
|
620 |
+
**extra_gen_cls_kwargs
|
621 |
+
}
|
622 |
+
return super().build_generator(
|
623 |
+
models, args, seq_gen_cls=SequenceGenerator, extra_gen_cls_kwargs=extra_gen_cls_kwargs
|
624 |
+
)
|
625 |
+
|
626 |
+
def build_tokenizer(self, args):
|
627 |
+
if self.config is None:
|
628 |
+
logger.info(f"pre-tokenizer: None")
|
629 |
+
return encoders.build_tokenizer(Namespace(**{"tokenizer": None}))
|
630 |
+
else:
|
631 |
+
logger.info(f"pre-tokenizer: {self.config.pre_tokenizer}")
|
632 |
+
return encoders.build_tokenizer(Namespace(**self.config.pre_tokenizer))
|
633 |
+
|
634 |
+
def build_bpe(self, args):
|
635 |
+
if self.config is not None:
|
636 |
+
logger.info(f"tokenizer: {self.config.bpe_tokenizer}")
|
637 |
+
return encoders.build_bpe(Namespace(**self.config.bpe_tokenizer))
|
638 |
+
else:
|
639 |
+
logger.info(f"tokenizer: {self.args.bpe_tokenizer}")
|
640 |
+
return encoders.build_bpe(Namespace(**{"bpe": "sentencepiece", "sentencepiece_model": self.args.bpe_tokenizer}))
|
641 |
+
|
642 |
+
def generate_class(self, models, net_input, prefix_tokens, **kwargs):
|
643 |
+
with torch.no_grad():
|
644 |
+
encoder_input = {
|
645 |
+
k: v for k, v in net_input.items() if k != "prev_output_tokens" and k != "task_name"
|
646 |
+
}
|
647 |
+
encoder_input.update(kwargs)
|
648 |
+
encoder_input.update({"prev_output_tokens": prefix_tokens})
|
649 |
+
return models[0].generate_class(**encoder_input)
|
650 |
+
|
651 |
+
def generate_speech(self, models, net_input, **kwargs):
|
652 |
+
with torch.no_grad():
|
653 |
+
encoder_input = {
|
654 |
+
k: v for k, v in net_input.items() if k != "prev_output_tokens" and k != "task_name"
|
655 |
+
}
|
656 |
+
encoder_input.update(kwargs)
|
657 |
+
return models[0].generate_speech(**encoder_input)
|
658 |
+
|
659 |
+
def inference_t2s(
|
660 |
+
self, models, sample
|
661 |
+
):
|
662 |
+
with torch.no_grad():
|
663 |
+
xs = sample['net_input']['src_tokens']
|
664 |
+
spkemb = sample['net_input']['spkembs']
|
665 |
+
return models[0].inference(xs, spkemb)
|
666 |
+
|
667 |
+
def inference_s2s(
|
668 |
+
self, models, sample, force_equal_length=False
|
669 |
+
):
|
670 |
+
with torch.no_grad():
|
671 |
+
x = sample['net_input']['src_tokens']
|
672 |
+
xlen = sample['net_input']['src_lengths']
|
673 |
+
spkemb = sample['net_input']['spkembs']
|
674 |
+
prev_output_tokens = sample['net_input']['prev_output_tokens']
|
675 |
+
padding_mask = sample['net_input']['padding_mask']
|
676 |
+
tgt_lengths = sample['net_input']['tgt_lengths']
|
677 |
+
return models[0].inference_s2s(x, xlen, spkemb, prev_output_tokens, tgt_lengths, force_equal_length=force_equal_length, padding_mask=padding_mask)
|
678 |
+
|
679 |
+
def inference_s2c(
|
680 |
+
self, models, sample
|
681 |
+
):
|
682 |
+
with torch.no_grad():
|
683 |
+
x = sample['net_input']['src_tokens']
|
684 |
+
xlen = sample['net_input']['src_lengths']
|
685 |
+
prev_output_tokens = sample['net_input']['prev_output_tokens']
|
686 |
+
padding_mask = sample['net_input']['padding_mask']
|
687 |
+
assert prev_output_tokens.size(1) == 1, prev_output_tokens.size()
|
688 |
+
return models[0].inference_s2c(x, xlen, prev_output_tokens, padding_mask=padding_mask)
|
689 |
+
|
690 |
+
def filter_indices_by_size(
|
691 |
+
self, indices, dataset, max_positions=None, ignore_invalid_inputs=False
|
692 |
+
):
|
693 |
+
"""
|
694 |
+
Filter examples that are too large
|
695 |
+
|
696 |
+
Args:
|
697 |
+
indices (np.array): original array of sample indices
|
698 |
+
dataset (~fairseq.data.FairseqDataset): dataset to batch
|
699 |
+
max_positions (optional): max sentence length supported by the
|
700 |
+
model (default: None).
|
701 |
+
ignore_invalid_inputs (bool, optional): don't raise Exception for
|
702 |
+
sentences that are too long (default: False).
|
703 |
+
Returns:
|
704 |
+
np.array: array of filtered sample indices
|
705 |
+
"""
|
706 |
+
|
707 |
+
indices, ignored = dataset.filter_indices_by_size(
|
708 |
+
indices,
|
709 |
+
self.max_pos
|
710 |
+
)
|
711 |
+
return indices
|
ckpts/mgb2_asr.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9aaf1c09f146aff9361e999885c861f8ab0a6b0c997ceb34682ad3366849cc91
|
3 |
+
size 1847116641
|
pre-requirements.txt
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
cython==0.29.35
|
2 |
+
fairseq==0.12.2
|
3 |
+
datasets==2.12.0
|
4 |
+
editdistance==0.6.2
|
5 |
+
espnet==202304
|
6 |
+
espnet-tts-frontend==0.0.3
|
7 |
+
librosa==0.9.2
|
8 |
+
omegaconf==2.0.6
|
9 |
+
pandas==2.0.1
|
10 |
+
PyArabic==0.6.15
|
11 |
+
scipy
|
12 |
+
soundfile
|
13 |
+
tqdm==4.65.0
|
14 |
+
tweepy==4.14.0
|
15 |
+
tensorboard
|
16 |
+
kaldiio==2.18.0
|
17 |
+
numpy==1.23.5
|
18 |
+
cmake==3.26.4
|
19 |
+
pillow==10.0.0
|
20 |
+
nvidia-cublas-cu11==11.10.3.66
|
21 |
+
nvidia-cuda-cupti-cu11==11.7.101
|
22 |
+
nvidia-cuda-nvrtc-cu11==11.7.99
|
23 |
+
nvidia-cuda-runtime-cu11==11.7.99
|
24 |
+
nvidia-cudnn-cu11==8.5.0.96
|
25 |
+
nvidia-cufft-cu11==10.9.0.58
|
26 |
+
nvidia-curand-cu11==10.2.10.91
|
27 |
+
nvidia-cusolver-cu11==11.4.0.1
|
28 |
+
nvidia-cusparse-cu11==11.7.4.91
|
29 |
+
nvidia-nccl-cu11==2.14.3
|
30 |
+
nvidia-nvtx-cu11==11.7.91
|
31 |
+
tensorboardx==2.6
|
32 |
+
transformers
|
33 |
+
speechbrain
|
34 |
+
numpy==1.23.5
|
samples/sample_audio.wav
ADDED
Binary file (97.9 kB). View file
|
|
utils/arabic.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:17dd0feab129e71329f98fe9efd6143be4f6e1aede2408fe6c586f803f7d6cc0
|
3 |
+
size 539226
|
utils/audios.tsv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
./
|
2 |
+
/tmp/gradio/8f338949403b3fbd48ea5459ce1a148e9ebf96ee/sample_audio.wav/t30000
|
utils/dict.txt
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
▁ 7888394
|
2 |
+
ا 5240837
|
3 |
+
ل 4086510
|
4 |
+
ي 2961895
|
5 |
+
م 2193275
|
6 |
+
ن 2167596
|
7 |
+
و 1852812
|
8 |
+
ت 1552697
|
9 |
+
ر 1543221
|
10 |
+
ع 1224967
|
11 |
+
ب 1116473
|
12 |
+
ه 1088784
|
13 |
+
د 1060960
|
14 |
+
ة 1037790
|
15 |
+
أ 940870
|
16 |
+
س 935180
|
17 |
+
ف 834969
|
18 |
+
ك 831374
|
19 |
+
ق 807989
|
20 |
+
ح 704969
|
21 |
+
ج 436806
|
22 |
+
ذ 368254
|
23 |
+
ط 344915
|
24 |
+
إ 327964
|
25 |
+
ش 314812
|
26 |
+
ى 307987
|
27 |
+
ص 290514
|
28 |
+
خ 277022
|
29 |
+
ض 252120
|
30 |
+
ث 197381
|
31 |
+
ز 166103
|
32 |
+
ئ 137402
|
33 |
+
ً 122431
|
34 |
+
غ 113416
|
35 |
+
ء 111245
|
36 |
+
ظ 98346
|
37 |
+
ُ 75189
|
38 |
+
آ 58959
|
39 |
+
ؤ 57561
|
40 |
+
ّ 39925
|
41 |
+
0 27260
|
42 |
+
ٍ 24728
|
43 |
+
َ 22922
|
44 |
+
ِ 21587
|
45 |
+
1 18933
|
46 |
+
2 15936
|
47 |
+
ٌ 10327
|
48 |
+
5 6938
|
49 |
+
9 6502
|
50 |
+
6 4829
|
51 |
+
7 4380
|
52 |
+
8 4041
|
53 |
+
e 3578
|
54 |
+
a 3330
|
55 |
+
% 2892
|
56 |
+
ـ 2768
|
57 |
+
t 2746
|
58 |
+
i 2683
|
59 |
+
n 2462
|
60 |
+
o 2390
|
61 |
+
r 2341
|
62 |
+
s 2073
|
63 |
+
c 1561
|
64 |
+
l 1504
|
65 |
+
m 1141
|
66 |
+
d 1018
|
67 |
+
p 988
|
68 |
+
u 920
|
69 |
+
h 888
|
70 |
+
g 777
|
71 |
+
b 762
|
72 |
+
f 603
|
73 |
+
y 562
|
74 |
+
k 515
|
75 |
+
w 470
|
76 |
+
v 293
|
77 |
+
j 254
|
78 |
+
z 193
|
79 |
+
x 105
|
80 |
+
@ 97
|
81 |
+
3 52
|
82 |
+
4 48
|
83 |
+
q 30
|
84 |
+
̇ 6
|
85 |
+
ٱ 2
|
86 |
+
⁄ 1
|
87 |
+
madeupword0000 0
|
88 |
+
madeupword0001 0
|
89 |
+
madeupword0002 0
|
90 |
+
madeupword0003 0
|
91 |
+
madeupword0004 0
|
92 |
+
madeupword0005 0
|