DarrenChensformer commited on
Commit
749b801
1 Parent(s): ba6a59b

Fix nothing return

Browse files
Files changed (1) hide show
  1. relation_extraction.py +9 -7
relation_extraction.py CHANGED
@@ -87,7 +87,7 @@ class relation_extraction(evaluate.Metric):
87
  # TODO: Download external resources if needed
88
  pass
89
 
90
- def _compute(self, pred_relations, gt_relations, mode="strict", relation_types=[]):
91
  """Returns the scores"""
92
  # TODO: Compute the different scores of the module
93
 
@@ -95,7 +95,7 @@ class relation_extraction(evaluate.Metric):
95
 
96
  # construct relation_types from ground truth if not given
97
  if len(relation_types) == 0:
98
- for triplets in gt_relations:
99
  for triplet in triplets:
100
  relation = triplet["type"]
101
  if relation not in relation_types:
@@ -104,12 +104,12 @@ class relation_extraction(evaluate.Metric):
104
  scores = {rel: {"tp": 0, "fp": 0, "fn": 0} for rel in relation_types + ["ALL"]}
105
 
106
  # Count GT relations and Predicted relations
107
- n_sents = len(gt_relations)
108
- n_rels = sum([len([rel for rel in sent]) for sent in gt_relations])
109
- n_found = sum([len([rel for rel in sent]) for sent in pred_relations])
110
 
111
  # Count TP, FP and FN per type
112
- for pred_sent, gt_sent in zip(pred_relations, gt_relations):
113
  for rel_type in relation_types:
114
  # strict mode takes argument types into account
115
  if mode == "strict":
@@ -164,4 +164,6 @@ class relation_extraction(evaluate.Metric):
164
  # Compute Macro F1 Scores
165
  scores["ALL"]["Macro_f1"] = np.mean([scores[ent_type]["f1"] for ent_type in relation_types])
166
  scores["ALL"]["Macro_p"] = np.mean([scores[ent_type]["p"] for ent_type in relation_types])
167
- scores["ALL"]["Macro_r"] = np.mean([scores[ent_type]["r"] for ent_type in relation_types])
 
 
 
87
  # TODO: Download external resources if needed
88
  pass
89
 
90
+ def _compute(self, predictions, references, mode="strict", relation_types=[]):
91
  """Returns the scores"""
92
  # TODO: Compute the different scores of the module
93
 
 
95
 
96
  # construct relation_types from ground truth if not given
97
  if len(relation_types) == 0:
98
+ for triplets in references:
99
  for triplet in triplets:
100
  relation = triplet["type"]
101
  if relation not in relation_types:
 
104
  scores = {rel: {"tp": 0, "fp": 0, "fn": 0} for rel in relation_types + ["ALL"]}
105
 
106
  # Count GT relations and Predicted relations
107
+ n_sents = len(references)
108
+ n_rels = sum([len([rel for rel in sent]) for sent in references])
109
+ n_found = sum([len([rel for rel in sent]) for sent in predictions])
110
 
111
  # Count TP, FP and FN per type
112
+ for pred_sent, gt_sent in zip(predictions, references):
113
  for rel_type in relation_types:
114
  # strict mode takes argument types into account
115
  if mode == "strict":
 
164
  # Compute Macro F1 Scores
165
  scores["ALL"]["Macro_f1"] = np.mean([scores[ent_type]["f1"] for ent_type in relation_types])
166
  scores["ALL"]["Macro_p"] = np.mean([scores[ent_type]["p"] for ent_type in relation_types])
167
+ scores["ALL"]["Macro_r"] = np.mean([scores[ent_type]["r"] for ent_type in relation_types])
168
+
169
+ return scores