jiangjiechen's picture
init loren for spaces
7f7285f
raw
history blame
2.62 kB
# -*- coding: utf-8 -*-
'''
@Author : Jiangjie Chen
@Time : 2020/7/20 17:54
@Contact : [email protected]
@Description:
'''
import os
import ujson as json
import tensorflow as tf
import argparse
import random
from transformers import BartTokenizer
try:
from .hparams import CACHED_QUESTION_FILE, QG_PREFIX
except:
from hparams import CACHED_QUESTION_FILE, QG_PREFIX
random.seed(1111)
def pproc_seq2seq(input_file, output_dir, role):
'''
:param input_file:
{
'id': id,
'claim': c,
'label': x,
'evidence': [e1, e2, ...], # n
'answers': [a1, a2, ...], # m
'questions': [q1, q2, ...], # m
'cloze_qs': [q1, q2, ...], #m
'regular_qs': [q1, q2, ...], #m
'answer_roles': [noun, noun, adj, verb, ...] # m
}
'''
assert role in ['val', 'test', 'train'], role
use_rag = 'v6' in input_file
if not use_rag:
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
tf.io.gfile.makedirs(output_dir)
src_fname = os.path.join(output_dir, f'{role}.source')
tgt_fname = os.path.join(output_dir, f'{role}.target')
with tf.io.gfile.GFile(input_file) as fin, \
tf.io.gfile.GFile(src_fname, 'w') as srcf, \
tf.io.gfile.GFile(tgt_fname, 'w') as tgtf:
data = fin.readlines()
for line in data:
js = json.loads(line)
if js['label'] == 'SUPPORTS':
evidence = ' '.join(js['evidence'])
questions = js['questions']
i = random.randint(0, len(questions) - 1)
if use_rag:
srcf.write(f'{questions[i]}\n')
else:
srcf.write(f'{questions[i]} {tokenizer.sep_token} {evidence}\n')
tgtf.write(js['answers'][i][0] + '\n')
return src_fname, tgt_fname
def pproc_for_mrc(output_dir, version):
assert version in ['v5']
for role in ['val', 'train', 'test']:
_role = 'val' if role == 'test' else role
input_file = os.path.join(QG_PREFIX.format(version=version),
CACHED_QUESTION_FILE.format(role=_role))
pproc_seq2seq(input_file, output_dir, role)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--output_dir', '-o', required=True, default='data/mrc_seq2seq_v5',
help='data path, e.g. data/mrc_seq2seq_v5')
parser.add_argument('--version', '-v', type=str, default='v5')
args = parser.parse_args()
pproc_for_mrc(args.output_dir, args.version)