facat commited on
Commit
9827786
1 Parent(s): 6d6787f
Files changed (3) hide show
  1. tasks.py +11 -12
  2. tlem.py +23 -26
  3. utils.py +32 -159
tasks.py CHANGED
@@ -261,18 +261,17 @@ class Metrics:
261
  return responses, answers
262
 
263
  def MATH(responses: list[str], answers: list[str]):
264
- extract_responses = []
265
- for response in responses:
266
- indices = [pos for pos, char in enumerate(response) if char == "$"]
267
- if len(indices) <= 2:
268
- ans = ""
269
- else:
270
- ans = response[indices[-2] + 1 : indices[-1]]
271
- extract_responses.append(strip_string(ans))
272
- extract_answers=[]
273
- for answer in answers:
274
- extract_answers.append(strip_string(get_answer(answer)))
275
- return extract_responses, extract_answers
276
 
277
 
278
  class CMMLU:
 
261
  return responses, answers
262
 
263
  def MATH(responses: list[str], answers: list[str]):
264
+ extract_responses = sync_pipe(get_answer)(responses)
265
+ extract_answers = sync_pipe(get_answer)(answers)
266
+ try:
267
+ from math_equivalence import is_equiv
268
+ except ImportError as e:
269
+ logging.warning(
270
+ "math_equivalence not installed, pip install git+https://github.com/hendrycks/math.git"
271
+ )
272
+ raise e
273
+
274
+ return sync_pipe(is_equiv)(zip(extract_responses, extract_answers))
 
275
 
276
 
277
  class CMMLU:
tlem.py CHANGED
@@ -1,12 +1,9 @@
1
- # %%
2
-
3
  try:
4
  from ipytorch import logging
5
  except Exception as e:
6
  import logging
7
 
8
  from typing import Any, Optional, Protocol, Iterable, Callable
9
- from numpy.lib import extract
10
  from tqdm.auto import tqdm
11
  from evaluate.evaluation_suite import EvaluationSuite
12
  import evaluate
@@ -14,7 +11,7 @@ import numpy as np
14
  import datasets
15
  import pandas as pd
16
  from .tasks import *
17
- from .utils import is_equiv
18
 
19
 
20
  class ReasoningMetric(evaluate.Metric):
@@ -46,27 +43,27 @@ class ReasoningMetric(evaluate.Metric):
46
  reference_urls=["http://path.to.reference.url/new_module"],
47
  )
48
 
49
- def _compute(self, responses, references, verbose=False):
50
- extract_responses, extract_references = getattr(Metrics, self.config_name)(
51
- responses, references
52
- )
53
- df = pd.DataFrame(
54
- {
55
- "responses": responses,
56
- "references": references,
57
- }
58
- )
59
- df["extract_responses"] = extract_responses
60
- df["extract_references"] = extract_references
61
- # print(df)
62
- results = {
63
- "Accuracy": (df["extract_references"] == df["extract_responses"])
64
- .astype(int)
65
- .mean(),
66
- }
67
- logging.info(results)
68
- if verbose:
69
- results["df"] = df
70
  return results
71
 
72
 
@@ -139,7 +136,7 @@ class Suite(EvaluationSuite):
139
  suite = Task(
140
  dataset_name="hendrycks/competition_math",
141
  split="test",
142
- prompt="This is a math problem, please think step by step and slove it: {input_column}",
143
  metric_name=("sustech/tlem", "MATH"),
144
  input_column="problem",
145
  label_column="solution",
 
 
 
1
  try:
2
  from ipytorch import logging
3
  except Exception as e:
4
  import logging
5
 
6
  from typing import Any, Optional, Protocol, Iterable, Callable
 
7
  from tqdm.auto import tqdm
8
  from evaluate.evaluation_suite import EvaluationSuite
9
  import evaluate
 
11
  import datasets
12
  import pandas as pd
13
  from .tasks import *
14
+ from .utils import *
15
 
16
 
17
  class ReasoningMetric(evaluate.Metric):
 
43
  reference_urls=["http://path.to.reference.url/new_module"],
44
  )
45
 
46
+ def _compute(self, responses, references):
47
+ return_value = getattr(Metrics, self.config_name)(responses, references)
48
+ match return_value:
49
+ case tuple():
50
+ extract_responses, extract_references = return_value
51
+ results = {
52
+ self.config_name: np.mean(
53
+ sync_pipe(lambda x, y: x == y)(
54
+ zip(extract_responses, extract_references)
55
+ )
56
+ )
57
+ }
58
+ case dict():
59
+ results = return_value
60
+
61
+ case list():
62
+ results = {self.config_name: np.mean(return_value)}
63
+
64
+ case _:
65
+ raise NotImplementedError
66
+
67
  return results
68
 
69
 
 
136
  suite = Task(
137
  dataset_name="hendrycks/competition_math",
138
  split="test",
139
+ prompt="This is a math problem, please think step by step and slove it: {input_column}, simplify your final answer as much as possible and surround them with $ in TeX form",
140
  metric_name=("sustech/tlem", "MATH"),
141
  input_column="problem",
142
  label_column="solution",
utils.py CHANGED
@@ -2,6 +2,8 @@ import logging
2
  import re
3
  import numpy as np
4
  from typing import Any
 
 
5
 
6
  NUMERIC_IN_EN = r"(?:[\s=+-/<>($:\.\*\\])(?=\S)((?:0|(?:\d{1,3}(?:,\d{3})+(?=\D|$))|(?:\d+))(?:\.\d+)?%?)(?:(?![^\s=+-/>)$:\.\*\\])|(?=, ))"
7
  NUMERIC_IN_ZH = (
@@ -9,6 +11,28 @@ NUMERIC_IN_ZH = (
9
  )
10
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def return_blank_if_exception(func):
13
  def wrapper(*args, **kwargs):
14
  try:
@@ -155,166 +179,15 @@ def last_boxed_only_string(string):
155
  return retval
156
 
157
 
158
- def fix_sqrt(string):
159
- if "\\sqrt" not in string:
160
- return string
161
- splits = string.split("\\sqrt")
162
- new_string = splits[0]
163
- for split in splits[1:]:
164
- if split[0] != "{":
165
- a = split[0]
166
- new_substr = "\\sqrt{" + a + "}" + split[1:]
167
- else:
168
- new_substr = "\\sqrt" + split
169
- new_string += new_substr
170
- return new_string
171
-
172
-
173
- def remove_right_units(string):
174
- # "\\text{ " only ever occurs (at least in the val set) when describing units
175
- if "\\text{ " in string:
176
- splits = string.split("\\text{ ")
177
- # assert len(splits) == 2
178
- return splits[0]
179
- else:
180
- return string
181
-
182
-
183
- def fix_fracs(string):
184
- substrs = string.split("\\frac")
185
- new_str = substrs[0]
186
- if len(substrs) > 1:
187
- substrs = substrs[1:]
188
- for substr in substrs:
189
- new_str += "\\frac"
190
- if substr[0] == "{":
191
- new_str += substr
192
- else:
193
- try:
194
- assert len(substr) >= 2
195
- except AssertionError:
196
- return string
197
- a = substr[0]
198
- b = substr[1]
199
- if b != "{":
200
- if len(substr) > 2:
201
- post_substr = substr[2:]
202
- new_str += "{" + a + "}{" + b + "}" + post_substr
203
- else:
204
- new_str += "{" + a + "}{" + b + "}"
205
- else:
206
- if len(substr) > 2:
207
- post_substr = substr[2:]
208
- new_str += "{" + a + "}" + b + post_substr
209
- else:
210
- new_str += "{" + a + "}" + b
211
- string = new_str
212
- return string
213
-
214
-
215
- def fix_a_slash_b(string):
216
- if len(string.split("/")) != 2:
217
- return string
218
- a = string.split("/")[0]
219
- b = string.split("/")[1]
220
- try:
221
- a = int(a)
222
- b = int(b)
223
- assert string == "{}/{}".format(a, b)
224
- new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
225
- return new_string
226
- except Exception as e:
227
- return string
228
-
229
-
230
- def strip_string(string):
231
- # linebreaks
232
- string = string.replace("\n", "")
233
-
234
- # remove inverse spaces
235
- string = string.replace("\\!", "")
236
-
237
- # replace \\ with \
238
- string = string.replace("\\\\", "\\")
239
-
240
- # replace tfrac and dfrac with frac
241
- string = string.replace("tfrac", "frac")
242
- string = string.replace("dfrac", "frac")
243
-
244
- # remove \left and \right
245
- string = string.replace("\\left", "")
246
- string = string.replace("\\right", "")
247
-
248
- # Remove circ (degrees)
249
- string = string.replace("^{\\circ}", "")
250
- string = string.replace("^\\circ", "")
251
-
252
- # remove dollar signs
253
- string = string.replace("\\$", "")
254
-
255
- # remove units (on the right)
256
- string = remove_right_units(string)
257
-
258
- # remove percentage
259
- string = string.replace("\\%", "")
260
- string = string.replace("\%", "") # noqa: W605
261
-
262
- # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
263
- string = string.replace(" .", " 0.")
264
- string = string.replace("{.", "{0.")
265
- # if empty, return empty string
266
- if len(string) == 0:
267
- return string
268
- if string[0] == ".":
269
- string = "0" + string
270
-
271
- # to consider: get rid of e.g. "k = " or "q = " at beginning
272
- if len(string.split("=")) == 2:
273
- if len(string.split("=")[0]) <= 2:
274
- string = string.split("=")[1]
275
-
276
- # fix sqrt3 --> sqrt{3}
277
- string = fix_sqrt(string)
278
-
279
- # remove spaces
280
- string = string.replace(" ", "")
281
-
282
- # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
283
- string = fix_fracs(string)
284
-
285
- # manually change 0.5 --> \frac{1}{2}
286
- if string == "0.5":
287
- string = "\\frac{1}{2}"
288
-
289
- # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
290
- # string = fix_a_slash_b(string)
291
- string = string.split("=")[-1]
292
- while string.startswith("\\boxed{") and string.endswith("}"):
293
- string = string[7:-1]
294
- string = string.split("=")[-1]
295
- return string
296
-
297
-
298
  def get_answer(string):
299
- answer = remove_boxed(last_boxed_only_string(string))
300
- return answer
301
-
302
-
303
- def is_equiv(str1, str2, verbose=False):
304
- if str1 is None and str2 is None:
305
- print("WARNING: Both None")
306
- return False
307
- if str1 is None or str2 is None:
308
- return False
309
-
310
- try:
311
- ss1 = strip_string(str1)
312
- ss2 = strip_string(str2)
313
- if verbose:
314
- print(ss1, ss2)
315
- return ss1 == ss2
316
- except Exception:
317
- return str1 == str2
318
 
319
 
320
  def first_option_postprocess(text: str, options: str) -> str:
 
2
  import re
3
  import numpy as np
4
  from typing import Any
5
+ from tqdm.auto import tqdm
6
+ import asyncio
7
 
8
  NUMERIC_IN_EN = r"(?:[\s=+-/<>($:\.\*\\])(?=\S)((?:0|(?:\d{1,3}(?:,\d{3})+(?=\D|$))|(?:\d+))(?:\.\d+)?%?)(?:(?![^\s=+-/>)$:\.\*\\])|(?=, ))"
9
  NUMERIC_IN_ZH = (
 
11
  )
12
 
13
 
14
+ def async_pipe(func):
15
+ async def sync_function(samples):
16
+ if not isinstance(samples, list):
17
+ samples = [samples]
18
+ return await tqdm.gather(*[func(sample) for sample in samples], leave=False)
19
+
20
+ def sync_func(samples):
21
+ return asyncio.run(sync_function(samples))
22
+
23
+ return sync_func
24
+
25
+
26
+ def sync_pipe(func, progress=False):
27
+ def sync_func(samples):
28
+ return [
29
+ func(*sample) if isinstance(sample, tuple) else func(sample)
30
+ for sample in tqdm(samples, disable=not progress, leave=False)
31
+ ]
32
+
33
+ return sync_func
34
+
35
+
36
  def return_blank_if_exception(func):
37
  def wrapper(*args, **kwargs):
38
  try:
 
179
  return retval
180
 
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  def get_answer(string):
183
+ if boxed := last_boxed_only_string(string):
184
+ return remove_boxed(boxed)
185
+ else:
186
+ indices = [pos for pos, char in enumerate(string) if char == "$"]
187
+ if len(indices) < 2:
188
+ return extract_numeric(string)
189
+ string = string[indices[-2] + 1 : indices[-1]]
190
+ return string.split("=")[-1]
 
 
 
 
 
 
 
 
 
 
 
191
 
192
 
193
  def first_option_postprocess(text: str, options: str) -> str: