Spaces:
Runtime error
Runtime error
DarrenChensformer
commited on
Commit
•
ba6a59b
1
Parent(s):
3360bee
Add main evaluation method
Browse files- relation_extraction.py +77 -5
relation_extraction.py
CHANGED
@@ -15,6 +15,7 @@
|
|
15 |
|
16 |
import evaluate
|
17 |
import datasets
|
|
|
18 |
|
19 |
|
20 |
# TODO: Add BibTeX citation
|
@@ -86,10 +87,81 @@ class relation_extraction(evaluate.Metric):
|
|
86 |
# TODO: Download external resources if needed
|
87 |
pass
|
88 |
|
89 |
-
def _compute(self,
|
90 |
"""Returns the scores"""
|
91 |
# TODO: Compute the different scores of the module
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
import evaluate
|
17 |
import datasets
|
18 |
+
import numpy as np
|
19 |
|
20 |
|
21 |
# TODO: Add BibTeX citation
|
|
|
87 |
# TODO: Download external resources if needed
|
88 |
pass
|
89 |
|
90 |
+
def _compute(self, pred_relations, gt_relations, mode="strict", relation_types=[]):
|
91 |
"""Returns the scores"""
|
92 |
# TODO: Compute the different scores of the module
|
93 |
+
|
94 |
+
assert mode in ["strict", "boundaries"]
|
95 |
+
|
96 |
+
# construct relation_types from ground truth if not given
|
97 |
+
if len(relation_types) == 0:
|
98 |
+
for triplets in gt_relations:
|
99 |
+
for triplet in triplets:
|
100 |
+
relation = triplet["type"]
|
101 |
+
if relation not in relation_types:
|
102 |
+
relation_types.append(relation)
|
103 |
+
|
104 |
+
scores = {rel: {"tp": 0, "fp": 0, "fn": 0} for rel in relation_types + ["ALL"]}
|
105 |
+
|
106 |
+
# Count GT relations and Predicted relations
|
107 |
+
n_sents = len(gt_relations)
|
108 |
+
n_rels = sum([len([rel for rel in sent]) for sent in gt_relations])
|
109 |
+
n_found = sum([len([rel for rel in sent]) for sent in pred_relations])
|
110 |
+
|
111 |
+
# Count TP, FP and FN per type
|
112 |
+
for pred_sent, gt_sent in zip(pred_relations, gt_relations):
|
113 |
+
for rel_type in relation_types:
|
114 |
+
# strict mode takes argument types into account
|
115 |
+
if mode == "strict":
|
116 |
+
pred_rels = {(rel["head"], rel["head_type"], rel["tail"], rel["tail_type"]) for rel in pred_sent if
|
117 |
+
rel["type"] == rel_type}
|
118 |
+
gt_rels = {(rel["head"], rel["head_type"], rel["tail"], rel["tail_type"]) for rel in gt_sent if
|
119 |
+
rel["type"] == rel_type}
|
120 |
+
|
121 |
+
# boundaries mode only takes argument spans into account
|
122 |
+
elif mode == "boundaries":
|
123 |
+
pred_rels = {(rel["head"], rel["tail"]) for rel in pred_sent if rel["type"] == rel_type}
|
124 |
+
gt_rels = {(rel["head"], rel["tail"]) for rel in gt_sent if rel["type"] == rel_type}
|
125 |
+
|
126 |
+
scores[rel_type]["tp"] += len(pred_rels & gt_rels)
|
127 |
+
scores[rel_type]["fp"] += len(pred_rels - gt_rels)
|
128 |
+
scores[rel_type]["fn"] += len(gt_rels - pred_rels)
|
129 |
+
|
130 |
+
# Compute per entity Precision / Recall / F1
|
131 |
+
for rel_type in scores.keys():
|
132 |
+
if scores[rel_type]["tp"]:
|
133 |
+
scores[rel_type]["p"] = 100 * scores[rel_type]["tp"] / (scores[rel_type]["fp"] + scores[rel_type]["tp"])
|
134 |
+
scores[rel_type]["r"] = 100 * scores[rel_type]["tp"] / (scores[rel_type]["fn"] + scores[rel_type]["tp"])
|
135 |
+
else:
|
136 |
+
scores[rel_type]["p"], scores[rel_type]["r"] = 0, 0
|
137 |
+
|
138 |
+
if not scores[rel_type]["p"] + scores[rel_type]["r"] == 0:
|
139 |
+
scores[rel_type]["f1"] = 2 * scores[rel_type]["p"] * scores[rel_type]["r"] / (
|
140 |
+
scores[rel_type]["p"] + scores[rel_type]["r"])
|
141 |
+
else:
|
142 |
+
scores[rel_type]["f1"] = 0
|
143 |
+
|
144 |
+
# Compute micro F1 Scores
|
145 |
+
tp = sum([scores[rel_type]["tp"] for rel_type in relation_types])
|
146 |
+
fp = sum([scores[rel_type]["fp"] for rel_type in relation_types])
|
147 |
+
fn = sum([scores[rel_type]["fn"] for rel_type in relation_types])
|
148 |
+
|
149 |
+
if tp:
|
150 |
+
precision = 100 * tp / (tp + fp)
|
151 |
+
recall = 100 * tp / (tp + fn)
|
152 |
+
f1 = 2 * precision * recall / (precision + recall)
|
153 |
+
|
154 |
+
else:
|
155 |
+
precision, recall, f1 = 0, 0, 0
|
156 |
+
|
157 |
+
scores["ALL"]["p"] = precision
|
158 |
+
scores["ALL"]["r"] = recall
|
159 |
+
scores["ALL"]["f1"] = f1
|
160 |
+
scores["ALL"]["tp"] = tp
|
161 |
+
scores["ALL"]["fp"] = fp
|
162 |
+
scores["ALL"]["fn"] = fn
|
163 |
+
|
164 |
+
# Compute Macro F1 Scores
|
165 |
+
scores["ALL"]["Macro_f1"] = np.mean([scores[ent_type]["f1"] for ent_type in relation_types])
|
166 |
+
scores["ALL"]["Macro_p"] = np.mean([scores[ent_type]["p"] for ent_type in relation_types])
|
167 |
+
scores["ALL"]["Macro_r"] = np.mean([scores[ent_type]["r"] for ent_type in relation_types])
|