File size: 2,619 Bytes
7f7285f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# -*- coding: utf-8 -*-

'''
@Author     : Jiangjie Chen
@Time       : 2020/7/20 17:54
@Contact    : [email protected]
@Description:
'''

import os
import ujson as json
import tensorflow as tf
import argparse
import random
from transformers import BartTokenizer

try:
    from .hparams import CACHED_QUESTION_FILE, QG_PREFIX
except:
    from hparams import CACHED_QUESTION_FILE, QG_PREFIX


random.seed(1111)


def pproc_seq2seq(input_file, output_dir, role):
    '''
    :param input_file:
        {
        'id': id,
        'claim': c,
        'label': x,
        'evidence': [e1, e2, ...], # n
        'answers': [a1, a2, ...], # m
        'questions': [q1, q2, ...], # m
        'cloze_qs': [q1, q2, ...], #m
        'regular_qs': [q1, q2, ...], #m
        'answer_roles': [noun, noun, adj, verb, ...] # m
    }
    '''
    assert role in ['val', 'test', 'train'], role

    use_rag = 'v6' in input_file
    if not use_rag:
        tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')

    tf.io.gfile.makedirs(output_dir)
    src_fname = os.path.join(output_dir, f'{role}.source')
    tgt_fname = os.path.join(output_dir, f'{role}.target')

    with tf.io.gfile.GFile(input_file) as fin, \
            tf.io.gfile.GFile(src_fname, 'w') as srcf, \
            tf.io.gfile.GFile(tgt_fname, 'w') as tgtf:
        data = fin.readlines()
        for line in data:
            js = json.loads(line)
            if js['label'] == 'SUPPORTS':
                evidence = ' '.join(js['evidence'])
                questions = js['questions']
                i = random.randint(0, len(questions) - 1)
                if use_rag:
                    srcf.write(f'{questions[i]}\n')
                else:
                    srcf.write(f'{questions[i]} {tokenizer.sep_token} {evidence}\n')
                tgtf.write(js['answers'][i][0] + '\n')

    return src_fname, tgt_fname


def pproc_for_mrc(output_dir, version):
    assert version in ['v5']
    for role in ['val', 'train', 'test']:
        _role = 'val' if role == 'test' else role
        input_file = os.path.join(QG_PREFIX.format(version=version),
                                  CACHED_QUESTION_FILE.format(role=_role))
        pproc_seq2seq(input_file, output_dir, role)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--output_dir', '-o', required=True, default='data/mrc_seq2seq_v5',
                        help='data path, e.g. data/mrc_seq2seq_v5')
    parser.add_argument('--version', '-v', type=str, default='v5')
    args = parser.parse_args()
    pproc_for_mrc(args.output_dir, args.version)