DarrenChensformer commited on
Commit
ba6a59b
1 Parent(s): 3360bee

Add main evaluation method

Browse files
Files changed (1) hide show
  1. relation_extraction.py +77 -5
relation_extraction.py CHANGED
@@ -15,6 +15,7 @@
15
 
16
  import evaluate
17
  import datasets
 
18
 
19
 
20
  # TODO: Add BibTeX citation
@@ -86,10 +87,81 @@ class relation_extraction(evaluate.Metric):
86
  # TODO: Download external resources if needed
87
  pass
88
 
89
- def _compute(self, predictions, references):
90
  """Returns the scores"""
91
  # TODO: Compute the different scores of the module
92
- accuracy = sum(i == j for i, j in zip(predictions, references)) / len(predictions)
93
- return {
94
- "accuracy": accuracy,
95
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  import evaluate
17
  import datasets
18
+ import numpy as np
19
 
20
 
21
  # TODO: Add BibTeX citation
 
87
  # TODO: Download external resources if needed
88
  pass
89
 
90
+ def _compute(self, pred_relations, gt_relations, mode="strict", relation_types=[]):
91
  """Returns the scores"""
92
  # TODO: Compute the different scores of the module
93
+
94
+ assert mode in ["strict", "boundaries"]
95
+
96
+ # construct relation_types from ground truth if not given
97
+ if len(relation_types) == 0:
98
+ for triplets in gt_relations:
99
+ for triplet in triplets:
100
+ relation = triplet["type"]
101
+ if relation not in relation_types:
102
+ relation_types.append(relation)
103
+
104
+ scores = {rel: {"tp": 0, "fp": 0, "fn": 0} for rel in relation_types + ["ALL"]}
105
+
106
+ # Count GT relations and Predicted relations
107
+ n_sents = len(gt_relations)
108
+ n_rels = sum([len([rel for rel in sent]) for sent in gt_relations])
109
+ n_found = sum([len([rel for rel in sent]) for sent in pred_relations])
110
+
111
+ # Count TP, FP and FN per type
112
+ for pred_sent, gt_sent in zip(pred_relations, gt_relations):
113
+ for rel_type in relation_types:
114
+ # strict mode takes argument types into account
115
+ if mode == "strict":
116
+ pred_rels = {(rel["head"], rel["head_type"], rel["tail"], rel["tail_type"]) for rel in pred_sent if
117
+ rel["type"] == rel_type}
118
+ gt_rels = {(rel["head"], rel["head_type"], rel["tail"], rel["tail_type"]) for rel in gt_sent if
119
+ rel["type"] == rel_type}
120
+
121
+ # boundaries mode only takes argument spans into account
122
+ elif mode == "boundaries":
123
+ pred_rels = {(rel["head"], rel["tail"]) for rel in pred_sent if rel["type"] == rel_type}
124
+ gt_rels = {(rel["head"], rel["tail"]) for rel in gt_sent if rel["type"] == rel_type}
125
+
126
+ scores[rel_type]["tp"] += len(pred_rels & gt_rels)
127
+ scores[rel_type]["fp"] += len(pred_rels - gt_rels)
128
+ scores[rel_type]["fn"] += len(gt_rels - pred_rels)
129
+
130
+ # Compute per entity Precision / Recall / F1
131
+ for rel_type in scores.keys():
132
+ if scores[rel_type]["tp"]:
133
+ scores[rel_type]["p"] = 100 * scores[rel_type]["tp"] / (scores[rel_type]["fp"] + scores[rel_type]["tp"])
134
+ scores[rel_type]["r"] = 100 * scores[rel_type]["tp"] / (scores[rel_type]["fn"] + scores[rel_type]["tp"])
135
+ else:
136
+ scores[rel_type]["p"], scores[rel_type]["r"] = 0, 0
137
+
138
+ if not scores[rel_type]["p"] + scores[rel_type]["r"] == 0:
139
+ scores[rel_type]["f1"] = 2 * scores[rel_type]["p"] * scores[rel_type]["r"] / (
140
+ scores[rel_type]["p"] + scores[rel_type]["r"])
141
+ else:
142
+ scores[rel_type]["f1"] = 0
143
+
144
+ # Compute micro F1 Scores
145
+ tp = sum([scores[rel_type]["tp"] for rel_type in relation_types])
146
+ fp = sum([scores[rel_type]["fp"] for rel_type in relation_types])
147
+ fn = sum([scores[rel_type]["fn"] for rel_type in relation_types])
148
+
149
+ if tp:
150
+ precision = 100 * tp / (tp + fp)
151
+ recall = 100 * tp / (tp + fn)
152
+ f1 = 2 * precision * recall / (precision + recall)
153
+
154
+ else:
155
+ precision, recall, f1 = 0, 0, 0
156
+
157
+ scores["ALL"]["p"] = precision
158
+ scores["ALL"]["r"] = recall
159
+ scores["ALL"]["f1"] = f1
160
+ scores["ALL"]["tp"] = tp
161
+ scores["ALL"]["fp"] = fp
162
+ scores["ALL"]["fn"] = fn
163
+
164
+ # Compute Macro F1 Scores
165
+ scores["ALL"]["Macro_f1"] = np.mean([scores[ent_type]["f1"] for ent_type in relation_types])
166
+ scores["ALL"]["Macro_p"] = np.mean([scores[ent_type]["p"] for ent_type in relation_types])
167
+ scores["ALL"]["Macro_r"] = np.mean([scores[ent_type]["r"] for ent_type in relation_types])