# ---------------------------------------------------------------------------- # SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329) # Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM # Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4 # # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # ---------------------------------------------------------------------------- """ Modified form: https://github.com/facebookresearch/fairseq/blob/272c4c5197250997148fb12c0db6306035f166a4/fairseq_cli/generate.py """ import ast import logging import math import os import sys from argparse import Namespace from itertools import chain import numpy as np import torch from omegaconf import DictConfig from fairseq import checkpoint_utils, options, scoring, tasks, utils from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging import progress_bar from fairseq.logging.meters import StopwatchMeter, TimeMeter def main(cfg: DictConfig): if isinstance(cfg, Namespace): cfg = convert_namespace_to_omegaconf(cfg) assert cfg.common_eval.path is not None, "--path required for generation!" assert ( not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam ), "--sampling requires --nbest to be equal to --beam" assert ( cfg.generation.replace_unk is None or cfg.dataset.dataset_impl == "raw" ), "--replace-unk requires a raw text dataset (--dataset-impl=raw)" if cfg.common_eval.results_path is not None: os.makedirs(cfg.common_eval.results_path, exist_ok=True) output_path = os.path.join( cfg.common_eval.results_path, "generate-{}.txt".format(cfg.dataset.gen_subset), ) with open(output_path, "w", buffering=1, encoding="utf-8") as h: return _main(cfg, h) else: return _main(cfg, sys.stdout) def get_symbols_to_strip_from_output(generator): if hasattr(generator, "symbols_to_strip_from_output"): return generator.symbols_to_strip_from_output else: return {generator.eos} def _main(cfg: DictConfig, output_file): logging.basicConfig( format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=os.environ.get("LOGLEVEL", "INFO").upper(), stream=output_file, ) logger = logging.getLogger("fairseq_cli.generate") utils.import_user_module(cfg.common) if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None: cfg.dataset.max_tokens = 12000 logger.info(cfg) # Fix seed for stochastic decoding if cfg.common.seed is not None and not cfg.generation.no_seed_provided: np.random.seed(cfg.common.seed) utils.set_torch_seed(cfg.common.seed) use_cuda = torch.cuda.is_available() and not cfg.common.cpu # Load dataset splits task = tasks.setup_task(cfg.task) # Set dictionaries try: src_dict = getattr(task, "source_dictionary", None) except NotImplementedError: src_dict = None tgt_dict = task.target_dictionary overrides = ast.literal_eval(cfg.common_eval.model_overrides) # Load ensemble logger.info("loading model(s) from {}".format(cfg.common_eval.path)) models, saved_cfg = checkpoint_utils.load_model_ensemble( utils.split_paths(cfg.common_eval.path), arg_overrides=overrides, task=task, suffix=cfg.checkpoint.checkpoint_suffix, strict=(cfg.checkpoint.checkpoint_shard_count == 1), num_shards=cfg.checkpoint.checkpoint_shard_count, ) # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task) if cfg.generation.lm_path is not None: overrides["data"] = cfg.task.data try: lms, _ = checkpoint_utils.load_model_ensemble( [cfg.generation.lm_path], arg_overrides=overrides, task=None ) except: logger.warning( f"Failed to load language model! Please make sure that the language model dict is the same " f"as target dict and is located in the data dir ({cfg.task.data})" ) raise assert len(lms) == 1 else: lms = [None] # Optimize ensemble for generation for model in chain(models, lms): if model is None: continue if cfg.common.fp16: model.half() if use_cuda and not cfg.distributed_training.pipeline_model_parallel: model.cuda() model.prepare_for_inference_(cfg) def _fp_convert_sample(sample): def apply_half(t): if t.dtype is torch.float32: return t.to(dtype=torch.half) return t def apply_bfloat16(t): if t.dtype is torch.float32: return t.to(dtype=torch.bfloat16) return t if cfg.common.fp16: sample = utils.apply_to_sample(apply_half, sample) if cfg.common.bf16: sample = utils.apply_to_sample(apply_bfloat16, sample) return sample # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) align_dict = utils.load_align_dict(cfg.generation.replace_unk) # Load dataset (possibly sharded) itr = task.get_batch_iterator( dataset=task.dataset(cfg.dataset.gen_subset), max_tokens=cfg.dataset.max_tokens, max_sentences=cfg.dataset.batch_size, max_positions=utils.resolve_max_positions( task.max_positions(), *[m.max_positions() for m in models] ), ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=cfg.dataset.required_batch_size_multiple, seed=cfg.common.seed, num_shards=cfg.distributed_training.distributed_world_size, shard_id=cfg.distributed_training.distributed_rank, num_workers=cfg.dataset.num_workers, data_buffer_size=cfg.dataset.data_buffer_size, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=cfg.common.log_format, log_interval=cfg.common.log_interval, default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), ) # Initialize generator gen_timer = StopwatchMeter() extra_gen_cls_kwargs = {"lm_model": lms[0], "lm_weight": cfg.generation.lm_weight} generator = task.build_generator( models, cfg.generation, extra_gen_cls_kwargs=extra_gen_cls_kwargs ) # Handle tokenization and BPE tokenizer = task.build_tokenizer(cfg.tokenizer) bpe = task.build_bpe(cfg.bpe) def decode_fn(x): if bpe is not None: x = bpe.decode(x) if tokenizer is not None: x = tokenizer.decode(x) return x scorer = scoring.build_scorer(cfg.scoring, None) num_sentences = 0 has_target = True wps_meter = TimeMeter() for sample in progress: sample = utils.move_to_cuda(sample) if use_cuda else sample sample = _fp_convert_sample(sample) if "net_input" not in sample: continue prefix_tokens = None if cfg.generation.prefix_size > 0: prefix_tokens = sample["target"][:, : cfg.generation.prefix_size] constraints = None if "constraints" in sample: constraints = sample["constraints"] gen_timer.start() hypos = task.inference_step( generator, models[0], sample, prefix_tokens=prefix_tokens, constraints=constraints, ) num_generated_tokens = sum(len(h["unit"]) for h in hypos) gen_timer.stop(num_generated_tokens) for i, sample_id in enumerate(sample["id"].tolist()): has_target = sample["target"] is not None # Remove padding if "src_tokens" in sample["net_input"]: src_tokens = utils.strip_pad( sample["net_input"]["src_tokens"][i, :], tgt_dict.pad() ).cpu() else: src_tokens = None target_tokens = None if has_target: target_tokens = ( utils.strip_pad(sample["target"][i, :], tgt_dict.pad()).cpu() ) # Either retrieve the original sentences or regenerate them from tokens. if align_dict is not None: src_str = task.dataset(cfg.dataset.gen_subset).src.get_original_text( sample_id ) target_str = task.dataset(cfg.dataset.gen_subset).tgt.get_original_text( sample_id ) else: if src_dict is not None: src_str = src_dict.string(src_tokens, cfg.common_eval.post_process) else: src_str = "" if has_target: target_str = " ".join(map(str, target_tokens.numpy().tolist())) src_str = decode_fn(src_str) if not cfg.common_eval.quiet: if src_dict is not None: print("S-{}\t{}".format(sample_id, src_str), file=output_file) if has_target: print("T-{}\t{}".format(sample_id, target_str), file=output_file) # Process top predictions j = 0 hypo = hypos[i] hypo_tokens = hypo["unit"].int().cpu() hypo_str = " ".join(map(str, hypo_tokens.numpy().tolist())) alignment = None detok_hypo_str = hypo_str # add duration prediction hypo_duration = " ".join(map(str, hypo["duration"].int().cpu().numpy().tolist())) hypo_fa_src_str = src_dict.string(hypo["fa_src"].cpu().numpy(), cfg.common_eval.post_process) # hypo_fa_src_str = " ".join(map(str, hypo["fa_src"].int().cpu().numpy() - 4)) if not cfg.common_eval.quiet: # score = hypo["score"] / math.log(2) # convert to base 2 score = 0.00 # original hypothesis (after tokenization and BPE) # print( # "H-{}\t{}\t{}".format(sample_id, score, hypo_str), # file=output_file, # ) # detokenized hypothesis print( "D-{}\t{}\t{}".format(sample_id, score, detok_hypo_str), file=output_file, ) # duration prediction print( "L-{}\t{}\t{}".format(sample_id, score, hypo_duration), file=output_file, ) # force-aligned upsampled src-tokens print( "U-{}\t{}\t{}".format(sample_id, score, hypo_fa_src_str), file=output_file, ) # print( # "P-{}\t{}".format( # sample_id, # " ".join( # map( # lambda x: "{:.4f}".format(x), # # convert from base e to base 2 # hypo["positional_scores"] # .div_(math.log(2)) # .tolist(), # ) # ), # ), # file=output_file, # ) if cfg.generation.print_alignment == "hard": print( "A-{}\t{}".format( sample_id, " ".join( [ "{}-{}".format(src_idx, tgt_idx) for src_idx, tgt_idx in alignment ] ), ), file=output_file, ) if cfg.generation.print_alignment == "soft": print( "A-{}\t{}".format( sample_id, " ".join( [",".join(src_probs) for src_probs in alignment] ), ), file=output_file, ) # Score only the top hypothesis if has_target and j == 0: if hasattr(scorer, "add_string"): scorer.add_string(target_str, detok_hypo_str) else: scorer.add(target_tokens, hypo_tokens) wps_meter.update(num_generated_tokens) progress.log({"wps": round(wps_meter.avg)}) num_sentences += ( sample["nsentences"] if "nsentences" in sample else sample["id"].numel() ) logger.info("NOTE: hypothesis and token scores are output in base 2") logger.info( "Translated {:,} sentences ({:,} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)".format( num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1.0 / gen_timer.avg, ) ) if has_target: if cfg.bpe and not cfg.generation.sacrebleu: if cfg.common_eval.post_process: logger.warning( "BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization" ) else: logger.warning( "If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words. Use --sacrebleu for standard 13a BLEU tokenization" ) # use print to be consistent with other main outputs: S-, H-, T-, D- and so on print( "Generate {} with beam={}: {}".format( cfg.dataset.gen_subset, cfg.generation.beam, scorer.result_string() ), file=output_file, ) return scorer def cli_main(): parser = options.get_generation_parser() # TODO: replace this workaround with refactoring of `AudioPretraining` parser.add_argument( "--arch", "-a", metavar="ARCH", default="wav2vec2", help="Model architecture. For constructing tasks that rely on " "model args (e.g. `AudioPretraining`)", ) args = options.parse_args_and_arch(parser) main(args) if __name__ == "__main__": cli_main()