JustinLin610
update
8437114
#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Run inference for pre-processed data with a trained model.
"""
import ast
from collections import namedtuple
from dataclasses import dataclass, field
from enum import Enum, auto
import hydra
from hydra.core.config_store import ConfigStore
import logging
import math
import os
from omegaconf import OmegaConf
from typing import Optional
import sys
import editdistance
import torch
from hydra.core.hydra_config import HydraConfig
from fairseq import checkpoint_utils, progress_bar, tasks, utils
from fairseq.data.data_utils import post_process
from fairseq.dataclass.configs import FairseqDataclass, FairseqConfig
from fairseq.logging.meters import StopwatchMeter
from omegaconf import open_dict
from examples.speech_recognition.kaldi.kaldi_decoder import KaldiDecoderConfig
logging.root.setLevel(logging.INFO)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__)
class DecoderType(Enum):
VITERBI = auto()
KENLM = auto()
FAIRSEQ = auto()
KALDI = auto()
@dataclass
class UnsupGenerateConfig(FairseqDataclass):
fairseq: FairseqConfig = FairseqConfig()
lm_weight: float = field(
default=2.0,
metadata={"help": "language model weight"},
)
w2l_decoder: DecoderType = field(
default=DecoderType.VITERBI,
metadata={"help": "type of decoder to use"},
)
kaldi_decoder_config: Optional[KaldiDecoderConfig] = None
lexicon: Optional[str] = field(
default=None,
metadata={
"help": "path to lexicon. This is also used to 'phonemize' for unsupvised param tuning"
},
)
lm_model: Optional[str] = field(
default=None,
metadata={"help": "path to language model (kenlm or fairseq)"},
)
unit_lm: bool = field(
default=False,
metadata={"help": "whether to use unit lm"},
)
beam_threshold: float = field(
default=50.0,
metadata={"help": "beam score threshold"},
)
beam_size_token: float = field(
default=100.0,
metadata={"help": "max tokens per beam"},
)
beam: int = field(
default=5,
metadata={"help": "decoder beam size"},
)
nbest: int = field(
default=1,
metadata={"help": "number of results to return"},
)
word_score: float = field(
default=1.0,
metadata={"help": "word score to add at end of word"},
)
unk_weight: float = field(
default=-math.inf,
metadata={"help": "unknown token weight"},
)
sil_weight: float = field(
default=0.0,
metadata={"help": "silence token weight"},
)
targets: Optional[str] = field(
default=None,
metadata={"help": "extension of ground truth labels to compute UER"},
)
results_path: Optional[str] = field(
default=None,
metadata={"help": "where to store results"},
)
post_process: Optional[str] = field(
default=None,
metadata={"help": "how to post process results"},
)
vocab_usage_power: float = field(
default=2,
metadata={"help": "for unsupervised param tuning"},
)
viterbi_transcript: Optional[str] = field(
default=None,
metadata={"help": "for unsupervised param tuning"},
)
min_lm_ppl: float = field(
default=0,
metadata={"help": "for unsupervised param tuning"},
)
min_vt_uer: float = field(
default=0,
metadata={"help": "for unsupervised param tuning"},
)
blank_weight: float = field(
default=0,
metadata={"help": "value to add or set for blank emission"},
)
blank_mode: str = field(
default="set",
metadata={
"help": "can be add or set, how to modify blank emission with blank weight"
},
)
sil_is_blank: bool = field(
default=False,
metadata={"help": "if true, <SIL> token is same as blank token"},
)
unsupervised_tuning: bool = field(
default=False,
metadata={
"help": "if true, returns a score based on unsupervised param selection metric instead of UER"
},
)
is_ax: bool = field(
default=False,
metadata={
"help": "if true, assumes we are using ax for tuning and returns a tuple for ax to consume"
},
)
def get_dataset_itr(cfg, task):
return task.get_batch_iterator(
dataset=task.dataset(cfg.fairseq.dataset.gen_subset),
max_tokens=cfg.fairseq.dataset.max_tokens,
max_sentences=cfg.fairseq.dataset.batch_size,
max_positions=(sys.maxsize, sys.maxsize),
ignore_invalid_inputs=cfg.fairseq.dataset.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=cfg.fairseq.dataset.required_batch_size_multiple,
num_shards=cfg.fairseq.dataset.num_shards,
shard_id=cfg.fairseq.dataset.shard_id,
num_workers=cfg.fairseq.dataset.num_workers,
data_buffer_size=cfg.fairseq.dataset.data_buffer_size,
).next_epoch_itr(shuffle=False)
def process_predictions(
cfg: UnsupGenerateConfig,
hypos,
tgt_dict,
target_tokens,
res_files,
):
retval = []
word_preds = []
transcriptions = []
dec_scores = []
for i, hypo in enumerate(hypos[: min(len(hypos), cfg.nbest)]):
if torch.is_tensor(hypo["tokens"]):
tokens = hypo["tokens"].int().cpu()
tokens = tokens[tokens >= tgt_dict.nspecial]
hyp_pieces = tgt_dict.string(tokens)
else:
hyp_pieces = " ".join(hypo["tokens"])
if "words" in hypo and len(hypo["words"]) > 0:
hyp_words = " ".join(hypo["words"])
else:
hyp_words = post_process(hyp_pieces, cfg.post_process)
to_write = {}
if res_files is not None:
to_write[res_files["hypo.units"]] = hyp_pieces
to_write[res_files["hypo.words"]] = hyp_words
tgt_words = ""
if target_tokens is not None:
if isinstance(target_tokens, str):
tgt_pieces = tgt_words = target_tokens
else:
tgt_pieces = tgt_dict.string(target_tokens)
tgt_words = post_process(tgt_pieces, cfg.post_process)
if res_files is not None:
to_write[res_files["ref.units"]] = tgt_pieces
to_write[res_files["ref.words"]] = tgt_words
if not cfg.fairseq.common_eval.quiet:
logger.info(f"HYPO {i}:" + hyp_words)
if tgt_words:
logger.info("TARGET:" + tgt_words)
if "am_score" in hypo and "lm_score" in hypo:
logger.info(
f"DECODER AM SCORE: {hypo['am_score']}, DECODER LM SCORE: {hypo['lm_score']}, DECODER SCORE: {hypo['score']}"
)
elif "score" in hypo:
logger.info(f"DECODER SCORE: {hypo['score']}")
logger.info("___________________")
hyp_words_arr = hyp_words.split()
tgt_words_arr = tgt_words.split()
retval.append(
(
editdistance.eval(hyp_words_arr, tgt_words_arr),
len(hyp_words_arr),
len(tgt_words_arr),
hyp_pieces,
hyp_words,
)
)
word_preds.append(hyp_words_arr)
transcriptions.append(to_write)
dec_scores.append(-hypo.get("score", 0)) # negate cuz kaldi returns NLL
if len(retval) > 1:
best = None
for r, t in zip(retval, transcriptions):
if best is None or r[0] < best[0][0]:
best = r, t
for dest, tran in best[1].items():
print(tran, file=dest)
dest.flush()
return best[0]
assert len(transcriptions) == 1
for dest, tran in transcriptions[0].items():
print(tran, file=dest)
return retval[0]
def prepare_result_files(cfg: UnsupGenerateConfig):
def get_res_file(file_prefix):
if cfg.fairseq.dataset.num_shards > 1:
file_prefix = f"{cfg.fairseq.dataset.shard_id}_{file_prefix}"
path = os.path.join(
cfg.results_path,
"{}{}.txt".format(
cfg.fairseq.dataset.gen_subset,
file_prefix,
),
)
return open(path, "w", buffering=1)
if not cfg.results_path:
return None
return {
"hypo.words": get_res_file(""),
"hypo.units": get_res_file("_units"),
"ref.words": get_res_file("_ref"),
"ref.units": get_res_file("_ref_units"),
"hypo.nbest.words": get_res_file("_nbest_words"),
}
def optimize_models(cfg: UnsupGenerateConfig, use_cuda, models):
"""Optimize ensemble for generation"""
for model in models:
model.eval()
if cfg.fairseq.common.fp16:
model.half()
if use_cuda:
model.cuda()
GenResult = namedtuple(
"GenResult",
[
"count",
"errs_t",
"gen_timer",
"lengths_hyp_unit_t",
"lengths_hyp_t",
"lengths_t",
"lm_score_t",
"num_feats",
"num_sentences",
"num_symbols",
"vt_err_t",
"vt_length_t",
],
)
def generate(cfg: UnsupGenerateConfig, models, saved_cfg, use_cuda):
task = tasks.setup_task(cfg.fairseq.task)
saved_cfg.task.labels = cfg.fairseq.task.labels
task.load_dataset(cfg.fairseq.dataset.gen_subset, task_cfg=saved_cfg.task)
# Set dictionary
tgt_dict = task.target_dictionary
logger.info(
"| {} {} {} examples".format(
cfg.fairseq.task.data,
cfg.fairseq.dataset.gen_subset,
len(task.dataset(cfg.fairseq.dataset.gen_subset)),
)
)
# Load dataset (possibly sharded)
itr = get_dataset_itr(cfg, task)
# Initialize generator
gen_timer = StopwatchMeter()
def build_generator(cfg: UnsupGenerateConfig):
w2l_decoder = cfg.w2l_decoder
if w2l_decoder == DecoderType.VITERBI:
from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder
return W2lViterbiDecoder(cfg, task.target_dictionary)
elif w2l_decoder == DecoderType.KENLM:
from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
return W2lKenLMDecoder(cfg, task.target_dictionary)
elif w2l_decoder == DecoderType.FAIRSEQ:
from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder
return W2lFairseqLMDecoder(cfg, task.target_dictionary)
elif w2l_decoder == DecoderType.KALDI:
from examples.speech_recognition.kaldi.kaldi_decoder import KaldiDecoder
assert cfg.kaldi_decoder_config is not None
return KaldiDecoder(
cfg.kaldi_decoder_config,
cfg.beam,
)
else:
raise NotImplementedError(
"only wav2letter decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment but found "
+ str(w2l_decoder)
)
generator = build_generator(cfg)
kenlm = None
fairseq_lm = None
if cfg.lm_model is not None:
import kenlm
kenlm = kenlm.Model(cfg.lm_model)
num_sentences = 0
if cfg.results_path is not None and not os.path.exists(cfg.results_path):
os.makedirs(cfg.results_path)
res_files = prepare_result_files(cfg)
errs_t = 0
lengths_hyp_t = 0
lengths_hyp_unit_t = 0
lengths_t = 0
count = 0
num_feats = 0
all_hyp_pieces = []
all_hyp_words = []
num_symbols = (
len([s for s in tgt_dict.symbols if not s.startswith("madeup")])
- tgt_dict.nspecial
)
targets = None
if cfg.targets is not None:
tgt_path = os.path.join(
cfg.fairseq.task.data, cfg.fairseq.dataset.gen_subset + "." + cfg.targets
)
if os.path.exists(tgt_path):
with open(tgt_path, "r") as f:
targets = f.read().splitlines()
viterbi_transcript = None
if cfg.viterbi_transcript is not None and len(cfg.viterbi_transcript) > 0:
logger.info(f"loading viterbi transcript from {cfg.viterbi_transcript}")
with open(cfg.viterbi_transcript, "r") as vf:
viterbi_transcript = vf.readlines()
viterbi_transcript = [v.rstrip().split() for v in viterbi_transcript]
gen_timer.start()
start = 0
end = len(itr)
hypo_futures = None
if cfg.w2l_decoder == DecoderType.KALDI:
logger.info("Extracting features")
hypo_futures = []
samples = []
with progress_bar.build_progress_bar(cfg.fairseq.common, itr) as t:
for i, sample in enumerate(t):
if "net_input" not in sample or i < start or i >= end:
continue
if "padding_mask" not in sample["net_input"]:
sample["net_input"]["padding_mask"] = None
hypos, num_feats = gen_hypos(
generator, models, num_feats, sample, task, use_cuda
)
hypo_futures.append(hypos)
samples.append(sample)
itr = list(zip(hypo_futures, samples))
start = 0
end = len(itr)
logger.info("Finished extracting features")
with progress_bar.build_progress_bar(cfg.fairseq.common, itr) as t:
for i, sample in enumerate(t):
if i < start or i >= end:
continue
if hypo_futures is not None:
hypos, sample = sample
hypos = [h.result() for h in hypos]
else:
if "net_input" not in sample:
continue
hypos, num_feats = gen_hypos(
generator, models, num_feats, sample, task, use_cuda
)
for i, sample_id in enumerate(sample["id"].tolist()):
if targets is not None:
target_tokens = targets[sample_id]
elif "target" in sample or "target_label" in sample:
toks = (
sample["target"][i, :]
if "target_label" not in sample
else sample["target_label"][i, :]
)
target_tokens = utils.strip_pad(toks, tgt_dict.pad()).int().cpu()
else:
target_tokens = None
# Process top predictions
(
errs,
length_hyp,
length,
hyp_pieces,
hyp_words,
) = process_predictions(
cfg,
hypos[i],
tgt_dict,
target_tokens,
res_files,
)
errs_t += errs
lengths_hyp_t += length_hyp
lengths_hyp_unit_t += (
len(hyp_pieces) if len(hyp_pieces) > 0 else len(hyp_words)
)
lengths_t += length
count += 1
all_hyp_pieces.append(hyp_pieces)
all_hyp_words.append(hyp_words)
num_sentences += (
sample["nsentences"] if "nsentences" in sample else sample["id"].numel()
)
lm_score_sum = 0
if kenlm is not None:
if cfg.unit_lm:
lm_score_sum = sum(kenlm.score(w) for w in all_hyp_pieces)
else:
lm_score_sum = sum(kenlm.score(w) for w in all_hyp_words)
elif fairseq_lm is not None:
lm_score_sum = sum(fairseq_lm.score([h.split() for h in all_hyp_words])[0])
vt_err_t = 0
vt_length_t = 0
if viterbi_transcript is not None:
unit_hyps = []
if cfg.targets is not None and cfg.lexicon is not None:
lex = {}
with open(cfg.lexicon, "r") as lf:
for line in lf:
items = line.rstrip().split()
lex[items[0]] = items[1:]
for h in all_hyp_pieces:
hyp_ws = []
for w in h.split():
assert w in lex, w
hyp_ws.extend(lex[w])
unit_hyps.append(hyp_ws)
else:
unit_hyps.extend([h.split() for h in all_hyp_words])
vt_err_t = sum(
editdistance.eval(vt, h) for vt, h in zip(viterbi_transcript, unit_hyps)
)
vt_length_t = sum(len(h) for h in viterbi_transcript)
if res_files is not None:
for r in res_files.values():
r.close()
gen_timer.stop(lengths_hyp_t)
return GenResult(
count,
errs_t,
gen_timer,
lengths_hyp_unit_t,
lengths_hyp_t,
lengths_t,
lm_score_sum,
num_feats,
num_sentences,
num_symbols,
vt_err_t,
vt_length_t,
)
def gen_hypos(generator, models, num_feats, sample, task, use_cuda):
sample = utils.move_to_cuda(sample) if use_cuda else sample
if "features" in sample["net_input"]:
sample["net_input"]["dense_x_only"] = True
num_feats += (
sample["net_input"]["features"].shape[0]
* sample["net_input"]["features"].shape[1]
)
hypos = task.inference_step(generator, models, sample, None)
return hypos, num_feats
def main(cfg: UnsupGenerateConfig, model=None):
if (
cfg.fairseq.dataset.max_tokens is None
and cfg.fairseq.dataset.batch_size is None
):
cfg.fairseq.dataset.max_tokens = 1024000
use_cuda = torch.cuda.is_available() and not cfg.fairseq.common.cpu
task = tasks.setup_task(cfg.fairseq.task)
overrides = ast.literal_eval(cfg.fairseq.common_eval.model_overrides)
if cfg.fairseq.task._name == "unpaired_audio_text":
overrides["model"] = {
"blank_weight": cfg.blank_weight,
"blank_mode": cfg.blank_mode,
"blank_is_sil": cfg.sil_is_blank,
"no_softmax": True,
"segmentation": {
"type": "NONE",
},
}
else:
overrides["model"] = {
"blank_weight": cfg.blank_weight,
"blank_mode": cfg.blank_mode,
}
if model is None:
# Load ensemble
logger.info("| loading model(s) from {}".format(cfg.fairseq.common_eval.path))
models, saved_cfg = checkpoint_utils.load_model_ensemble(
cfg.fairseq.common_eval.path.split("\\"),
arg_overrides=overrides,
task=task,
suffix=cfg.fairseq.checkpoint.checkpoint_suffix,
strict=(cfg.fairseq.checkpoint.checkpoint_shard_count == 1),
num_shards=cfg.fairseq.checkpoint.checkpoint_shard_count,
)
optimize_models(cfg, use_cuda, models)
else:
models = [model]
saved_cfg = cfg.fairseq
with open_dict(saved_cfg.task):
saved_cfg.task.shuffle = False
saved_cfg.task.sort_by_length = False
gen_result = generate(cfg, models, saved_cfg, use_cuda)
wer = None
if gen_result.lengths_t > 0:
wer = gen_result.errs_t * 100.0 / gen_result.lengths_t
logger.info(f"WER: {wer}")
lm_ppl = float("inf")
if gen_result.lm_score_t != 0 and gen_result.lengths_hyp_t > 0:
hyp_len = gen_result.lengths_hyp_t
lm_ppl = math.pow(
10, -gen_result.lm_score_t / (hyp_len + gen_result.num_sentences)
)
logger.info(f"LM PPL: {lm_ppl}")
logger.info(
"| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}"
" sentences/s, {:.2f} tokens/s)".format(
gen_result.num_sentences,
gen_result.gen_timer.n,
gen_result.gen_timer.sum,
gen_result.num_sentences / gen_result.gen_timer.sum,
1.0 / gen_result.gen_timer.avg,
)
)
vt_diff = None
if gen_result.vt_length_t > 0:
vt_diff = gen_result.vt_err_t / gen_result.vt_length_t
vt_diff = max(cfg.min_vt_uer, vt_diff)
lm_ppl = max(cfg.min_lm_ppl, lm_ppl)
if not cfg.unsupervised_tuning == 0:
weighted_score = wer
else:
weighted_score = math.log(lm_ppl) * (vt_diff or 1.0)
res = (
f"| Generate {cfg.fairseq.dataset.gen_subset} with beam={cfg.beam}, "
f"lm_weight={cfg.kaldi_decoder_config.acoustic_scale if cfg.kaldi_decoder_config else cfg.lm_weight}, "
f"word_score={cfg.word_score}, sil_weight={cfg.sil_weight}, blank_weight={cfg.blank_weight}, "
f"WER: {wer}, LM_PPL: {lm_ppl}, num feats: {gen_result.num_feats}, "
f"length: {gen_result.lengths_hyp_t}, UER to viterbi: {(vt_diff or 0) * 100}, score: {weighted_score}"
)
logger.info(res)
# print(res)
return task, weighted_score
@hydra.main(
config_path=os.path.join("../../..", "fairseq", "config"), config_name="config"
)
def hydra_main(cfg):
with open_dict(cfg):
# make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
cfg.job_logging_cfg = OmegaConf.to_container(
HydraConfig.get().job_logging, resolve=True
)
cfg = OmegaConf.create(
OmegaConf.to_container(cfg, resolve=False, enum_to_str=False)
)
OmegaConf.set_struct(cfg, True)
logger.info(cfg)
utils.import_user_module(cfg.fairseq.common)
_, score = main(cfg)
if cfg.is_ax:
return score, None
return score
def cli_main():
try:
from hydra._internal.utils import get_args
cfg_name = get_args().config_name or "config"
except:
logger.warning("Failed to get config name from hydra args")
cfg_name = "config"
cs = ConfigStore.instance()
cs.store(name=cfg_name, node=UnsupGenerateConfig)
hydra_main()
if __name__ == "__main__":
cli_main()