Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 -u | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
Translate pre-processed data with a trained model. | |
""" | |
import ast | |
import logging | |
import argparse | |
import math | |
import os | |
import sys | |
from argparse import Namespace | |
from itertools import chain | |
import numpy as np | |
import torch | |
from omegaconf import DictConfig | |
from fairseq import checkpoint_utils, options, scoring, tasks, utils | |
from fairseq.dataclass.utils import convert_namespace_to_omegaconf | |
from fairseq.logging import progress_bar | |
from fairseq.logging.meters import StopwatchMeter, TimeMeter | |
import os | |
import torch | |
import gradio as gr | |
import numpy as np | |
import os.path as op | |
import pyarabic.araby as araby | |
import subprocess | |
import soundfile as sf | |
from artst.tasks.artst import ArTSTTask | |
from artst.models.artst import ArTSTTransformerModel | |
from fairseq.tasks.hubert_pretraining import LabelEncoder | |
from fairseq import checkpoint_utils, options, scoring, tasks, utils | |
from loguru import logger | |
from fairseq.logging.meters import StopwatchMeter, TimeMeter | |
def postprocess(wav, cur_sample_rate): | |
if wav.dim() == 2: | |
wav = wav.mean(-1) | |
assert wav.dim() == 1, wav.dim() | |
if cur_sample_rate != 16000: | |
raise Exception(f"sr {cur_sample_rate} != {16000}") | |
return wav | |
def main(cfg: DictConfig, audio_path): | |
print('config') | |
print(cfg) | |
if isinstance(cfg, Namespace): | |
cfg = convert_namespace_to_omegaconf(cfg) | |
assert cfg.common_eval.path is not None, "--path required for generation!" | |
assert ( | |
not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam | |
), "--sampling requires --nbest to be equal to --beam" | |
assert ( | |
cfg.generation.replace_unk is None or cfg.dataset.dataset_impl == "raw" | |
), "--replace-unk requires a raw text dataset (--dataset-impl=raw)" | |
if cfg.common_eval.results_path is not None: | |
os.makedirs(cfg.common_eval.results_path, exist_ok=True) | |
output_path = os.path.join( | |
cfg.common_eval.results_path, | |
"generate-{}.txt".format(cfg.dataset.gen_subset), | |
) | |
with open(output_path, "w", buffering=1, encoding="utf-8") as h: | |
return _main(cfg, h) | |
else: | |
return _main(cfg, sys.stdout, audio_path) | |
def get_symbols_to_strip_from_output(generator): | |
if hasattr(generator, "symbols_to_strip_from_output"): | |
return generator.symbols_to_strip_from_output | |
else: | |
return {generator.eos} | |
def _main(cfg: DictConfig, output_file, audio_path): | |
logging.basicConfig( | |
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", | |
datefmt="%Y-%m-%d %H:%M:%S", | |
level=os.environ.get("LOGLEVEL", "INFO").upper(), | |
stream=output_file, | |
) | |
logger = logging.getLogger("fairseq_cli.generate") | |
utils.import_user_module(cfg.common) | |
if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None: | |
cfg.dataset.max_tokens = 12000 | |
logger.info(cfg) | |
# Fix seed for stochastic decoding | |
if cfg.common.seed is not None and not cfg.generation.no_seed_provided: | |
np.random.seed(cfg.common.seed) | |
utils.set_torch_seed(cfg.common.seed) | |
use_cuda = torch.cuda.is_available() and not cfg.common.cpu | |
# Load dataset splits | |
task = tasks.setup_task(cfg.task) | |
# Set dictionaries | |
try: | |
src_dict = getattr(task, "source_dictionary", None) | |
except NotImplementedError: | |
src_dict = None | |
tgt_dict = task.target_dictionary | |
overrides = ast.literal_eval(cfg.common_eval.model_overrides) | |
# Load ensemble | |
logger.info("loading model(s) from {}".format(cfg.common_eval.path)) | |
models, saved_cfg = checkpoint_utils.load_model_ensemble( | |
utils.split_paths(cfg.common_eval.path), | |
arg_overrides=overrides, | |
task=task, | |
suffix=cfg.checkpoint.checkpoint_suffix, | |
strict=(cfg.checkpoint.checkpoint_shard_count == 1), | |
num_shards=cfg.checkpoint.checkpoint_shard_count, | |
) | |
# loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config | |
# task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task) | |
if cfg.generation.lm_path is not None: | |
overrides["data"] = cfg.task.data | |
try: | |
lms, _ = checkpoint_utils.load_model_ensemble( | |
[cfg.generation.lm_path], arg_overrides=overrides, task=None | |
) | |
except: | |
logger.warning( | |
f"Failed to load language model! Please make sure that the language model dict is the same " | |
f"as target dict and is located in the data dir ({cfg.task.data})" | |
) | |
raise | |
assert len(lms) == 1 | |
else: | |
lms = [None] | |
# Optimize ensemble for generation | |
for model in chain(models, lms): | |
if model is None: | |
continue | |
if cfg.common.fp16: | |
model.half() | |
if use_cuda and not cfg.distributed_training.pipeline_model_parallel: | |
model.cuda() | |
model.prepare_for_inference_(cfg) | |
# Load alignment dictionary for unknown word replacement | |
# (None if no unknown word replacement, empty if no path to align dictionary) | |
align_dict = utils.load_align_dict(cfg.generation.replace_unk) | |
# Initialize generator | |
gen_timer = StopwatchMeter() | |
extra_gen_cls_kwargs = {"lm_model": lms[0], "lm_weight": cfg.generation.lm_weight} | |
generator = task.build_generator( | |
models, cfg.generation, extra_gen_cls_kwargs=extra_gen_cls_kwargs | |
) | |
# Handle tokenization and BPE | |
tokenizer = task.build_tokenizer(cfg.tokenizer) | |
bpe = task.build_bpe(cfg.bpe) | |
def decode_fn(x): | |
if bpe is not None: | |
x = bpe.decode(x) | |
if tokenizer is not None: | |
x = tokenizer.decode(x) | |
return x | |
scorer = scoring.build_scorer(cfg.scoring, tgt_dict) | |
num_sentences = 0 | |
has_target = True | |
wps_meter = TimeMeter() | |
wav, cur_sample_rate = sf.read(audio_path) | |
wav = torch.from_numpy(wav).float() | |
wav = postprocess(wav, cur_sample_rate) | |
sample = {'index': 0, 'net_input': {'source': torch.tensor(wav).unsqueeze(dim=0), 'padding_mask': | |
torch.BoolTensor(wav.shape).fill_(False).unsqueeze(dim=0)}, 'id': [0], 'target': [[None], ]} | |
prefix_tokens = None | |
if cfg.generation.prefix_size > 0: | |
prefix_tokens = sample["target"][:, : cfg.generation.prefix_size] | |
constraints = None | |
if "constraints" in sample: | |
constraints = sample["constraints"] | |
gen_timer.start() | |
hypos = task.inference_step( | |
generator, | |
models, | |
sample, | |
prefix_tokens=prefix_tokens, | |
constraints=constraints, | |
) | |
num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos) | |
gen_timer.stop(num_generated_tokens) | |
for i, sample_id in enumerate(sample["id"]): | |
has_target = False | |
# Remove padding | |
if "src_tokens" in sample["net_input"]: | |
src_tokens = utils.strip_pad( | |
sample["net_input"]["src_tokens"][i, :], tgt_dict.pad() | |
) | |
else: | |
src_tokens = None | |
target_tokens = None | |
if has_target: | |
target_tokens = ( | |
utils.strip_pad(sample["target"][i, :], tgt_dict.pad()).int().cpu() | |
) | |
# Either retrieve the original sentences or regenerate them from tokens. | |
if align_dict is not None: | |
src_str = task.dataset(cfg.dataset.gen_subset).src.get_original_text( | |
sample_id | |
) | |
target_str = task.dataset(cfg.dataset.gen_subset).tgt.get_original_text( | |
sample_id | |
) | |
else: | |
if src_dict is not None: | |
src_str = src_dict.string(src_tokens, cfg.common_eval.post_process) | |
else: | |
src_str = "" | |
if has_target: | |
target_str = tgt_dict.string( | |
target_tokens, | |
cfg.common_eval.post_process, | |
escape_unk=True, | |
extra_symbols_to_ignore=get_symbols_to_strip_from_output( | |
generator | |
), | |
) | |
src_str = decode_fn(src_str) | |
if has_target: | |
target_str = decode_fn(target_str) | |
if not cfg.common_eval.quiet: | |
if src_dict is not None: | |
print("S-{}\t{}".format(sample_id, src_str), file=output_file) | |
if has_target: | |
print("T-{}\t{}".format(sample_id, target_str), file=output_file) | |
# Process top predictions | |
for j, hypo in enumerate(hypos[i][: cfg.generation.nbest]): | |
hypo_tokens, hypo_str, alignment = utils.post_process_prediction( | |
hypo_tokens=hypo["tokens"].int().cpu(), | |
src_str=src_str, | |
alignment=hypo["alignment"], | |
align_dict=align_dict, | |
tgt_dict=tgt_dict, | |
remove_bpe=cfg.common_eval.post_process, | |
extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator), | |
) | |
detok_hypo_str = decode_fn(hypo_str) | |
if not cfg.common_eval.quiet: | |
score = hypo["score"] / math.log(2) # convert to base 2 | |
# original hypothesis (after tokenization and BPE) | |
print( | |
"H-{}\t{}\t{}".format(sample_id, score, hypo_str), | |
file=output_file, | |
) | |
# detokenized hypothesis | |
print( | |
"D-{}\t{}\t{}".format(sample_id, score, detok_hypo_str), | |
file=output_file, | |
) | |
print( | |
"P-{}\t{}".format( | |
sample_id, | |
" ".join( | |
map( | |
lambda x: "{:.4f}".format(x), | |
# convert from base e to base 2 | |
hypo["positional_scores"] | |
.div_(math.log(2)) | |
.tolist(), | |
) | |
), | |
), | |
file=output_file, | |
) | |
if cfg.generation.print_alignment == "hard": | |
print( | |
"A-{}\t{}".format( | |
sample_id, | |
" ".join( | |
[ | |
"{}-{}".format(src_idx, tgt_idx) | |
for src_idx, tgt_idx in alignment | |
] | |
), | |
), | |
file=output_file, | |
) | |
if cfg.generation.print_alignment == "soft": | |
print( | |
"A-{}\t{}".format( | |
sample_id, | |
" ".join( | |
[",".join(src_probs) for src_probs in alignment] | |
), | |
), | |
file=output_file, | |
) | |
if cfg.generation.print_step: | |
print( | |
"I-{}\t{}".format(sample_id, hypo["steps"]), | |
file=output_file, | |
) | |
if cfg.generation.retain_iter_history: | |
for step, h in enumerate(hypo["history"]): | |
_, h_str, _ = utils.post_process_prediction( | |
hypo_tokens=h["tokens"].int().cpu(), | |
src_str=src_str, | |
alignment=None, | |
align_dict=None, | |
tgt_dict=tgt_dict, | |
remove_bpe=None, | |
) | |
print( | |
"E-{}_{}\t{}".format(sample_id, step, h_str), | |
file=output_file, | |
) | |
# Score only the top hypothesis | |
if has_target and j == 0: | |
if ( | |
align_dict is not None | |
or cfg.common_eval.post_process is not None | |
): | |
# Convert back to tokens for evaluation with unk replacement and/or without BPE | |
target_tokens = tgt_dict.encode_line( | |
target_str, add_if_not_exist=True | |
) | |
hypo_tokens = tgt_dict.encode_line( | |
detok_hypo_str, add_if_not_exist=True | |
) | |
if hasattr(scorer, "add_string"): | |
scorer.add_string(target_str, detok_hypo_str) | |
else: | |
scorer.add(target_tokens, hypo_tokens) | |
wps_meter.update(num_generated_tokens) | |
# progress.log({"wps": round(wps_meter.avg)}) | |
logger.info("NOTE: hypothesis and token scores are output in base 2") | |
if has_target: | |
if cfg.bpe and not cfg.generation.sacrebleu: | |
if cfg.common_eval.post_process: | |
logger.warning( | |
"BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization" | |
) | |
else: | |
logger.warning( | |
"If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words. Use --sacrebleu for standard 13a BLEU tokenization" | |
) | |
# use print to be consistent with other main outputs: S-, H-, T-, D- and so on | |
print( | |
"Generate {} with beam={}: {}".format( | |
cfg.dataset.gen_subset, cfg.generation.beam, scorer.result_string() | |
), | |
file=output_file, | |
) | |
return detok_hypo_str | |
def inference(audio_path): | |
# parser = options.get_generation_parser() | |
# TODO: replace this workaround with refactoring of `AudioPretraining` | |
parser = argparse.ArgumentParser(description='Process some integers.') | |
parser.add_argument( | |
"--arch", | |
"-a", | |
metavar="ARCH", | |
default="wav2vec2", | |
help="Model architecture. For constructing tasks that rely on " | |
"model args (e.g. `AudioPretraining`)", | |
) | |
parser.add_argument('--data', type=str, default='./utils', metavar='data') | |
parser.add_argument('--bpe-tokenizer', type=str, default='./utils/arabic.model') | |
parser.add_argument('--user-dir', type=str, default='./artst/') | |
parser.add_argument('--task', type=str, default='artst') | |
parser.add_argument('--t5-task', type=str, default='s2t') | |
parser.add_argument('--path', type=str, default='./ckpts/mgb2_asr.pt') | |
parser.add_argument('--ctc-weight', type=float, default=0.25) | |
parser.add_argument('--max-tokens', type=int, default=350000) | |
parser.add_argument('--beam', type=int, default=5) | |
parser.add_argument('--scoring', type=str, default='wer') | |
parser.add_argument('--max-len-a', type=float, default=0) | |
parser.add_argument('--max-len-b', type=int, default=1000) | |
parser.add_argument('--sample-rate', type=int, default=16000) | |
parser.add_argument('--batch-size', type=int, default=1) | |
# parser.add_argument('--num-workers', type=int, default=4) | |
parser.add_argument('--seed', type=int, default=4) | |
parser.add_argument('--normalize', type=bool, default=True) | |
args = parser.parse_args() | |
return main(args, audio_path=audio_path) | |
text_box = gr.Textbox(label="Arabic Text") | |
input_audio = gr.Audio(label="Upload Audio", type="filepath", sources="upload") | |
title="ArTST: Arabic Speech Recognition" | |
description="ArTST: Arabic text and speech transformer based on the T5 transformer. This space demonstarates the ASR checkpoint finetuned on \ | |
the MGB-2 dataset. The model is pre-trained on the MGB-2 dataset." | |
examples=["samples/sample_audio.wav"] | |
article = """ | |
<div style='margin:20px auto;'> | |
<p>References: <a href="https://arxiv.org/abs/2310.16621">ArTST paper</a> | | |
<a href="https://github.com/mbzuai-nlp/ArTST">GitHub</a> | | |
<a href="https://huggingface.co/MBZUAI/ArTST">Weights and Tokenizer</a></p> | |
<pre> | |
@misc{toyin2023artst, | |
title={ArTST: Arabic Text and Speech Transformer}, | |
author={Hawau Olamide Toyin and Amirbek Djanibekov and Ajinkya Kulkarni and Hanan Aldarmaki}, | |
year={2023}, | |
eprint={2310.16621}, | |
archivePrefix={arXiv}, | |
primaryClass={cs.CL} | |
} | |
</pre> | |
<p>Speaker embeddings were generated from <a href="http://www.festvox.org/cmu_arctic/">CMU ARCTIC</a>.</p> | |
<p>ArTST is based on <a href="https://arxiv.org/abs/2110.07205">SpeechT5 architecture</a>.</p> | |
</div> | |
""" | |
demo = gr.Interface(inference, \ | |
inputs=input_audio, outputs=text_box, title=title, description=description, examples=examples, article=article) | |
if __name__ == "__main__": | |
demo.launch(share=True) | |