oiTrans / inference /custom_interactive.py
harveen
Harveen | Adding code
74fc30d
raw
history blame
11.4 kB
# python wrapper for fairseq-interactive command line tool
#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Translate raw text with a trained model. Batches data on-the-fly.
"""
import ast
from collections import namedtuple
import torch
from fairseq import checkpoint_utils, options, tasks, utils
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.token_generation_constraints import pack_constraints, unpack_constraints
from fairseq_cli.generate import get_symbols_to_strip_from_output
import codecs
Batch = namedtuple("Batch", "ids src_tokens src_lengths constraints")
Translation = namedtuple("Translation", "src_str hypos pos_scores alignments")
def make_batches(
lines, cfg, task, max_positions, encode_fn, constrainted_decoding=False
):
def encode_fn_target(x):
return encode_fn(x)
if constrainted_decoding:
# Strip (tab-delimited) contraints, if present, from input lines,
# store them in batch_constraints
batch_constraints = [list() for _ in lines]
for i, line in enumerate(lines):
if "\t" in line:
lines[i], *batch_constraints[i] = line.split("\t")
# Convert each List[str] to List[Tensor]
for i, constraint_list in enumerate(batch_constraints):
batch_constraints[i] = [
task.target_dictionary.encode_line(
encode_fn_target(constraint),
append_eos=False,
add_if_not_exist=False,
)
for constraint in constraint_list
]
if constrainted_decoding:
constraints_tensor = pack_constraints(batch_constraints)
else:
constraints_tensor = None
tokens, lengths = task.get_interactive_tokens_and_lengths(lines, encode_fn)
itr = task.get_batch_iterator(
dataset=task.build_dataset_for_inference(
tokens, lengths, constraints=constraints_tensor
),
max_tokens=cfg.dataset.max_tokens,
max_sentences=cfg.dataset.batch_size,
max_positions=max_positions,
ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
).next_epoch_itr(shuffle=False)
for batch in itr:
ids = batch["id"]
src_tokens = batch["net_input"]["src_tokens"]
src_lengths = batch["net_input"]["src_lengths"]
constraints = batch.get("constraints", None)
yield Batch(
ids=ids,
src_tokens=src_tokens,
src_lengths=src_lengths,
constraints=constraints,
)
class Translator:
def __init__(
self, data_dir, checkpoint_path, batch_size=25, constrained_decoding=False
):
self.constrained_decoding = constrained_decoding
self.parser = options.get_generation_parser(interactive=True)
# buffer_size is currently not used but we just initialize it to batch
# size + 1 to avoid any assertion errors.
if self.constrained_decoding:
self.parser.set_defaults(
path=checkpoint_path,
remove_bpe="subword_nmt",
num_workers=-1,
constraints="ordered",
batch_size=batch_size,
buffer_size=batch_size + 1,
)
else:
self.parser.set_defaults(
path=checkpoint_path,
remove_bpe="subword_nmt",
num_workers=-1,
batch_size=batch_size,
buffer_size=batch_size + 1,
)
args = options.parse_args_and_arch(self.parser, input_args=[data_dir])
# we are explictly setting src_lang and tgt_lang here
# generally the data_dir we pass contains {split}-{src_lang}-{tgt_lang}.*.idx files from
# which fairseq infers the src and tgt langs(if these are not passed). In deployment we dont
# use any idx files and only store the SRC and TGT dictionaries.
args.source_lang = "SRC"
args.target_lang = "TGT"
# since we are truncating sentences to max_seq_len in engine, we can set it to False here
args.skip_invalid_size_inputs_valid_test = False
# we have custom architechtures in this folder and we will let fairseq
# import this
args.user_dir = "model_configs"
self.cfg = convert_namespace_to_omegaconf(args)
utils.import_user_module(self.cfg.common)
if self.cfg.interactive.buffer_size < 1:
self.cfg.interactive.buffer_size = 1
if self.cfg.dataset.max_tokens is None and self.cfg.dataset.batch_size is None:
self.cfg.dataset.batch_size = 1
assert (
not self.cfg.generation.sampling
or self.cfg.generation.nbest == self.cfg.generation.beam
), "--sampling requires --nbest to be equal to --beam"
assert (
not self.cfg.dataset.batch_size
or self.cfg.dataset.batch_size <= self.cfg.interactive.buffer_size
), "--batch-size cannot be larger than --buffer-size"
# Fix seed for stochastic decoding
# if self.cfg.common.seed is not None and not self.cfg.generation.no_seed_provided:
# np.random.seed(self.cfg.common.seed)
# utils.set_torch_seed(self.cfg.common.seed)
# if not self.constrained_decoding:
# self.use_cuda = torch.cuda.is_available() and not self.cfg.common.cpu
# else:
# self.use_cuda = False
self.use_cuda = torch.cuda.is_available() and not self.cfg.common.cpu
# Setup task, e.g., translation
self.task = tasks.setup_task(self.cfg.task)
# Load ensemble
overrides = ast.literal_eval(self.cfg.common_eval.model_overrides)
self.models, self._model_args = checkpoint_utils.load_model_ensemble(
utils.split_paths(self.cfg.common_eval.path),
arg_overrides=overrides,
task=self.task,
suffix=self.cfg.checkpoint.checkpoint_suffix,
strict=(self.cfg.checkpoint.checkpoint_shard_count == 1),
num_shards=self.cfg.checkpoint.checkpoint_shard_count,
)
# Set dictionaries
self.src_dict = self.task.source_dictionary
self.tgt_dict = self.task.target_dictionary
# Optimize ensemble for generation
for model in self.models:
if model is None:
continue
if self.cfg.common.fp16:
model.half()
if (
self.use_cuda
and not self.cfg.distributed_training.pipeline_model_parallel
):
model.cuda()
model.prepare_for_inference_(self.cfg)
# Initialize generator
self.generator = self.task.build_generator(self.models, self.cfg.generation)
# Handle tokenization and BPE
self.tokenizer = self.task.build_tokenizer(self.cfg.tokenizer)
self.bpe = self.task.build_bpe(self.cfg.bpe)
# Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary)
self.align_dict = utils.load_align_dict(self.cfg.generation.replace_unk)
self.max_positions = utils.resolve_max_positions(
self.task.max_positions(), *[model.max_positions() for model in self.models]
)
def encode_fn(self, x):
if self.tokenizer is not None:
x = self.tokenizer.encode(x)
if self.bpe is not None:
x = self.bpe.encode(x)
return x
def decode_fn(self, x):
if self.bpe is not None:
x = self.bpe.decode(x)
if self.tokenizer is not None:
x = self.tokenizer.decode(x)
return x
def translate(self, inputs, constraints=None):
if self.constrained_decoding and constraints is None:
raise ValueError("Constraints cant be None in constrained decoding mode")
if not self.constrained_decoding and constraints is not None:
raise ValueError("Cannot pass constraints during normal translation")
if constraints:
constrained_decoding = True
modified_inputs = []
for _input, constraint in zip(inputs, constraints):
modified_inputs.append(_input + f"\t{constraint}")
inputs = modified_inputs
else:
constrained_decoding = False
start_id = 0
results = []
final_translations = []
for batch in make_batches(
inputs,
self.cfg,
self.task,
self.max_positions,
self.encode_fn,
constrained_decoding,
):
bsz = batch.src_tokens.size(0)
src_tokens = batch.src_tokens
src_lengths = batch.src_lengths
constraints = batch.constraints
if self.use_cuda:
src_tokens = src_tokens.cuda()
src_lengths = src_lengths.cuda()
if constraints is not None:
constraints = constraints.cuda()
sample = {
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
},
}
translations = self.task.inference_step(
self.generator, self.models, sample, constraints=constraints
)
list_constraints = [[] for _ in range(bsz)]
if constrained_decoding:
list_constraints = [unpack_constraints(c) for c in constraints]
for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
src_tokens_i = utils.strip_pad(src_tokens[i], self.tgt_dict.pad())
constraints = list_constraints[i]
results.append(
(
start_id + id,
src_tokens_i,
hypos,
{
"constraints": constraints,
},
)
)
# sort output to match input order
for id_, src_tokens, hypos, _ in sorted(results, key=lambda x: x[0]):
src_str = ""
if self.src_dict is not None:
src_str = self.src_dict.string(
src_tokens, self.cfg.common_eval.post_process
)
# Process top predictions
for hypo in hypos[: min(len(hypos), self.cfg.generation.nbest)]:
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo["tokens"].int().cpu(),
src_str=src_str,
alignment=hypo["alignment"],
align_dict=self.align_dict,
tgt_dict=self.tgt_dict,
remove_bpe="subword_nmt",
extra_symbols_to_ignore=get_symbols_to_strip_from_output(
self.generator
),
)
detok_hypo_str = self.decode_fn(hypo_str)
final_translations.append(detok_hypo_str)
return final_translations