|
|
|
|
|
|
|
|
|
|
|
""" |
|
Sample from a trained LM; hacked fairseq-interactive |
|
""" |
|
from collections import namedtuple |
|
import os |
|
import ast |
|
import numpy as np |
|
|
|
from fairseq import checkpoint_utils, options, tasks, utils |
|
|
|
import tqdm |
|
|
|
Batch = namedtuple('Batch', 'ids src_tokens src_lengths') |
|
Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments') |
|
|
|
|
|
def make_batches(lines, args, task, max_positions): |
|
tokens = [ |
|
task.source_dictionary.encode_line( |
|
src_str, add_if_not_exist=False |
|
).long() |
|
for src_str in lines |
|
] |
|
lengths = [t.numel() for t in tokens] |
|
itr = task.get_batch_iterator( |
|
dataset=task.build_dataset_for_inference(tokens, lengths), |
|
max_tokens=args.dataset.max_tokens, |
|
max_sentences=args.dataset.batch_size, |
|
max_positions=max_positions, |
|
ignore_invalid_inputs=args.dataset.skip_invalid_size_inputs_valid_test |
|
).next_epoch_itr(shuffle=False) |
|
for batch in itr: |
|
yield Batch( |
|
ids=batch['id'], |
|
src_tokens=batch['net_input']['src_tokens'], src_lengths=batch['net_input']['src_lengths'], |
|
) |
|
|
|
|
|
def main(args): |
|
arg_prompts = args.prompts |
|
arg_output = args.output |
|
arg_debug = args.debug |
|
arg_sample_size = args.samples_per_prompt |
|
|
|
try: |
|
from fairseq.dataclass.utils import convert_namespace_to_omegaconf |
|
args = convert_namespace_to_omegaconf(args) |
|
except: |
|
pass |
|
|
|
|
|
if args.common.seed is not None: |
|
np.random.seed(args.common.seed) |
|
utils.set_torch_seed(args.common.seed) |
|
|
|
if args.generation.sampling: |
|
args.generation.nbest = args.generation.beam = arg_sample_size |
|
|
|
task = tasks.setup_task(args.task) |
|
|
|
overrides = ast.literal_eval(args.common_eval.model_overrides) |
|
|
|
models, _model_args = checkpoint_utils.load_model_ensemble( |
|
args.common_eval.path.split(os.pathsep), |
|
arg_overrides=overrides, |
|
task=task, |
|
suffix=getattr(args, "checkpoint_suffix", ""), |
|
) |
|
|
|
|
|
src_dict = task.source_dictionary |
|
tgt_dict = task.target_dictionary |
|
|
|
|
|
for model in models: |
|
model.prepare_for_inference_(args) |
|
model.cuda() |
|
|
|
|
|
|
|
align_dict = utils.load_align_dict(args.generation.replace_unk) |
|
|
|
max_positions = utils.resolve_max_positions( |
|
task.max_positions(), |
|
*[model.max_positions() for model in models] |
|
) |
|
|
|
output_file = open(arg_output, 'w') |
|
|
|
with open(arg_prompts, 'r') as fin: |
|
lines = fin.readlines() |
|
|
|
split = [x.split('|', 1) for x in lines] |
|
seq_id = [x[0] for x in split] |
|
prompts = [x[1] for x in split] |
|
|
|
if args.generation.prefix_size >= 0: |
|
prompts = [' '.join(l.split()[:args.generation.prefix_size]) |
|
for l in prompts] |
|
|
|
if arg_debug: |
|
prompts = prompts[:10] |
|
|
|
generator = task.build_generator(models, args.generation) |
|
|
|
start_id = 0 |
|
pbar = tqdm.tqdm(total=len(prompts)) |
|
for batch in make_batches(prompts, args, task, max_positions): |
|
src_tokens = batch.src_tokens |
|
src_lengths = batch.src_lengths |
|
src_tokens = src_tokens.cuda() |
|
src_lengths = src_lengths.cuda() |
|
|
|
sample = { |
|
'net_input': { |
|
'src_tokens': src_tokens, |
|
'src_lengths': src_lengths, |
|
}, |
|
} |
|
|
|
results = [] |
|
translations = task.inference_step(generator, models, sample) |
|
for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)): |
|
src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad()) |
|
results.append((i + start_id, src_tokens_i, hypos)) |
|
|
|
|
|
for id, src_tokens, hypos in sorted(results, key=lambda x: x[0]): |
|
if src_dict is not None: |
|
src_str = src_dict.string( |
|
src_tokens, args.common_eval.post_process) |
|
|
|
|
|
for hypo_id, hypo in enumerate(hypos): |
|
_hypo_tokens, hypo_str, _alignment = utils.post_process_prediction( |
|
hypo_tokens=hypo['tokens'].int().cpu(), |
|
src_str=src_str, |
|
alignment=hypo['alignment'], |
|
align_dict=align_dict, |
|
tgt_dict=tgt_dict, |
|
remove_bpe=args.common_eval.post_process, |
|
) |
|
|
|
detok_hypo_str = hypo_str |
|
utterance = detok_hypo_str |
|
print(f'{seq_id[id]}__{hypo_id}|{utterance}', file=output_file) |
|
pbar.update(1) |
|
start_id += len(results) |
|
|
|
|
|
|
|
|
|
def cli_main(): |
|
parser = options.get_interactive_generation_parser() |
|
parser.add_argument('--prompts', type=str, default=None, required=True) |
|
parser.add_argument('--output', type=str, default=None, required=True) |
|
parser.add_argument('--debug', action='store_true') |
|
parser.add_argument('--samples-per-prompt', type=int, default=1) |
|
|
|
args = options.parse_args_and_arch(parser) |
|
|
|
np.random.seed(args.seed) |
|
utils.set_torch_seed(args.seed) |
|
|
|
main(args) |
|
|
|
|
|
if __name__ == '__main__': |
|
cli_main() |
|
|