Spaces:
Runtime error
Runtime error
# ---------------------------------------------------------------------------- | |
# SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329) | |
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM | |
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4 | |
# | |
# Copyright (c) 2022 Microsoft | |
# Licensed under The MIT License [see LICENSE for details] | |
# ---------------------------------------------------------------------------- | |
""" | |
Modified form: https://github.com/facebookresearch/fairseq/blob/272c4c5197250997148fb12c0db6306035f166a4/examples/speech_recognition/new/infer.py | |
1. add "utils.import_user_module(cfg.common)" so that usr-dir can be loaded | |
""" | |
import ast | |
import hashlib | |
import logging | |
import os | |
import shutil | |
import sys | |
from dataclasses import dataclass, field, is_dataclass | |
from pathlib import Path | |
from typing import Any, Dict, List, Optional, Tuple, Union | |
import editdistance | |
import torch | |
import torch.distributed as dist | |
import examples | |
from examples.speech_recognition.new.decoders.decoder_config import ( | |
DecoderConfig, | |
FlashlightDecoderConfig, | |
) | |
from examples.speech_recognition.new.decoders.decoder import Decoder | |
from fairseq import checkpoint_utils, distributed_utils, progress_bar, tasks, utils | |
from fairseq.data.data_utils import post_process | |
from fairseq.dataclass.configs import ( | |
CheckpointConfig, | |
CommonConfig, | |
CommonEvalConfig, | |
DatasetConfig, | |
DistributedTrainingConfig, | |
FairseqDataclass, | |
) | |
from fairseq.logging.meters import StopwatchMeter, TimeMeter | |
from fairseq.logging.progress_bar import BaseProgressBar | |
from fairseq.models.fairseq_model import FairseqModel | |
from omegaconf import OmegaConf | |
import hydra | |
from hydra.core.config_store import ConfigStore | |
logging.root.setLevel(logging.INFO) | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
config_path = Path(examples.speech_recognition.new.__path__[0]).resolve() / "conf" | |
class DecodingConfig(DecoderConfig, FlashlightDecoderConfig): | |
unique_wer_file: bool = field( | |
default=False, | |
metadata={"help": "If set, use a unique file for storing WER"}, | |
) | |
results_path: Optional[str] = field( | |
default=None, | |
metadata={ | |
"help": "If set, write hypothesis and reference sentences into this directory" | |
}, | |
) | |
class InferConfig(FairseqDataclass): | |
task: Any = None | |
decoding: DecodingConfig = DecodingConfig() | |
common: CommonConfig = CommonConfig() | |
common_eval: CommonEvalConfig = CommonEvalConfig() | |
checkpoint: CheckpointConfig = CheckpointConfig() | |
distributed_training: DistributedTrainingConfig = DistributedTrainingConfig() | |
dataset: DatasetConfig = DatasetConfig() | |
is_ax: bool = field( | |
default=False, | |
metadata={ | |
"help": "if true, assumes we are using ax for tuning and returns a tuple for ax to consume" | |
}, | |
) | |
def reset_logging(): | |
root = logging.getLogger() | |
for handler in root.handlers: | |
root.removeHandler(handler) | |
root.setLevel(os.environ.get("LOGLEVEL", "INFO").upper()) | |
handler = logging.StreamHandler(sys.stdout) | |
handler.setFormatter( | |
logging.Formatter( | |
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", | |
datefmt="%Y-%m-%d %H:%M:%S", | |
) | |
) | |
root.addHandler(handler) | |
class InferenceProcessor: | |
cfg: InferConfig | |
def __init__(self, cfg: InferConfig) -> None: | |
self.cfg = cfg | |
self.task = tasks.setup_task(cfg.task) | |
models, saved_cfg = self.load_model_ensemble() | |
self.models = models | |
self.saved_cfg = saved_cfg | |
self.tgt_dict = self.task.target_dictionary | |
self.task.load_dataset( | |
self.cfg.dataset.gen_subset, | |
task_cfg=saved_cfg.task, | |
) | |
self.generator = Decoder(cfg.decoding, self.tgt_dict) | |
self.gen_timer = StopwatchMeter() | |
self.wps_meter = TimeMeter() | |
self.num_sentences = 0 | |
self.total_errors = 0 | |
self.total_length = 0 | |
self.hypo_words_file = None | |
self.hypo_units_file = None | |
self.ref_words_file = None | |
self.ref_units_file = None | |
self.progress_bar = self.build_progress_bar() | |
def __enter__(self) -> "InferenceProcessor": | |
if self.cfg.decoding.results_path is not None: | |
self.hypo_words_file = self.get_res_file("hypo.word") | |
self.hypo_units_file = self.get_res_file("hypo.units") | |
self.ref_words_file = self.get_res_file("ref.word") | |
self.ref_units_file = self.get_res_file("ref.units") | |
return self | |
def __exit__(self, *exc) -> bool: | |
if self.cfg.decoding.results_path is not None: | |
self.hypo_words_file.close() | |
self.hypo_units_file.close() | |
self.ref_words_file.close() | |
self.ref_units_file.close() | |
return False | |
def __iter__(self) -> Any: | |
for sample in self.progress_bar: | |
if not self.cfg.common.cpu: | |
sample = utils.move_to_cuda(sample) | |
# Happens on the last batch. | |
if "net_input" not in sample: | |
continue | |
yield sample | |
def log(self, *args, **kwargs): | |
self.progress_bar.log(*args, **kwargs) | |
def print(self, *args, **kwargs): | |
self.progress_bar.print(*args, **kwargs) | |
def get_res_file(self, fname: str) -> None: | |
fname = os.path.join(self.cfg.decoding.results_path, fname) | |
if self.data_parallel_world_size > 1: | |
fname = f"{fname}.{self.data_parallel_rank}" | |
return open(fname, "w", buffering=1) | |
def merge_shards(self) -> None: | |
"""Merges all shard files into shard 0, then removes shard suffix.""" | |
shard_id = self.data_parallel_rank | |
num_shards = self.data_parallel_world_size | |
if self.data_parallel_world_size > 1: | |
def merge_shards_with_root(fname: str) -> None: | |
fname = os.path.join(self.cfg.decoding.results_path, fname) | |
logger.info("Merging %s on shard %d", fname, shard_id) | |
base_fpath = Path(f"{fname}.0") | |
with open(base_fpath, "a") as out_file: | |
for s in range(1, num_shards): | |
shard_fpath = Path(f"{fname}.{s}") | |
with open(shard_fpath, "r") as in_file: | |
for line in in_file: | |
out_file.write(line) | |
shard_fpath.unlink() | |
shutil.move(f"{fname}.0", fname) | |
dist.barrier() # ensure all shards finished writing | |
if shard_id == (0 % num_shards): | |
merge_shards_with_root("hypo.word") | |
if shard_id == (1 % num_shards): | |
merge_shards_with_root("hypo.units") | |
if shard_id == (2 % num_shards): | |
merge_shards_with_root("ref.word") | |
if shard_id == (3 % num_shards): | |
merge_shards_with_root("ref.units") | |
dist.barrier() | |
def optimize_model(self, model: FairseqModel) -> None: | |
model.make_generation_fast_() | |
if self.cfg.common.fp16: | |
model.half() | |
if not self.cfg.common.cpu: | |
model.cuda() | |
def load_model_ensemble(self) -> Tuple[List[FairseqModel], FairseqDataclass]: | |
arg_overrides = ast.literal_eval(self.cfg.common_eval.model_overrides) | |
models, saved_cfg = checkpoint_utils.load_model_ensemble( | |
utils.split_paths(self.cfg.common_eval.path, separator="\\"), | |
arg_overrides=arg_overrides, | |
task=self.task, | |
suffix=self.cfg.checkpoint.checkpoint_suffix, | |
strict=(self.cfg.checkpoint.checkpoint_shard_count == 1), | |
num_shards=self.cfg.checkpoint.checkpoint_shard_count, | |
) | |
for model in models: | |
self.optimize_model(model) | |
return models, saved_cfg | |
def get_dataset_itr(self, disable_iterator_cache: bool = False) -> None: | |
return self.task.get_batch_iterator( | |
dataset=self.task.dataset(self.cfg.dataset.gen_subset), | |
max_tokens=self.cfg.dataset.max_tokens, | |
max_sentences=self.cfg.dataset.batch_size, | |
max_positions=(sys.maxsize, sys.maxsize), | |
ignore_invalid_inputs=self.cfg.dataset.skip_invalid_size_inputs_valid_test, | |
required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple, | |
seed=self.cfg.common.seed, | |
num_shards=self.data_parallel_world_size, | |
shard_id=self.data_parallel_rank, | |
num_workers=self.cfg.dataset.num_workers, | |
data_buffer_size=self.cfg.dataset.data_buffer_size, | |
disable_iterator_cache=disable_iterator_cache, | |
).next_epoch_itr(shuffle=False) | |
def build_progress_bar( | |
self, | |
epoch: Optional[int] = None, | |
prefix: Optional[str] = None, | |
default_log_format: str = "tqdm", | |
) -> BaseProgressBar: | |
return progress_bar.progress_bar( | |
iterator=self.get_dataset_itr(), | |
log_format=self.cfg.common.log_format, | |
log_interval=self.cfg.common.log_interval, | |
epoch=epoch, | |
prefix=prefix, | |
tensorboard_logdir=self.cfg.common.tensorboard_logdir, | |
default_log_format=default_log_format, | |
) | |
def data_parallel_world_size(self): | |
if self.cfg.distributed_training.distributed_world_size == 1: | |
return 1 | |
return distributed_utils.get_data_parallel_world_size() | |
def data_parallel_rank(self): | |
if self.cfg.distributed_training.distributed_world_size == 1: | |
return 0 | |
return distributed_utils.get_data_parallel_rank() | |
def process_sentence( | |
self, | |
sample: Dict[str, Any], | |
hypo: Dict[str, Any], | |
sid: int, | |
batch_id: int, | |
) -> Tuple[int, int]: | |
speaker = None # Speaker can't be parsed from dataset. | |
if "target_label" in sample: | |
toks = sample["target_label"] | |
else: | |
toks = sample["target"] | |
toks = toks[batch_id, :] | |
# Processes hypothesis. | |
hyp_pieces = self.tgt_dict.string(hypo["tokens"].int().cpu()) | |
if "words" in hypo: | |
hyp_words = " ".join(hypo["words"]) | |
else: | |
hyp_words = post_process(hyp_pieces, self.cfg.common_eval.post_process) | |
# Processes target. | |
target_tokens = utils.strip_pad(toks, self.tgt_dict.pad()) | |
tgt_pieces = self.tgt_dict.string(target_tokens.int().cpu()) | |
tgt_words = post_process(tgt_pieces, self.cfg.common_eval.post_process) | |
if self.cfg.decoding.results_path is not None: | |
print(f"{hyp_pieces} ({speaker}-{sid})", file=self.hypo_units_file) | |
print(f"{hyp_words} ({speaker}-{sid})", file=self.hypo_words_file) | |
print(f"{tgt_pieces} ({speaker}-{sid})", file=self.ref_units_file) | |
print(f"{tgt_words} ({speaker}-{sid})", file=self.ref_words_file) | |
if not self.cfg.common_eval.quiet: | |
logger.info(f"HYPO: {hyp_words}") | |
logger.info(f"REF: {tgt_words}") | |
logger.info("---------------------") | |
hyp_words, tgt_words = hyp_words.split(), tgt_words.split() | |
return editdistance.eval(hyp_words, tgt_words), len(tgt_words) | |
def process_sample(self, sample: Dict[str, Any]) -> None: | |
self.gen_timer.start() | |
hypos = self.task.inference_step( | |
generator=self.generator, | |
models=self.models, | |
sample=sample, | |
) | |
num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos) | |
self.gen_timer.stop(num_generated_tokens) | |
self.wps_meter.update(num_generated_tokens) | |
for batch_id, sample_id in enumerate(sample["id"].tolist()): | |
errs, length = self.process_sentence( | |
sample=sample, | |
sid=sample_id, | |
batch_id=batch_id, | |
hypo=hypos[batch_id][0], | |
) | |
self.total_errors += errs | |
self.total_length += length | |
self.log({"wps": round(self.wps_meter.avg)}) | |
if "nsentences" in sample: | |
self.num_sentences += sample["nsentences"] | |
else: | |
self.num_sentences += sample["id"].numel() | |
def log_generation_time(self) -> None: | |
logger.info( | |
"Processed %d sentences (%d tokens) in %.1fs %.2f " | |
"sentences per second, %.2f tokens per second)", | |
self.num_sentences, | |
self.gen_timer.n, | |
self.gen_timer.sum, | |
self.num_sentences / (self.gen_timer.sum + 1e-6), | |
1.0 / (self.gen_timer.avg + 1e-6), | |
) | |
def parse_wer(wer_file: Path) -> float: | |
with open(wer_file, "r") as f: | |
return float(f.readline().strip().split(" ")[1]) | |
def get_wer_file(cfg: InferConfig) -> Path: | |
"""Hashes the decoding parameters to a unique file ID.""" | |
base_path = "wer" | |
if cfg.decoding.results_path is not None: | |
base_path = os.path.join(cfg.decoding.results_path, base_path) | |
if cfg.decoding.unique_wer_file: | |
yaml_str = OmegaConf.to_yaml(cfg.decoding) | |
fid = int(hashlib.md5(yaml_str.encode("utf-8")).hexdigest(), 16) | |
return Path(f"{base_path}.{fid % 1000000}") | |
else: | |
return Path(base_path) | |
def main(cfg: InferConfig) -> float: | |
"""Entry point for main processing logic. | |
Args: | |
cfg: The inferance configuration to use. | |
wer: Optional shared memory pointer for returning the WER. If not None, | |
the final WER value will be written here instead of being returned. | |
Returns: | |
The final WER if `wer` is None, otherwise None. | |
""" | |
utils.import_user_module(cfg.common) | |
yaml_str, wer_file = OmegaConf.to_yaml(cfg.decoding), get_wer_file(cfg) | |
# Validates the provided configuration. | |
if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None: | |
cfg.dataset.max_tokens = 4000000 | |
if not cfg.common.cpu and not torch.cuda.is_available(): | |
raise ValueError("CUDA not found; set `cpu=True` to run without CUDA") | |
logger.info(cfg.common_eval.path) | |
with InferenceProcessor(cfg) as processor: | |
for sample in processor: | |
processor.process_sample(sample) | |
processor.log_generation_time() | |
if cfg.decoding.results_path is not None: | |
processor.merge_shards() | |
errs_t, leng_t = processor.total_errors, processor.total_length | |
if cfg.common.cpu: | |
logger.warning("Merging WER requires CUDA.") | |
elif processor.data_parallel_world_size > 1: | |
stats = torch.LongTensor([errs_t, leng_t]).cuda() | |
dist.all_reduce(stats, op=dist.ReduceOp.SUM) | |
errs_t, leng_t = stats[0].item(), stats[1].item() | |
wer = errs_t * 100.0 / leng_t | |
if distributed_utils.is_master(cfg.distributed_training): | |
with open(wer_file, "w") as f: | |
f.write( | |
( | |
f"WER: {wer}\n" | |
f"err / num_ref_words = {errs_t} / {leng_t}\n\n" | |
f"{yaml_str}" | |
) | |
) | |
return wer | |
def hydra_main(cfg: InferConfig) -> Union[float, Tuple[float, Optional[float]]]: | |
container = OmegaConf.to_container(cfg, resolve=True, enum_to_str=True) | |
cfg = OmegaConf.create(container) | |
OmegaConf.set_struct(cfg, True) | |
if cfg.common.reset_logging: | |
reset_logging() | |
utils.import_user_module(cfg.common) | |
# logger.info("Config:\n%s", OmegaConf.to_yaml(cfg)) | |
wer = float("inf") | |
try: | |
if cfg.common.profile: | |
with torch.cuda.profiler.profile(): | |
with torch.autograd.profiler.emit_nvtx(): | |
distributed_utils.call_main(cfg, main) | |
else: | |
distributed_utils.call_main(cfg, main) | |
wer = parse_wer(get_wer_file(cfg)) | |
except BaseException as e: # pylint: disable=broad-except | |
if not cfg.common.suppress_crashes: | |
raise | |
else: | |
logger.error("Crashed! %s", str(e)) | |
logger.info("Word error rate: %.4f", wer) | |
if cfg.is_ax: | |
return wer, None | |
return wer | |
def cli_main() -> None: | |
try: | |
from hydra._internal.utils import ( | |
get_args, | |
) # pylint: disable=import-outside-toplevel | |
cfg_name = get_args().config_name or "infer" | |
except ImportError: | |
logger.warning("Failed to get config name from hydra args") | |
cfg_name = "infer" | |
cs = ConfigStore.instance() | |
cs.store(name=cfg_name, node=InferConfig) | |
for k in InferConfig.__dataclass_fields__: | |
if is_dataclass(InferConfig.__dataclass_fields__[k].type): | |
v = InferConfig.__dataclass_fields__[k].default | |
cs.store(name=k, node=v) | |
hydra_main() # pylint: disable=no-value-for-parameter | |
if __name__ == "__main__": | |
cli_main() | |