loren-fact-checking / src /mrc_client /answer_generator.py
jiangjiechen's picture
init loren for spaces
7f7285f
raw
history blame
4.89 kB
# -*- coding: utf-8 -*-
'''
@Author : Jiangjie Chen
@Time : 2020/8/12 14:44
@Contact : [email protected]
@Description:
'''
import re
import time
from pathlib import Path
from typing import Dict, List
import torch
from logging import getLogger
from tqdm import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import ujson as json
import random
try:
from .seq2seq.seq2seq_utils import (
use_task_specific_params,
calculate_rouge,
chunks,
Seq2SeqDataset,
lmap,
load_json,
save_json,
)
except ImportError:
import cjjpy as cjj
import sys
sys.path.append(cjj.AbsParentDir(__file__, '.'))
from seq2seq.seq2seq_utils import (
use_task_specific_params,
calculate_rouge,
chunks,
Seq2SeqDataset,
lmap,
load_json,
save_json,
)
logger = getLogger(__name__)
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
random.seed(1111)
def assemble_answers_to_one(js, k=5, mask_token='<mask>', mask_rate=0.):
if isinstance(js, str):
js = json.loads(js)
should_keep = random.random() > mask_rate
js.pop('evidential_assembled')
for q, answers in zip(js['cloze_qs'], js['evidential']):
if mask_token in q:
s = q.find(mask_token)
e = s + len(mask_token)
nq_list = []
if should_keep:
for i in range(k):
answer_span = answers[i]
nq = q[:s] + answer_span + q[e:]
nq_list.append(nq)
else:
for i in range(k):
answer_span = mask_token
nq = q[:s] + answer_span + q[e:]
nq_list.append(nq)
ev_nqs = ' '.join(nq_list)
if js.get('evidential_assembled') is None:
js['evidential_assembled'] = [ev_nqs]
else:
js['evidential_assembled'].append(ev_nqs)
assert len(js['evidential_assembled']) == len(js['answers'])
return js
class AnswerGenerator():
def __init__(self, model_name, device=DEFAULT_DEVICE):
self.model_name = str(model_name)
self.device = device
self.model = None
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
def init_model(self):
if self.model is None:
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name).to(self.device)
def assemble(self, question, context):
sep = '\n' if 'unifiedqa' in self.tokenizer.name_or_path else self.tokenizer.sep_token
return f'{question} {sep} {context}'
def generate(self, examples, out_file=None, batch_size=16, verbose=True,
max_length=20, min_length=1, num_beams=4, num_return_sequences=4,
prefix=None, fp16=False, task='summarization', **generate_kwargs):
'''
:param examples: [N]
:return: [N x num_return_seq]
'''
self.init_model()
if fp16:
self.model = self.model.half()
# update config with summarization specific params
use_task_specific_params(self.model, task)
fout = None if out_file is None else Path(out_file).open("w", encoding="utf-8")
generated = []
if verbose:
iter = tqdm(list(chunks(examples, batch_size)), desc="MRC")
else:
iter = list(chunks(examples, batch_size))
if prefix is None:
prefix = prefix or getattr(self.model.config, "prefix", "") or ""
for examples_chunk in iter:
examples_chunk = [prefix + text for text in examples_chunk]
batch = self.tokenizer(examples_chunk, return_tensors="pt", truncation=True,
padding="longest").to(self.device)
summaries = self.model.generate(
input_ids=batch.input_ids,
attention_mask=batch.attention_mask,
max_length=max_length,
min_length=min_length,
num_beams=num_beams,
num_return_sequences=num_return_sequences,
length_penalty=1.2,
repetition_penalty=1.2,
**generate_kwargs,
)
dec = self.tokenizer.batch_decode(summaries, skip_special_tokens=True,
clean_up_tokenization_spaces=False)
if fout is not None:
for hypothesis in dec:
fout.write(hypothesis.strip() + "\n")
fout.flush()
else:
generated += dec
if fout is not None:
fout.close()
generated = list(map(lambda x: x.strip(), generated))
generated = list(chunks(generated, num_return_sequences))
return generated