File size: 1,860 Bytes
7f7285f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# -*- coding:utf-8  -*-

"""
@Author             : Bao
@Date               : 2021/9/7
@Desc               :
@Last modified by   : Bao
@Last modified date : 2021/9/7
"""

import json
import numpy as np
import argparse
from collections import defaultdict
from sklearn.metrics import precision_recall_fscore_support

# ref --> label 1, nei & sup --> label 0
idx2label = {0: 1, 1: 0, 2: 0}


def read_json_lines(filename, mode='r', encoding='utf-8', skip=0):
    with open(filename, mode, encoding=encoding) as fin:
        for line in fin:
            if skip > 0:
                skip -= 1
                continue
            yield json.loads(line)


def process(filein):
    id2info = defaultdict(dict)
    for line in read_json_lines('eval.human.ref.merged.json'):
        labels = [0] * len(line['questions'])
        for cul in line['culprit']:
            labels[cul] = 1
        id2info[line['id']].update({'id': line['id'], 'labels': labels})

    for line in read_json_lines(filein):
        if line['id'] not in id2info: continue
        predicted = [idx2label[_] for _ in np.argmax(line['z_prob'], axis=-1)]
        id2info[line['id']]['predicted'] = predicted

    ps, rs, fs = [], [], []
    for info in id2info.values():
        p, r, f, _ = precision_recall_fscore_support(info['labels'], info['predicted'], average='binary')
        ps.append(p)
        rs.append(r)
        fs.append(f)
    print(filein)
    print('Precision: {}'.format(sum(ps) / len(ps)))
    print('Recall: {}'.format(sum(rs) / len(rs)))
    print('F1: {}'.format(sum(fs) / len(fs)))

    return sum(ps) / len(ps), sum(rs) / len(rs), sum(fs) / len(fs)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', type=str, help='predicted jsonl file with phrasal veracity predictions.')
    args = parser.parse_args()
    process(args.i)