Spaces:
Build error
Build error
# -*- coding: utf-8 -*- | |
''' | |
@Author : Jiangjie Chen | |
@Time : 2020/7/25 18:23 | |
@Contact : [email protected] | |
@Description: | |
''' | |
import os | |
import cjjpy as cjj | |
import sys | |
import tensorflow as tf | |
import ujson as json | |
from tqdm import tqdm | |
import argparse | |
try: | |
sys.path.append(cjj.AbsParentDir(__file__, '..')) | |
from hparams import * | |
from pseudo_multiproc_toolkit import * | |
from dataloaders import FEVERLoader | |
from parsing_client.sentence_parser import SentenceParser, deal_bracket | |
from qg_client.question_generator import QuestionGenerator | |
except: | |
from .hparams import * | |
from .pseudo_multiproc_toolkit import * | |
from ..dataloaders import FEVERLoader | |
from ..parsing_client.sentence_parser import SentenceParser, deal_bracket | |
from ..qg_client.question_generator import QuestionGenerator | |
def prepare_answers(version, role, evi_key='bert_evidence', overwrite=False): | |
''' | |
:return | |
{ | |
'id': id, | |
'claim': c, | |
'label': x, | |
'evidence': [e1, e2, ...], # n | |
'answers': [a1, a2, ...], # m | |
'answer_roles': [noun, noun, adj, verb, ...] # m | |
} | |
''' | |
assert role in ['val', 'test', 'train', 'eval'], role | |
def _proc_one(js): | |
js.pop('all_evidence') | |
evidence = [deal_bracket(ev[2], True, ev[0]) for ev in js[evi_key]] | |
results = sent_client.identify_NPs(deal_bracket(js['claim'], True), | |
candidate_NPs=[x[0] for x in js[evi_key]]) | |
NPs = results['NPs'] | |
claim = results['text'] | |
verbs = results['verbs'] | |
adjs = results['adjs'] | |
_cache = {'id': js['id'], | |
'claim': claim, | |
'evidence': evidence, | |
'answers': NPs + verbs + adjs, | |
'answer_roles': ['noun'] * len(NPs) + ['verb'] * len(verbs) + ['adj'] * len(adjs)} | |
if js.get('label'): | |
_cache.update({'label': js['label']}) | |
return _cache | |
cached_ = QG_PREFIX.format(version=version) + CACHED_ANSEWR_FILE.format(role=role) | |
tf.io.gfile.makedirs(QG_PREFIX.format(version=version)) | |
if tf.io.gfile.exists(cached_) and not overwrite: | |
print(f'* Skipped, exising {cached_}') | |
return cached_ | |
sent_client = SentenceParser(device='cuda:0') | |
floader = FEVERLoader(role) | |
floader.load_fever(evi_key.split('_')[0]) | |
with tf.io.gfile.GFile(cached_, 'w') as f: | |
for id in tqdm(floader, desc=f'{role} answer'): | |
res = _proc_one(floader[id]) | |
f.write(json.dumps(res) + '\n') | |
cjj.lark(f'* NPs baked in {cached_}') | |
return cached_ | |
def prepare_questions(version, role, qg_model='t5', batch_size=64, overwrite=False): | |
''' | |
After prepare_nps | |
:return | |
{ | |
'id': id, | |
'claim': c, | |
'label': x, | |
'evidence': [e1, e2, ...], # n | |
'answers': [a1, a2, ...], # m | |
'questions': [q1, q2, ...], # m | |
'cloze_qs': [q1, q2, ...], #m | |
'regular_qs': [q1, q2, ...], #m | |
'answer_roles': [noun, noun, adj, verb, ...] # m | |
} | |
''' | |
cached_answer = QG_PREFIX.format(version=version) + CACHED_ANSEWR_FILE.format(role=role) | |
cached_question = QG_PREFIX.format(version=version) + CACHED_QUESTION_FILE.format(role=role) | |
if tf.io.gfile.exists(cached_question) and not overwrite: | |
print(f'* Skipped, existing {cached_question}') | |
return cached_question | |
qg_client = QuestionGenerator(qg_model) | |
with tf.io.gfile.GFile(cached_answer, 'r') as f, \ | |
tf.io.gfile.GFile(cached_question, 'w') as fo: | |
data = f.read().splitlines() | |
data_dict = {} | |
_cache = [] | |
for line in data: | |
js = json.loads(line) | |
data_dict[js['id']] = js | |
if len(js['answers']) == 0: | |
# TODO: hack empty answer | |
print('Empty answer:', js) | |
pseudo_answer = js['claim'].split()[0] | |
js['answers'] = [(pseudo_answer, 0, len(pseudo_answer))] | |
js['answer_roles'] = ['noun'] | |
for answer in js['answers']: | |
_cache.append((js['claim'], [answer], js['id'])) | |
print(_cache[:5]) | |
qa_pairs = qg_client.generate([(x, y) for x, y, z in _cache], batch_size=batch_size) | |
print(qa_pairs[:5]) | |
for (q, clz_q, a), (_, _, id) in zip(qa_pairs, _cache): | |
if 'questions' in data_dict[id]: | |
data_dict[id]['cloze_qs'].append(clz_q) | |
data_dict[id]['regular_qs'].append(q) | |
data_dict[id]['questions'].append(qg_client.assemble_question(q, clz_q)) | |
else: | |
data_dict[id]['cloze_qs'] = [clz_q] | |
data_dict[id]['regular_qs'] = [q] | |
data_dict[id]['questions'] = [qg_client.assemble_question(q, clz_q)] | |
_ = [_sanity_check(data_dict[k]) for k in data_dict] | |
for k in data_dict: | |
fo.write(json.dumps(data_dict[k]) + '\n') | |
cjj.lark(f'* Questions baked in {cached_question}') | |
return cached_question | |
def _sanity_check(js): | |
try: | |
assert len(js['questions']) == len(js['answers']) | |
assert len(js['answers']) == len(js['answer_roles']) | |
except: | |
print(js) | |
raise Exception | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--overwrite', action='store_true') | |
parser.add_argument('--batch_size', '-b', type=int, default=64) | |
parser.add_argument('--evi_key', '-e', type=str, default='bert_evidence') | |
parser.add_argument('--version', '-v', type=str, help='v1, v2, ...', default='v5') | |
parser.add_argument('--roles', nargs='+', required=True, | |
help='train val test eval') | |
parser.add_argument('--qg_model', '-m', type=str, default='t5') | |
args = parser.parse_args() | |
for role in args.roles: | |
prepare_answers(args.version, role, args.evi_key, args.overwrite) | |
prepare_questions(args.version, role, args.qg_model, args.batch_size, args.overwrite) | |