jiangjiechen's picture
init loren for spaces
7f7285f
raw
history blame
1.86 kB
# -*- coding:utf-8 -*-
"""
@Author : Bao
@Date : 2021/9/7
@Desc :
@Last modified by : Bao
@Last modified date : 2021/9/7
"""
import json
import numpy as np
import argparse
from collections import defaultdict
from sklearn.metrics import precision_recall_fscore_support
# ref --> label 1, nei & sup --> label 0
idx2label = {0: 1, 1: 0, 2: 0}
def read_json_lines(filename, mode='r', encoding='utf-8', skip=0):
with open(filename, mode, encoding=encoding) as fin:
for line in fin:
if skip > 0:
skip -= 1
continue
yield json.loads(line)
def process(filein):
id2info = defaultdict(dict)
for line in read_json_lines('eval.human.ref.merged.json'):
labels = [0] * len(line['questions'])
for cul in line['culprit']:
labels[cul] = 1
id2info[line['id']].update({'id': line['id'], 'labels': labels})
for line in read_json_lines(filein):
if line['id'] not in id2info: continue
predicted = [idx2label[_] for _ in np.argmax(line['z_prob'], axis=-1)]
id2info[line['id']]['predicted'] = predicted
ps, rs, fs = [], [], []
for info in id2info.values():
p, r, f, _ = precision_recall_fscore_support(info['labels'], info['predicted'], average='binary')
ps.append(p)
rs.append(r)
fs.append(f)
print(filein)
print('Precision: {}'.format(sum(ps) / len(ps)))
print('Recall: {}'.format(sum(rs) / len(rs)))
print('F1: {}'.format(sum(fs) / len(fs)))
return sum(ps) / len(ps), sum(rs) / len(rs), sum(fs) / len(fs)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-i', type=str, help='predicted jsonl file with phrasal veracity predictions.')
args = parser.parse_args()
process(args.i)