DarrenChensformer commited on
Commit
4d4c4d4
1 Parent(s): 0daf005

Add weight sum result

Browse files
Files changed (1) hide show
  1. action_generation.py +120 -10
action_generation.py CHANGED
@@ -47,8 +47,8 @@ Examples:
47
  Examples should be written in doctest format, and should illustrate how
48
  to use the function.
49
 
50
- >>> my_new_module = evaluate.load("my_new_module")
51
- >>> results = my_new_module.compute(references=[0, 1], predictions=[0, 1])
52
  >>> print(results)
53
  {'accuracy': 1.0}
54
  """
@@ -56,6 +56,101 @@ Examples:
56
  # TODO: Define external resources urls if needed
57
  BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
61
  class action_generation(evaluate.Metric):
@@ -71,8 +166,8 @@ class action_generation(evaluate.Metric):
71
  inputs_description=_KWARGS_DESCRIPTION,
72
  # This defines the format of each prediction and reference
73
  features=datasets.Features({
74
- 'predictions': datasets.Value('int64'),
75
- 'references': datasets.Value('int64'),
76
  }),
77
  # Homepage of the module for documentation
78
  homepage="http://module.homepage",
@@ -86,10 +181,25 @@ class action_generation(evaluate.Metric):
86
  # TODO: Download external resources if needed
87
  pass
88
 
89
- def _compute(self, predictions, references):
 
 
 
90
  """Returns the scores"""
91
- # TODO: Compute the different scores of the module
92
- accuracy = sum(i == j for i, j in zip(predictions, references)) / len(predictions)
93
- return {
94
- "accuracy": accuracy,
95
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  Examples should be written in doctest format, and should illustrate how
48
  to use the function.
49
 
50
+ >>> metric = evaluate.load("DarrenChensformer/aciton_generation")
51
+ >>> results = metric.compute(references=[0, 1], predictions=[0, 1])
52
  >>> print(results)
53
  {'accuracy': 1.0}
54
  """
 
56
  # TODO: Define external resources urls if needed
57
  BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"
58
 
59
+ class BaseEvaluater:
60
+ eps = 1e-8
61
+
62
+ def __call__(self, preds, labels):
63
+ return self._compute(preds, labels)
64
+
65
+ def _compute(self, preds, labels):
66
+ # calculate precision, recall, f1
67
+ tp, fp, fn = 0, 0, 0
68
+ for pred, label in zip(preds, labels):
69
+ tp += len(set(pred) & set(label))
70
+ fp += len(set(pred) - set(label))
71
+ fn += len(set(label) - set(pred))
72
+ precision = tp / (tp + fp + self.eps)
73
+ recall = tp / (tp + fn + self.eps)
74
+ f1 = 2 * precision * recall / (precision + recall)
75
+
76
+ return {
77
+ "precision": round(precision, 4),
78
+ "recall": round(recall, 4),
79
+ "f1": round(f1, 4)
80
+ }
81
+
82
+ class ClassEvaluater(BaseEvaluater):
83
+ def __init__(self, valid_labels=None):
84
+ self.valid_labels = valid_labels
85
+
86
+ def __call__(self, preds, labels):
87
+ preds = map(self.extract_class, preds)
88
+ labels = map(self.extract_class, labels)
89
+ # helper function to extract valid tags
90
+ preds = list(map(self.extract_valid, preds))
91
+ labels = list(map(self.extract_valid, labels))
92
+ return self._compute(preds, labels)
93
+
94
+ def extract_valid(self, tags):
95
+ # TODO: if valid_labels is None:
96
+ tags = list(filter(lambda tag: tag in self.valid_labels, tags))
97
+ return tags
98
+
99
+ def extract_class(self, tags):
100
+ tags = map(lambda tag: tag.replace("/ ", "/"), tags)
101
+ tags = list(map(self.batch_extract_class, tags))
102
+ # deduplicate
103
+ tags = list(dict.fromkeys(tags))
104
+ return tags
105
+
106
+ def batch_extract_class(self, tag):
107
+ # filter out invalid tags
108
+ tag = tag.split('/')
109
+ if len(tag)==3:
110
+ _class = '/'.join(tag[:2])
111
+ elif len(tag)==4:
112
+ _class = '/'.join(tag[:3])
113
+ elif len(tag)==1:
114
+ _class = ''
115
+ else:
116
+ _class = None
117
+ if _class in self.valid_labels:
118
+ return _class
119
+ else:
120
+ return ""
121
+
122
+
123
+ class PhraseEvaluater(BaseEvaluater):
124
+ def __init__(self, valid_labels=None):
125
+ self.valid_labels = valid_labels
126
+
127
+ def __call__(self, preds, labels):
128
+ preds = map(self.extract_phrase, preds)
129
+ labels = map(self.extract_phrase, labels)
130
+ return self._compute(preds, labels)
131
+
132
+ def extract_phrase(self, tags):
133
+ tags = map(lambda tag: tag.replace("/ ", "/"), tags)
134
+ tags = list(map(self.batch_extract_phrase, tags))
135
+ # deduplicate
136
+ tags = list(dict.fromkeys(tags))
137
+ return tags
138
+
139
+ def batch_extract_phrase(self, phrase):
140
+ # filter out invalid tags
141
+ tag = phrase.split('/')
142
+ if len(tag)==3:
143
+ _class = '/'.join(tag[:2])
144
+ elif len(tag)==4:
145
+ _class = '/'.join(tag[:3])
146
+ elif len(tag)==1:
147
+ _class = ''
148
+ else:
149
+ _class = None
150
+ if _class in self.valid_labels:
151
+ return phrase.replace(_class, '')
152
+ else:
153
+ return ""
154
 
155
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
156
  class action_generation(evaluate.Metric):
 
166
  inputs_description=_KWARGS_DESCRIPTION,
167
  # This defines the format of each prediction and reference
168
  features=datasets.Features({
169
+ 'predictions': datasets.Sequence(datasets.Value('string')),
170
+ 'references': datasets.Sequence(datasets.Value('string')),
171
  }),
172
  # Homepage of the module for documentation
173
  homepage="http://module.homepage",
 
181
  # TODO: Download external resources if needed
182
  pass
183
 
184
+ def _compute(self, predictions, references,
185
+ valid_labels=None, detailed_scores=False,
186
+ weights={"class": 0.8, "phrase": 0.2}
187
+ ):
188
  """Returns the scores"""
189
+ weights = {"class": 0.8, "phrase": 0.2}
190
+ class_eval = ClassEvaluater(valid_labels)(predictions, references)
191
+ phrase_eval = PhraseEvaluater(valid_labels)(predictions, references)
192
+ weight_sum = {
193
+ key: round((class_eval[key] * weights["class"]) + (phrase_eval[key] * weights["phrase"]), 4)
194
+ for key in class_eval
195
+ }
196
+ if detailed_scores:
197
+ results = {
198
+ "class": class_eval,
199
+ "phrase": phrase_eval,
200
+ "weighted_sum": weight_sum
201
+ }
202
+ else:
203
+ results = weight_sum
204
+
205
+ return results