JeffYang52415 commited on
Commit
3aaa6f0
·
unverified ·
1 Parent(s): 0772011

refactor: mgsm parser

Browse files
llmdataparser/mgsm_parser.py CHANGED
@@ -1,7 +1,12 @@
1
  from dataclasses import dataclass
2
  from typing import Any, ClassVar
3
 
4
- from llmdataparser.base_parser import HuggingFaceDatasetParser, HuggingFaceParseEntry
 
 
 
 
 
5
  from llmdataparser.prompts import MGSM_SYSTEM_PROMPT
6
 
7
 
@@ -93,6 +98,86 @@ class MGSMDatasetParser(HuggingFaceDatasetParser[MGSMParseEntry]):
93
  language=task,
94
  )
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  if __name__ == "__main__":
98
  from pprint import pprint
 
1
  from dataclasses import dataclass
2
  from typing import Any, ClassVar
3
 
4
+ from llmdataparser.base_parser import (
5
+ DatasetDescription,
6
+ EvaluationMetric,
7
+ HuggingFaceDatasetParser,
8
+ HuggingFaceParseEntry,
9
+ )
10
  from llmdataparser.prompts import MGSM_SYSTEM_PROMPT
11
 
12
 
 
98
  language=task,
99
  )
100
 
101
+ def get_dataset_description(self) -> DatasetDescription:
102
+ """Returns a description of the Multilingual Grade School Math dataset."""
103
+ return DatasetDescription.create(
104
+ name="Multilingual Grade School Math (MGSM)",
105
+ purpose="Evaluate multilingual chain-of-thought reasoning capabilities in mathematical problem solving",
106
+ source="https://huggingface.co/datasets/juletxara/mgsm",
107
+ language="Multilingual (11 languages)",
108
+ format="Word problems with numerical answers and solution steps",
109
+ characteristics=(
110
+ "Human-translated version of 250 GSM8K problems into 10 additional languages. "
111
+ "Each problem includes the original question from GSM8K, its translations, "
112
+ "numerical answer, and solution steps. The benchmark is designed to evaluate "
113
+ "language models' ability to perform mathematical reasoning across different languages."
114
+ ),
115
+ citation="""@misc{shi2022language,
116
+ title={Language Models are Multilingual Chain-of-Thought Reasoners},
117
+ author={Freda Shi and Mirac Suzgun and Markus Freitag and Xuezhi Wang and Suraj Srivats and Soroush Vosoughi and Hyung Won Chung and Yi Tay and Sebastian Ruder and Denny Zhou and Dipanjan Das and Jason Wei},
118
+ year={2022},
119
+ eprint={2210.03057},
120
+ archivePrefix={arXiv},
121
+ primaryClass={cs.CL}
122
+ }
123
+ @article{cobbe2021gsm8k,
124
+ title={Training Verifiers to Solve Math Word Problems},
125
+ author={Cobbe, Karl and Kosaraju, Vineet and Bavarian, Mohammad and Chen, Mark and Jun, Heewoo and Kaiser, Lukasz and Plappert, Matthias and Tworek, Jerry and Hilton, Jacob and Nakano, Reiichiro and Hesse, Christopher and Schulman, John},
126
+ journal={arXiv preprint arXiv:2110.14168},
127
+ year={2021}
128
+ }""",
129
+ additional_info={
130
+ "languages": [
131
+ "Bengali",
132
+ "German",
133
+ "English",
134
+ "Spanish",
135
+ "French",
136
+ "Japanese",
137
+ "Russian",
138
+ "Swahili",
139
+ "Telugu",
140
+ "Thai",
141
+ "Chinese",
142
+ ],
143
+ "size": "250 problems translated into each language",
144
+ "base_dataset": "GSM8K (Grade School Math 8K)",
145
+ },
146
+ )
147
+
148
+ def get_evaluation_metrics(self) -> list[EvaluationMetric]:
149
+ """Returns the recommended evaluation metrics for MGSM dataset."""
150
+ return [
151
+ EvaluationMetric.create(
152
+ name="exact_match",
153
+ type="string",
154
+ description="Exact match comparison between predicted and correct numerical answers",
155
+ implementation="custom_exact_match",
156
+ primary=True,
157
+ ),
158
+ EvaluationMetric.create(
159
+ name="solution_validity",
160
+ type="text",
161
+ description="Assessment of whether the solution steps are mathematically valid and complete",
162
+ implementation="custom_solution_validator",
163
+ primary=True,
164
+ ),
165
+ EvaluationMetric.create(
166
+ name="step_accuracy",
167
+ type="numerical",
168
+ description="Accuracy of intermediate calculation steps (e.g., <<48/2=24>>)",
169
+ implementation="custom_step_accuracy",
170
+ primary=True,
171
+ ),
172
+ EvaluationMetric.create(
173
+ name="cross_lingual_consistency",
174
+ type="comparison",
175
+ description="Consistency of model performance across different language versions of the same problem",
176
+ implementation="custom_language_comparator",
177
+ primary=False,
178
+ ),
179
+ ]
180
+
181
 
182
  if __name__ == "__main__":
183
  from pprint import pprint
tests/test_mgsm_parser.py CHANGED
@@ -182,3 +182,89 @@ def test_system_prompt_override(mgsm_parser):
182
 
183
  entry = parser.process_entry(test_entry, task_name="en")
184
  assert custom_prompt in entry.prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
  entry = parser.process_entry(test_entry, task_name="en")
184
  assert custom_prompt in entry.prompt
185
+
186
+
187
+ def test_get_dataset_description(mgsm_parser):
188
+ """Test dataset description generation."""
189
+ description = mgsm_parser.get_dataset_description()
190
+
191
+ assert description.name == "Multilingual Grade School Math (MGSM)"
192
+ assert "multilingual chain-of-thought reasoning" in description.purpose.lower()
193
+ assert "juletxara/mgsm" in description.source
194
+ assert description.language == "Multilingual (11 languages)"
195
+ assert "word problems" in description.format.lower()
196
+ assert "numerical answers" in description.format.lower()
197
+ assert "solution steps" in description.format.lower()
198
+
199
+ # Check characteristics
200
+ assert "250" in description.characteristics
201
+ assert "gsm8k" in description.characteristics.lower()
202
+ assert "translations" in description.characteristics.lower()
203
+ assert "mathematical reasoning" in description.characteristics.lower()
204
+
205
+ # Check citations
206
+ assert "shi2022language" in description.citation
207
+ assert "cobbe2021gsm8k" in description.citation
208
+ assert (
209
+ "Language Models are Multilingual Chain-of-Thought Reasoners"
210
+ in description.citation
211
+ )
212
+ assert "Training Verifiers to Solve Math Word Problems" in description.citation
213
+
214
+ # Check additional info
215
+ assert description.additional_info is not None
216
+ assert len(description.additional_info["languages"]) == 11
217
+ assert "English" in description.additional_info["languages"]
218
+ assert "Chinese" in description.additional_info["languages"]
219
+ assert (
220
+ description.additional_info["size"]
221
+ == "250 problems translated into each language"
222
+ )
223
+ assert description.additional_info["base_dataset"] == "GSM8K (Grade School Math 8K)"
224
+
225
+
226
+ def test_get_evaluation_metrics(mgsm_parser):
227
+ """Test evaluation metrics generation."""
228
+ metrics = mgsm_parser.get_evaluation_metrics()
229
+
230
+ # Check total number of metrics
231
+ assert len(metrics) == 4
232
+
233
+ # Check primary metrics
234
+ primary_metrics = [m for m in metrics if m.primary]
235
+ assert len(primary_metrics) == 3
236
+
237
+ # Verify specific metrics exist with correct properties
238
+ metric_names = {m.name for m in metrics}
239
+ assert "exact_match" in metric_names
240
+ assert "solution_validity" in metric_names
241
+ assert "step_accuracy" in metric_names
242
+ assert "cross_lingual_consistency" in metric_names
243
+
244
+ # Check specific metric properties
245
+ exact_match_metric = next(m for m in metrics if m.name == "exact_match")
246
+ assert exact_match_metric.type == "string"
247
+ assert exact_match_metric.primary is True
248
+ assert "numerical answers" in exact_match_metric.description.lower()
249
+ assert "custom_exact_match" in exact_match_metric.implementation
250
+
251
+ solution_metric = next(m for m in metrics if m.name == "solution_validity")
252
+ assert solution_metric.type == "text"
253
+ assert solution_metric.primary is True
254
+ assert "mathematically valid" in solution_metric.description.lower()
255
+ assert "custom_solution_validator" in solution_metric.implementation
256
+
257
+ step_metric = next(m for m in metrics if m.name == "step_accuracy")
258
+ assert step_metric.type == "numerical"
259
+ assert step_metric.primary is True
260
+ assert "calculation steps" in step_metric.description.lower()
261
+ assert "custom_step_accuracy" in step_metric.implementation
262
+
263
+ # Check cross-lingual metric specifically
264
+ cross_lingual_metric = next(
265
+ m for m in metrics if m.name == "cross_lingual_consistency"
266
+ )
267
+ assert cross_lingual_metric.type == "comparison"
268
+ assert cross_lingual_metric.primary is False
269
+ assert "different language versions" in cross_lingual_metric.description.lower()
270
+ assert "custom_language_comparator" in cross_lingual_metric.implementation