refactor: mgsm parser
Browse files- llmdataparser/mgsm_parser.py +86 -1
- tests/test_mgsm_parser.py +86 -0
llmdataparser/mgsm_parser.py
CHANGED
@@ -1,7 +1,12 @@
|
|
1 |
from dataclasses import dataclass
|
2 |
from typing import Any, ClassVar
|
3 |
|
4 |
-
from llmdataparser.base_parser import
|
|
|
|
|
|
|
|
|
|
|
5 |
from llmdataparser.prompts import MGSM_SYSTEM_PROMPT
|
6 |
|
7 |
|
@@ -93,6 +98,86 @@ class MGSMDatasetParser(HuggingFaceDatasetParser[MGSMParseEntry]):
|
|
93 |
language=task,
|
94 |
)
|
95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
if __name__ == "__main__":
|
98 |
from pprint import pprint
|
|
|
1 |
from dataclasses import dataclass
|
2 |
from typing import Any, ClassVar
|
3 |
|
4 |
+
from llmdataparser.base_parser import (
|
5 |
+
DatasetDescription,
|
6 |
+
EvaluationMetric,
|
7 |
+
HuggingFaceDatasetParser,
|
8 |
+
HuggingFaceParseEntry,
|
9 |
+
)
|
10 |
from llmdataparser.prompts import MGSM_SYSTEM_PROMPT
|
11 |
|
12 |
|
|
|
98 |
language=task,
|
99 |
)
|
100 |
|
101 |
+
def get_dataset_description(self) -> DatasetDescription:
|
102 |
+
"""Returns a description of the Multilingual Grade School Math dataset."""
|
103 |
+
return DatasetDescription.create(
|
104 |
+
name="Multilingual Grade School Math (MGSM)",
|
105 |
+
purpose="Evaluate multilingual chain-of-thought reasoning capabilities in mathematical problem solving",
|
106 |
+
source="https://huggingface.co/datasets/juletxara/mgsm",
|
107 |
+
language="Multilingual (11 languages)",
|
108 |
+
format="Word problems with numerical answers and solution steps",
|
109 |
+
characteristics=(
|
110 |
+
"Human-translated version of 250 GSM8K problems into 10 additional languages. "
|
111 |
+
"Each problem includes the original question from GSM8K, its translations, "
|
112 |
+
"numerical answer, and solution steps. The benchmark is designed to evaluate "
|
113 |
+
"language models' ability to perform mathematical reasoning across different languages."
|
114 |
+
),
|
115 |
+
citation="""@misc{shi2022language,
|
116 |
+
title={Language Models are Multilingual Chain-of-Thought Reasoners},
|
117 |
+
author={Freda Shi and Mirac Suzgun and Markus Freitag and Xuezhi Wang and Suraj Srivats and Soroush Vosoughi and Hyung Won Chung and Yi Tay and Sebastian Ruder and Denny Zhou and Dipanjan Das and Jason Wei},
|
118 |
+
year={2022},
|
119 |
+
eprint={2210.03057},
|
120 |
+
archivePrefix={arXiv},
|
121 |
+
primaryClass={cs.CL}
|
122 |
+
}
|
123 |
+
@article{cobbe2021gsm8k,
|
124 |
+
title={Training Verifiers to Solve Math Word Problems},
|
125 |
+
author={Cobbe, Karl and Kosaraju, Vineet and Bavarian, Mohammad and Chen, Mark and Jun, Heewoo and Kaiser, Lukasz and Plappert, Matthias and Tworek, Jerry and Hilton, Jacob and Nakano, Reiichiro and Hesse, Christopher and Schulman, John},
|
126 |
+
journal={arXiv preprint arXiv:2110.14168},
|
127 |
+
year={2021}
|
128 |
+
}""",
|
129 |
+
additional_info={
|
130 |
+
"languages": [
|
131 |
+
"Bengali",
|
132 |
+
"German",
|
133 |
+
"English",
|
134 |
+
"Spanish",
|
135 |
+
"French",
|
136 |
+
"Japanese",
|
137 |
+
"Russian",
|
138 |
+
"Swahili",
|
139 |
+
"Telugu",
|
140 |
+
"Thai",
|
141 |
+
"Chinese",
|
142 |
+
],
|
143 |
+
"size": "250 problems translated into each language",
|
144 |
+
"base_dataset": "GSM8K (Grade School Math 8K)",
|
145 |
+
},
|
146 |
+
)
|
147 |
+
|
148 |
+
def get_evaluation_metrics(self) -> list[EvaluationMetric]:
|
149 |
+
"""Returns the recommended evaluation metrics for MGSM dataset."""
|
150 |
+
return [
|
151 |
+
EvaluationMetric.create(
|
152 |
+
name="exact_match",
|
153 |
+
type="string",
|
154 |
+
description="Exact match comparison between predicted and correct numerical answers",
|
155 |
+
implementation="custom_exact_match",
|
156 |
+
primary=True,
|
157 |
+
),
|
158 |
+
EvaluationMetric.create(
|
159 |
+
name="solution_validity",
|
160 |
+
type="text",
|
161 |
+
description="Assessment of whether the solution steps are mathematically valid and complete",
|
162 |
+
implementation="custom_solution_validator",
|
163 |
+
primary=True,
|
164 |
+
),
|
165 |
+
EvaluationMetric.create(
|
166 |
+
name="step_accuracy",
|
167 |
+
type="numerical",
|
168 |
+
description="Accuracy of intermediate calculation steps (e.g., <<48/2=24>>)",
|
169 |
+
implementation="custom_step_accuracy",
|
170 |
+
primary=True,
|
171 |
+
),
|
172 |
+
EvaluationMetric.create(
|
173 |
+
name="cross_lingual_consistency",
|
174 |
+
type="comparison",
|
175 |
+
description="Consistency of model performance across different language versions of the same problem",
|
176 |
+
implementation="custom_language_comparator",
|
177 |
+
primary=False,
|
178 |
+
),
|
179 |
+
]
|
180 |
+
|
181 |
|
182 |
if __name__ == "__main__":
|
183 |
from pprint import pprint
|
tests/test_mgsm_parser.py
CHANGED
@@ -182,3 +182,89 @@ def test_system_prompt_override(mgsm_parser):
|
|
182 |
|
183 |
entry = parser.process_entry(test_entry, task_name="en")
|
184 |
assert custom_prompt in entry.prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
|
183 |
entry = parser.process_entry(test_entry, task_name="en")
|
184 |
assert custom_prompt in entry.prompt
|
185 |
+
|
186 |
+
|
187 |
+
def test_get_dataset_description(mgsm_parser):
|
188 |
+
"""Test dataset description generation."""
|
189 |
+
description = mgsm_parser.get_dataset_description()
|
190 |
+
|
191 |
+
assert description.name == "Multilingual Grade School Math (MGSM)"
|
192 |
+
assert "multilingual chain-of-thought reasoning" in description.purpose.lower()
|
193 |
+
assert "juletxara/mgsm" in description.source
|
194 |
+
assert description.language == "Multilingual (11 languages)"
|
195 |
+
assert "word problems" in description.format.lower()
|
196 |
+
assert "numerical answers" in description.format.lower()
|
197 |
+
assert "solution steps" in description.format.lower()
|
198 |
+
|
199 |
+
# Check characteristics
|
200 |
+
assert "250" in description.characteristics
|
201 |
+
assert "gsm8k" in description.characteristics.lower()
|
202 |
+
assert "translations" in description.characteristics.lower()
|
203 |
+
assert "mathematical reasoning" in description.characteristics.lower()
|
204 |
+
|
205 |
+
# Check citations
|
206 |
+
assert "shi2022language" in description.citation
|
207 |
+
assert "cobbe2021gsm8k" in description.citation
|
208 |
+
assert (
|
209 |
+
"Language Models are Multilingual Chain-of-Thought Reasoners"
|
210 |
+
in description.citation
|
211 |
+
)
|
212 |
+
assert "Training Verifiers to Solve Math Word Problems" in description.citation
|
213 |
+
|
214 |
+
# Check additional info
|
215 |
+
assert description.additional_info is not None
|
216 |
+
assert len(description.additional_info["languages"]) == 11
|
217 |
+
assert "English" in description.additional_info["languages"]
|
218 |
+
assert "Chinese" in description.additional_info["languages"]
|
219 |
+
assert (
|
220 |
+
description.additional_info["size"]
|
221 |
+
== "250 problems translated into each language"
|
222 |
+
)
|
223 |
+
assert description.additional_info["base_dataset"] == "GSM8K (Grade School Math 8K)"
|
224 |
+
|
225 |
+
|
226 |
+
def test_get_evaluation_metrics(mgsm_parser):
|
227 |
+
"""Test evaluation metrics generation."""
|
228 |
+
metrics = mgsm_parser.get_evaluation_metrics()
|
229 |
+
|
230 |
+
# Check total number of metrics
|
231 |
+
assert len(metrics) == 4
|
232 |
+
|
233 |
+
# Check primary metrics
|
234 |
+
primary_metrics = [m for m in metrics if m.primary]
|
235 |
+
assert len(primary_metrics) == 3
|
236 |
+
|
237 |
+
# Verify specific metrics exist with correct properties
|
238 |
+
metric_names = {m.name for m in metrics}
|
239 |
+
assert "exact_match" in metric_names
|
240 |
+
assert "solution_validity" in metric_names
|
241 |
+
assert "step_accuracy" in metric_names
|
242 |
+
assert "cross_lingual_consistency" in metric_names
|
243 |
+
|
244 |
+
# Check specific metric properties
|
245 |
+
exact_match_metric = next(m for m in metrics if m.name == "exact_match")
|
246 |
+
assert exact_match_metric.type == "string"
|
247 |
+
assert exact_match_metric.primary is True
|
248 |
+
assert "numerical answers" in exact_match_metric.description.lower()
|
249 |
+
assert "custom_exact_match" in exact_match_metric.implementation
|
250 |
+
|
251 |
+
solution_metric = next(m for m in metrics if m.name == "solution_validity")
|
252 |
+
assert solution_metric.type == "text"
|
253 |
+
assert solution_metric.primary is True
|
254 |
+
assert "mathematically valid" in solution_metric.description.lower()
|
255 |
+
assert "custom_solution_validator" in solution_metric.implementation
|
256 |
+
|
257 |
+
step_metric = next(m for m in metrics if m.name == "step_accuracy")
|
258 |
+
assert step_metric.type == "numerical"
|
259 |
+
assert step_metric.primary is True
|
260 |
+
assert "calculation steps" in step_metric.description.lower()
|
261 |
+
assert "custom_step_accuracy" in step_metric.implementation
|
262 |
+
|
263 |
+
# Check cross-lingual metric specifically
|
264 |
+
cross_lingual_metric = next(
|
265 |
+
m for m in metrics if m.name == "cross_lingual_consistency"
|
266 |
+
)
|
267 |
+
assert cross_lingual_metric.type == "comparison"
|
268 |
+
assert cross_lingual_metric.primary is False
|
269 |
+
assert "different language versions" in cross_lingual_metric.description.lower()
|
270 |
+
assert "custom_language_comparator" in cross_lingual_metric.implementation
|