facat commited on
Commit
8af54b8
1 Parent(s): 507319c
Files changed (4) hide show
  1. .gitignore +2 -0
  2. pyproject.toml +14 -0
  3. tlem.py +225 -0
  4. utils.py +257 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__
2
+ tlem.ju.py
pyproject.toml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "tlem"
3
+ version = "0.1.0"
4
+ description = ""
5
+ authors = ["fecet <[email protected]>"]
6
+ readme = "README.md"
7
+
8
+ [tool.poetry.dependencies]
9
+ python = "3.10"
10
+
11
+
12
+ [build-system]
13
+ requires = ["poetry-core"]
14
+ build-backend = "poetry.core.masonry.api"
tlem.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+
3
+ try:
4
+ from ipytorch import logging
5
+ except Exception as e:
6
+ import logging
7
+
8
+ from typing import Any, Optional, Protocol, Iterable, Callable
9
+
10
+ # %%
11
+
12
+ # %cd ../tlem
13
+
14
+ # %load_ext ipytorch
15
+ # %ls
16
+ from utils import (
17
+ NUMERIC_IN_ZH,
18
+ extract_choice_ans,
19
+ extract_numeric,
20
+ get_answer,
21
+ is_equiv,
22
+ )
23
+
24
+
25
+ from dataclasses import dataclass, field
26
+ from datasets import load_dataset, Dataset
27
+ from functools import cached_property
28
+
29
+
30
+ TextGenerationPipeline = Callable[[Iterable[str]], list[str]]
31
+
32
+
33
+ from evaluate import EvaluationModule, Evaluator, evaluator, load
34
+
35
+
36
+ @dataclass
37
+ class Task:
38
+ dataset_name: str = "gsm8k"
39
+ dataset_params: dict = field(default_factory=dict)
40
+ # metrics: list[str] = field(default_factory=list)
41
+ metric_name: str | tuple[str, str] = "gsm8k"
42
+ input_column: str = "question"
43
+ label_column: str
44
+ prompt: Optional[Callable | str] = None
45
+
46
+ @cached_property
47
+ def samples(self):
48
+ return self.dataset[self.input_column]
49
+
50
+ @cached_property
51
+ def dataset(self):
52
+ ds = load_dataset(self.dataset_name, **self.dataset_params)
53
+ if self.prompt is not None:
54
+ ds = ds.map(
55
+ lambda example: {
56
+ self.input_column: self.prompt.format(
57
+ input_column=example[self.input_column]
58
+ )
59
+ }
60
+ if isinstance(self.prompt, str)
61
+ else self.prompt(example),
62
+ )
63
+
64
+ return ds
65
+
66
+ @cached_property
67
+ def metric(self):
68
+ metric = (
69
+ load(self.metric_name)
70
+ if isinstance(self.metric_name, str)
71
+ else load(*self.metric_name)
72
+ )
73
+ return metric
74
+
75
+ def run(self, pipeline: TextGenerationPipeline):
76
+ outputs = pipeline(self.samples)
77
+ return self.metric.compute(outputs, self.dataset[self.label_column])
78
+
79
+
80
+ class Metrics:
81
+ def gsm8k(responses: list[str], answers: list[str | int]):
82
+ scores = []
83
+ for response, answer in zip(responses, answers):
84
+ pred = extract_numeric(response)
85
+ gold = extract_numeric(answer) if isinstance(answer, str) else str(answer)
86
+ scores.append(1.0 * (pred == gold))
87
+ return scores
88
+
89
+ def MATH(responses: list[str], answers: list[str]):
90
+ scores = []
91
+
92
+ for response, answer in zip(responses, answers):
93
+ indices = [pos for pos, char in enumerate(response) if char == "$"]
94
+ if len(indices) <= 2:
95
+ scores.append(0)
96
+ continue
97
+ else:
98
+ result = response[indices[-2] + 1 : indices[-1]]
99
+ gold = get_answer(answer)
100
+ scores.append(1.0 * is_equiv(result, gold))
101
+
102
+ return scores
103
+
104
+ def math23k(responses: list[str], answers: list[str]):
105
+ scores = []
106
+ for response, answer in zip(responses, answers):
107
+ pred = extract_numeric(response, pattern=NUMERIC_IN_ZH)
108
+ gold = extract_numeric(answer, pattern=NUMERIC_IN_ZH)
109
+ scores.append(1.0 * (pred == gold))
110
+ return scores
111
+
112
+ def gsm8k_zh(responses: list[str], answers: list[str]):
113
+ scores = []
114
+ for response, answer in zip(responses, answers):
115
+ pred = extract_numeric(response, pattern=NUMERIC_IN_ZH)
116
+ gold = extract_numeric(answer)
117
+ scores.append(1.0 * (pred == gold))
118
+ return scores
119
+
120
+ def svamp(responses: list[float], answers: list[str]):
121
+ scores = []
122
+ for response, answer in zip(responses, answers):
123
+ pred = extract_numeric(response, pattern=NUMERIC_IN_ZH)
124
+ gold = answer
125
+ scores.append(1.0 * (float(pred) == gold))
126
+ return scores
127
+
128
+ def mmlu(responses, answers):
129
+ scores = []
130
+ for response, answer in zip(responses, answers):
131
+ pred = extract_choice_ans(response)
132
+ gold = answer.lower()
133
+ scores.append(1.0 * (pred == gold))
134
+ return scores
135
+
136
+
137
+ import evaluate
138
+ import numpy as np
139
+
140
+ import datasets
141
+
142
+
143
+ # TODO: Add BibTeX citation
144
+ _CITATION = """\
145
+ @InProceedings{huggingface:module,
146
+ title = {A great new module},
147
+ authors={huggingface, Inc.},
148
+ year={2020}
149
+ }
150
+ """
151
+
152
+ # TODO: Add description of the module here
153
+ _DESCRIPTION = """\
154
+ A simple measurement that returns the number of elements in dataset.
155
+ """
156
+
157
+
158
+ # TODO: Add description of the arguments of the module here
159
+ _KWARGS_DESCRIPTION = """
160
+ Calculates number of elements in dataset
161
+ Args:
162
+ data: list of elements.
163
+ Returns:
164
+ element_count: number of elements in dataset,
165
+ Examples:
166
+ >>> measure = evaluate.load("lvwerra/element_count")
167
+ >>> measure.compute(["a", "b", "c")
168
+ {"element_count": 3}
169
+ """
170
+
171
+ # TODO: Define external resources urls if needed
172
+ BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"
173
+
174
+
175
+ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
176
+ class ReasoningMetric(evaluate.Metric):
177
+ """TODO: Short description of my evaluation module."""
178
+
179
+ def _info(self):
180
+ features = datasets.Features(
181
+ {
182
+ "responses": datasets.Value("string"),
183
+ "references": datasets.Value("string"),
184
+ }
185
+ )
186
+
187
+ if self.config_name == "svamp":
188
+ features = datasets.Features(
189
+ {
190
+ "responses": datasets.Value("string"),
191
+ "references": datasets.Value("float"),
192
+ }
193
+ )
194
+
195
+ # TODO: Specifies the evaluate.EvaluationModuleInfo object
196
+ return evaluate.EvaluationModuleInfo(
197
+ # This is the description that will appear on the modules page.
198
+ # module_type="measurement",
199
+ description=_DESCRIPTION,
200
+ citation=_CITATION,
201
+ inputs_description=_KWARGS_DESCRIPTION,
202
+ # This defines the format of each prediction and reference
203
+ features=features,
204
+ # Homepage of the module for documentation
205
+ homepage="http://module.homepage",
206
+ # Additional links to the codebase or references
207
+ codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
208
+ reference_urls=["http://path.to.reference.url/new_module"],
209
+ )
210
+
211
+ def _compute(self, responses, references, verbose=False):
212
+ results = {}
213
+ scores = getattr(Metrics, self.config_name)(responses, references)
214
+ acc = np.asarray(scores).mean()
215
+ results = {
216
+ "accuracy": acc,
217
+ "scores": scores,
218
+ }
219
+
220
+ if verbose:
221
+ results["references"] = references
222
+ results["answers"] = responses
223
+ # results["scores"] = scores
224
+
225
+ return results
utils.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import re
3
+
4
+ NUMERIC_IN_EN = r"(?:[\s=+-/<>($:\.\*\\])(?=\S)((?:0|(?:\d{1,3}(?:,\d{3})+(?=\D|$))|(?:\d+))(?:\.\d+)?%?)(?:(?![^\s=+-/>)$:\.\*\\])|(?=, ))"
5
+ NUMERIC_IN_ZH = (
6
+ r"(?:\D|^)((?:0|(?:\d{1,3}(?:,\d{3})+(?=\D|$))|(?:\d+))(?:\.\d+)?%?)(?=\D|$)"
7
+ )
8
+
9
+
10
+ def extract_choice_ans(text):
11
+ pattern1 = r"\b[ABCDabcd]\b"
12
+ pattern2 = r"\([ABCDabcd]\)"
13
+ matches1 = re.findall(pattern1, text)
14
+ matches2 = re.findall(pattern2, text)
15
+ matches = matches1 + matches2
16
+
17
+ def standardize(ans):
18
+ return ans if len(ans) == 1 else ans[1]
19
+
20
+ return standardize(matches[-1]).lower() if matches else "_"
21
+
22
+
23
+ def extract_numeric(string, pattern=NUMERIC_IN_EN) -> str:
24
+ all_values = list(
25
+ filter(lambda x: len(x.strip()) != 0 and x != "%", re.findall(pattern, string))
26
+ )
27
+
28
+ def standardize(x):
29
+ y = "".join(x.split(","))
30
+ if "." in y:
31
+ y = y.rstrip("0")
32
+ if y[-1] == ".":
33
+ y = y[:-1]
34
+ if y[0] == ".":
35
+ y = "0" + y
36
+ if y[-1] == "%":
37
+ y = str(eval(y[:-1]) / 100)
38
+ return y
39
+
40
+ if not len(all_values):
41
+ logging.debug(f"No numeric value found in string: {string}")
42
+ value = string
43
+ else:
44
+ value = standardize(all_values[-1].strip())
45
+ return value
46
+
47
+
48
+ def remove_boxed(s):
49
+ if "\\boxed " in s:
50
+ left = "\\boxed "
51
+ assert s[: len(left)] == left
52
+ return s[len(left) :]
53
+
54
+ left = "\\boxed{"
55
+
56
+ assert s[: len(left)] == left
57
+ assert s[-1] == "}"
58
+
59
+ return s[len(left) : -1]
60
+
61
+
62
+ def last_boxed_only_string(string):
63
+ idx = string.rfind("\\boxed")
64
+ if "\\boxed " in string:
65
+ return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
66
+ if idx < 0:
67
+ idx = string.rfind("\\fbox")
68
+ if idx < 0:
69
+ return None
70
+
71
+ i = idx
72
+ right_brace_idx = None
73
+ num_left_braces_open = 0
74
+ while i < len(string):
75
+ if string[i] == "{":
76
+ num_left_braces_open += 1
77
+ if string[i] == "}":
78
+ num_left_braces_open -= 1
79
+ if num_left_braces_open == 0:
80
+ right_brace_idx = i
81
+ break
82
+ i += 1
83
+
84
+ if right_brace_idx is None:
85
+ retval = None
86
+ else:
87
+ retval = string[idx : right_brace_idx + 1]
88
+
89
+ return retval
90
+
91
+
92
+ def fix_sqrt(string):
93
+ if "\\sqrt" not in string:
94
+ return string
95
+ splits = string.split("\\sqrt")
96
+ new_string = splits[0]
97
+ for split in splits[1:]:
98
+ if split[0] != "{":
99
+ a = split[0]
100
+ new_substr = "\\sqrt{" + a + "}" + split[1:]
101
+ else:
102
+ new_substr = "\\sqrt" + split
103
+ new_string += new_substr
104
+ return new_string
105
+
106
+
107
+ def remove_right_units(string):
108
+ # "\\text{ " only ever occurs (at least in the val set) when describing units
109
+ if "\\text{ " in string:
110
+ splits = string.split("\\text{ ")
111
+ assert len(splits) == 2
112
+ return splits[0]
113
+ else:
114
+ return string
115
+
116
+
117
+ def fix_fracs(string):
118
+ substrs = string.split("\\frac")
119
+ new_str = substrs[0]
120
+ if len(substrs) > 1:
121
+ substrs = substrs[1:]
122
+ for substr in substrs:
123
+ new_str += "\\frac"
124
+ if substr[0] == "{":
125
+ new_str += substr
126
+ else:
127
+ try:
128
+ assert len(substr) >= 2
129
+ except AssertionError:
130
+ return string
131
+ a = substr[0]
132
+ b = substr[1]
133
+ if b != "{":
134
+ if len(substr) > 2:
135
+ post_substr = substr[2:]
136
+ new_str += "{" + a + "}{" + b + "}" + post_substr
137
+ else:
138
+ new_str += "{" + a + "}{" + b + "}"
139
+ else:
140
+ if len(substr) > 2:
141
+ post_substr = substr[2:]
142
+ new_str += "{" + a + "}" + b + post_substr
143
+ else:
144
+ new_str += "{" + a + "}" + b
145
+ string = new_str
146
+ return string
147
+
148
+
149
+ def fix_a_slash_b(string):
150
+ if len(string.split("/")) != 2:
151
+ return string
152
+ a = string.split("/")[0]
153
+ b = string.split("/")[1]
154
+ try:
155
+ a = int(a)
156
+ b = int(b)
157
+ assert string == "{}/{}".format(a, b)
158
+ new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
159
+ return new_string
160
+ except Exception as e:
161
+ return string
162
+
163
+
164
+ def strip_string(string):
165
+ # linebreaks
166
+ string = string.replace("\n", "")
167
+
168
+ # remove inverse spaces
169
+ string = string.replace("\\!", "")
170
+
171
+ # replace \\ with \
172
+ string = string.replace("\\\\", "\\")
173
+
174
+ # replace tfrac and dfrac with frac
175
+ string = string.replace("tfrac", "frac")
176
+ string = string.replace("dfrac", "frac")
177
+
178
+ # remove \left and \right
179
+ string = string.replace("\\left", "")
180
+ string = string.replace("\\right", "")
181
+
182
+ # Remove circ (degrees)
183
+ string = string.replace("^{\\circ}", "")
184
+ string = string.replace("^\\circ", "")
185
+
186
+ # remove dollar signs
187
+ string = string.replace("\\$", "")
188
+
189
+ # remove units (on the right)
190
+ string = remove_right_units(string)
191
+
192
+ # remove percentage
193
+ string = string.replace("\\%", "")
194
+ string = string.replace("\%", "") # noqa: W605
195
+
196
+ # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
197
+ string = string.replace(" .", " 0.")
198
+ string = string.replace("{.", "{0.")
199
+ # if empty, return empty string
200
+ if len(string) == 0:
201
+ return string
202
+ if string[0] == ".":
203
+ string = "0" + string
204
+
205
+ # to consider: get rid of e.g. "k = " or "q = " at beginning
206
+ if len(string.split("=")) == 2:
207
+ if len(string.split("=")[0]) <= 2:
208
+ string = string.split("=")[1]
209
+
210
+ # fix sqrt3 --> sqrt{3}
211
+ string = fix_sqrt(string)
212
+
213
+ # remove spaces
214
+ string = string.replace(" ", "")
215
+
216
+ # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
217
+ string = fix_fracs(string)
218
+
219
+ # manually change 0.5 --> \frac{1}{2}
220
+ if string == "0.5":
221
+ string = "\\frac{1}{2}"
222
+
223
+ # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
224
+ # string = fix_a_slash_b(string)
225
+
226
+ return string
227
+
228
+
229
+ def get_answer(string):
230
+ try:
231
+ answer = remove_boxed(last_boxed_only_string(string))
232
+ # answer = strip_string(answer)
233
+ except Exception:
234
+ answer = string
235
+ return answer
236
+
237
+
238
+ def is_equiv(str1, str2, verbose=False):
239
+ if str1 is None and str2 is None:
240
+ print("WARNING: Both None")
241
+ return False
242
+ if str1 is None or str2 is None:
243
+ return False
244
+
245
+ try:
246
+ ss1 = strip_string(str1)
247
+ ss2 = strip_string(str2)
248
+ if verbose:
249
+ print(ss1, ss2)
250
+ return ss1 == ss2
251
+ except Exception:
252
+ return str1 == str2
253
+
254
+
255
+ if __name__ == "__main__":
256
+ num = extract_numeric("the answer is -1.5")
257
+ print(num)