Spaces:
Build error
Build error
# -*- coding: utf-8 -*- | |
''' | |
@Author : Jiangjie Chen | |
@Time : 2020/7/20 17:54 | |
@Contact : [email protected] | |
@Description: | |
''' | |
import os | |
import ujson as json | |
import tensorflow as tf | |
import argparse | |
import random | |
from transformers import BartTokenizer | |
try: | |
from .hparams import CACHED_QUESTION_FILE, QG_PREFIX | |
except: | |
from hparams import CACHED_QUESTION_FILE, QG_PREFIX | |
random.seed(1111) | |
def pproc_seq2seq(input_file, output_dir, role): | |
''' | |
:param input_file: | |
{ | |
'id': id, | |
'claim': c, | |
'label': x, | |
'evidence': [e1, e2, ...], # n | |
'answers': [a1, a2, ...], # m | |
'questions': [q1, q2, ...], # m | |
'cloze_qs': [q1, q2, ...], #m | |
'regular_qs': [q1, q2, ...], #m | |
'answer_roles': [noun, noun, adj, verb, ...] # m | |
} | |
''' | |
assert role in ['val', 'test', 'train'], role | |
use_rag = 'v6' in input_file | |
if not use_rag: | |
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base') | |
tf.io.gfile.makedirs(output_dir) | |
src_fname = os.path.join(output_dir, f'{role}.source') | |
tgt_fname = os.path.join(output_dir, f'{role}.target') | |
with tf.io.gfile.GFile(input_file) as fin, \ | |
tf.io.gfile.GFile(src_fname, 'w') as srcf, \ | |
tf.io.gfile.GFile(tgt_fname, 'w') as tgtf: | |
data = fin.readlines() | |
for line in data: | |
js = json.loads(line) | |
if js['label'] == 'SUPPORTS': | |
evidence = ' '.join(js['evidence']) | |
questions = js['questions'] | |
i = random.randint(0, len(questions) - 1) | |
if use_rag: | |
srcf.write(f'{questions[i]}\n') | |
else: | |
srcf.write(f'{questions[i]} {tokenizer.sep_token} {evidence}\n') | |
tgtf.write(js['answers'][i][0] + '\n') | |
return src_fname, tgt_fname | |
def pproc_for_mrc(output_dir, version): | |
assert version in ['v5'] | |
for role in ['val', 'train', 'test']: | |
_role = 'val' if role == 'test' else role | |
input_file = os.path.join(QG_PREFIX.format(version=version), | |
CACHED_QUESTION_FILE.format(role=_role)) | |
pproc_seq2seq(input_file, output_dir, role) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--output_dir', '-o', required=True, default='data/mrc_seq2seq_v5', | |
help='data path, e.g. data/mrc_seq2seq_v5') | |
parser.add_argument('--version', '-v', type=str, default='v5') | |
args = parser.parse_args() | |
pproc_for_mrc(args.output_dir, args.version) | |