Spaces:
Runtime error
Runtime error
""" Official evaluation script for CUAD dataset. """ | |
import argparse | |
import json | |
import re | |
import string | |
import sys | |
import numpy as np | |
IOU_THRESH = 0.5 | |
def get_jaccard(prediction, ground_truth): | |
remove_tokens = [".", ",", ";", ":"] | |
for token in remove_tokens: | |
ground_truth = ground_truth.replace(token, "") | |
prediction = prediction.replace(token, "") | |
ground_truth, prediction = ground_truth.lower(), prediction.lower() | |
ground_truth, prediction = ground_truth.replace("/", " "), prediction.replace("/", " ") | |
ground_truth, prediction = set(ground_truth.split(" ")), set(prediction.split(" ")) | |
intersection = ground_truth.intersection(prediction) | |
union = ground_truth.union(prediction) | |
jaccard = len(intersection) / len(union) | |
return jaccard | |
def normalize_answer(s): | |
"""Lower text and remove punctuation, articles and extra whitespace.""" | |
def remove_articles(text): | |
return re.sub(r"\b(a|an|the)\b", " ", text) | |
def white_space_fix(text): | |
return " ".join(text.split()) | |
def remove_punc(text): | |
exclude = set(string.punctuation) | |
return "".join(ch for ch in text if ch not in exclude) | |
def lower(text): | |
return text.lower() | |
return white_space_fix(remove_articles(remove_punc(lower(s)))) | |
def compute_precision_recall(predictions, ground_truths, qa_id): | |
tp, fp, fn = 0, 0, 0 | |
substr_ok = "Parties" in qa_id | |
# first check if ground truth is empty | |
if len(ground_truths) == 0: | |
if len(predictions) > 0: | |
fp += len(predictions) # false positive for each one | |
else: | |
for ground_truth in ground_truths: | |
assert len(ground_truth) > 0 | |
# check if there is a match | |
match_found = False | |
for pred in predictions: | |
if substr_ok: | |
is_match = get_jaccard(pred, ground_truth) >= IOU_THRESH or ground_truth in pred | |
else: | |
is_match = get_jaccard(pred, ground_truth) >= IOU_THRESH | |
if is_match: | |
match_found = True | |
if match_found: | |
tp += 1 | |
else: | |
fn += 1 | |
# now also get any fps by looping through preds | |
for pred in predictions: | |
# Check if there's a match. if so, don't count (don't want to double count based on the above) | |
# but if there's no match, then this is a false positive. | |
# (Note: we get the true positives in the above loop instead of this loop so that we don't double count | |
# multiple predictions that are matched with the same answer.) | |
match_found = False | |
for ground_truth in ground_truths: | |
assert len(ground_truth) > 0 | |
if substr_ok: | |
is_match = get_jaccard(pred, ground_truth) >= IOU_THRESH or ground_truth in pred | |
else: | |
is_match = get_jaccard(pred, ground_truth) >= IOU_THRESH | |
if is_match: | |
match_found = True | |
if not match_found: | |
fp += 1 | |
precision = tp / (tp + fp) if tp + fp > 0 else np.nan | |
recall = tp / (tp + fn) if tp + fn > 0 else np.nan | |
return precision, recall | |
def process_precisions(precisions): | |
""" | |
Processes precisions to ensure that precision and recall don't both get worse. | |
Assumes the list precision is sorted in order of recalls | |
""" | |
precision_best = precisions[::-1] | |
for i in range(1, len(precision_best)): | |
precision_best[i] = max(precision_best[i - 1], precision_best[i]) | |
precisions = precision_best[::-1] | |
return precisions | |
def get_aupr(precisions, recalls): | |
processed_precisions = process_precisions(precisions) | |
aupr = np.trapz(processed_precisions, recalls) | |
if np.isnan(aupr): | |
return 0 | |
return aupr | |
def get_prec_at_recall(precisions, recalls, recall_thresh): | |
"""Assumes recalls are sorted in increasing order""" | |
processed_precisions = process_precisions(precisions) | |
prec_at_recall = 0 | |
for prec, recall in zip(processed_precisions, recalls): | |
if recall >= recall_thresh: | |
prec_at_recall = prec | |
break | |
return prec_at_recall | |
def exact_match_score(prediction, ground_truth): | |
return normalize_answer(prediction) == normalize_answer(ground_truth) | |
def metric_max_over_ground_truths(metric_fn, predictions, ground_truths): | |
score = 0 | |
for pred in predictions: | |
for ground_truth in ground_truths: | |
score = metric_fn(pred, ground_truth) | |
if score == 1: # break the loop when one prediction matches the ground truth | |
break | |
if score == 1: | |
break | |
return score | |
def compute_score(dataset, predictions): | |
f1 = exact_match = total = 0 | |
precisions = [] | |
recalls = [] | |
for article in dataset: | |
for paragraph in article["paragraphs"]: | |
for qa in paragraph["qas"]: | |
total += 1 | |
if qa["id"] not in predictions: | |
message = "Unanswered question " + qa["id"] + " will receive score 0." | |
print(message, file=sys.stderr) | |
continue | |
ground_truths = list(map(lambda x: x["text"], qa["answers"])) | |
prediction = predictions[qa["id"]] | |
precision, recall = compute_precision_recall(prediction, ground_truths, qa["id"]) | |
precisions.append(precision) | |
recalls.append(recall) | |
if precision == 0 and recall == 0: | |
f1 += 0 | |
else: | |
f1 += 2 * (precision * recall) / (precision + recall) | |
exact_match += metric_max_over_ground_truths(exact_match_score, prediction, ground_truths) | |
precisions = [x for _, x in sorted(zip(recalls, precisions))] | |
recalls.sort() | |
f1 = 100.0 * f1 / total | |
exact_match = 100.0 * exact_match / total | |
aupr = get_aupr(precisions, recalls) | |
prec_at_90_recall = get_prec_at_recall(precisions, recalls, recall_thresh=0.9) | |
prec_at_80_recall = get_prec_at_recall(precisions, recalls, recall_thresh=0.8) | |
return { | |
"exact_match": exact_match, | |
"f1": f1, | |
"aupr": aupr, | |
"prec_at_80_recall": prec_at_80_recall, | |
"prec_at_90_recall": prec_at_90_recall, | |
} | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Evaluation for CUAD") | |
parser.add_argument("dataset_file", help="Dataset file") | |
parser.add_argument("prediction_file", help="Prediction File") | |
args = parser.parse_args() | |
with open(args.dataset_file) as dataset_file: | |
dataset_json = json.load(dataset_file) | |
dataset = dataset_json["data"] | |
with open(args.prediction_file) as prediction_file: | |
predictions = json.load(prediction_file) | |
print(json.dumps(compute_score(dataset, predictions))) | |