Spaces:
Build error
Build error
# -*- coding: utf-8 -*- | |
""" | |
@Author : Bao | |
@Date : 2020/8/12 | |
@Desc : | |
@Last modified by : Bao | |
@Last modified date : 2020/8/12 | |
""" | |
import logging | |
from numpy.core.fromnumeric import argmax | |
import ujson as json | |
import torch | |
from plm_checkers.checker_utils import soft_logic | |
def init_logger(level, filename=None, mode='a', encoding='utf-8'): | |
logging_config = { | |
'format': '%(asctime)s - %(levelname)s - %(name)s:\t%(message)s', | |
'datefmt': '%Y-%m-%d %H:%M:%S', | |
'level': level, | |
'handlers': [logging.StreamHandler()] | |
} | |
if filename: | |
logging_config['handlers'].append(logging.FileHandler(filename, mode, encoding)) | |
logging.basicConfig(**logging_config) | |
def read_json(filename, mode='r', encoding='utf-8'): | |
with open(filename, mode, encoding=encoding) as fin: | |
return json.load(fin) | |
def save_json(data, filename, mode='w', encoding='utf-8'): | |
with open(filename, mode, encoding=encoding) as fout: | |
json.dump(data, fout, ensure_ascii=False, indent=4) | |
def read_json_lines(filename, mode='r', encoding='utf-8', skip=0): | |
with open(filename, mode, encoding=encoding) as fin: | |
for line in fin: | |
if skip > 0: | |
skip -= 1 | |
continue | |
yield json.loads(line) | |
def save_json_lines(data, filename, mode='w', encoding='utf-8', skip=0): | |
with open(filename, mode, encoding=encoding) as fout: | |
for line in data: | |
if skip > 0: | |
skip -= 1 | |
continue | |
print(json.dumps(line, ensure_ascii=False), file=fout) | |
def read_json_dict(filename, mode='r', encoding='utf-8'): | |
with open(filename, mode, encoding=encoding) as fin: | |
key_2_id = json.load(fin) | |
id_2_key = dict(zip(key_2_id.values(), key_2_id.keys())) | |
return key_2_id, id_2_key | |
def save_json_dict(data, filename, mode='w', encoding='utf-8'): | |
with open(filename, mode, encoding=encoding) as fout: | |
json.dump(data, fout, ensure_ascii=False, indent=4) | |
# Calculate precision, recall and f1 value | |
# According to https://github.com/dice-group/gerbil/wiki/Precision,-Recall-and-F1-measure | |
def get_prf(res): | |
if res['TP'] == 0: | |
if res['FP'] == 0 and res['FN'] == 0: | |
precision = 1.0 | |
recall = 1.0 | |
f1 = 1.0 | |
else: | |
precision = 0.0 | |
recall = 0.0 | |
f1 = 0.0 | |
else: | |
precision = 1.0 * res['TP'] / (res['TP'] + res['FP']) | |
recall = 1.0 * res['TP'] / (res['TP'] + res['FN']) | |
f1 = 2 * precision * recall / (precision + recall) | |
return precision, recall, f1 | |
def compute_metrics(truth, predicted, z_predicted, mask): | |
assert len(truth) == len(predicted) | |
outputs = [] | |
results = {} | |
cnt = 0 | |
z_cnt_h, z_cnt_s = 0, 0 | |
agree_h, agree_s = 0, 0 | |
for x, y, z, m in zip(truth, predicted, z_predicted, mask): | |
res = {'label': x, 'prediction': y} | |
if x == y: | |
cnt += 1 | |
res['pred_z'] = z | |
y_ = soft_logic(torch.tensor([z]), torch.tensor([m]))[0] | |
if y_.argmax(-1).item() == x: | |
z_cnt_s += 1 | |
if y_.argmax(-1).item() == y: | |
agree_s += 1 | |
z_h = torch.tensor(z[:torch.tensor(m).sum()]).argmax(-1).tolist() # m' x 3 | |
if 0 in z_h: # REFUTES | |
y__ = 0 | |
elif 1 in z_h: # NEI | |
y__ = 1 | |
else: # SUPPPORTS | |
y__ = 2 | |
if y__ == x: | |
z_cnt_h += 1 | |
if y__ == y: | |
agree_h += 1 | |
outputs.append(res) | |
results['Accuracy'] = cnt / len(truth) | |
results['z_Acc_hard'] = z_cnt_h / len(truth) | |
results['z_Acc_soft'] = z_cnt_s / len(truth) | |
results['Agreement_hard'] = agree_h / len(truth) | |
results['Agreement_soft'] = agree_s / len(truth) | |
return outputs, results | |