JeffYang52415 commited on
Commit
952a3b5
·
unverified ·
1 Parent(s): b65e855

feat: add mgsm parser

Browse files
llmdataparser/mgsm_parser.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any, ClassVar
3
+
4
+ from llmdataparser.base_parser import HuggingFaceDatasetParser, HuggingFaceParseEntry
5
+ from llmdataparser.prompts import MGSM_SYSTEM_PROMPT
6
+
7
+
8
+ @dataclass(frozen=True, kw_only=True, slots=True)
9
+ class MGSMParseEntry(HuggingFaceParseEntry):
10
+ """Custom entry class for MGSM, with fields specific to this dataset parser."""
11
+
12
+ numerical_answer: int | float
13
+ equation_solution: str | None
14
+ language: str
15
+
16
+ @classmethod
17
+ def create(
18
+ cls,
19
+ prompt: str,
20
+ answer: str,
21
+ raw_question: str,
22
+ raw_answer: str,
23
+ numerical_answer: int | float,
24
+ equation_solution: str | None,
25
+ task_name: str,
26
+ language: str,
27
+ ) -> "MGSMParseEntry":
28
+ return cls(
29
+ prompt=prompt,
30
+ answer=answer,
31
+ raw_question=raw_question,
32
+ raw_answer=raw_answer,
33
+ numerical_answer=numerical_answer,
34
+ equation_solution=equation_solution,
35
+ task_name=task_name,
36
+ language=language,
37
+ )
38
+
39
+
40
+ class MGSMDatasetParser(HuggingFaceDatasetParser[MGSMParseEntry]):
41
+ """Parser for the MGSM (Multilingual Grade School Math) dataset."""
42
+
43
+ _data_source: ClassVar[str] = "juletxara/mgsm"
44
+ _default_task: ClassVar[str] = "en"
45
+ _task_names: ClassVar[list[str]] = [
46
+ "bn",
47
+ "de",
48
+ "en",
49
+ "es",
50
+ "fr",
51
+ "ja",
52
+ "ru",
53
+ "sw",
54
+ "te",
55
+ "th",
56
+ "zh",
57
+ ]
58
+ _default_system_prompt: ClassVar[str] = MGSM_SYSTEM_PROMPT
59
+
60
+ def process_entry(
61
+ self, row: dict[str, Any], task_name: str | None = None, **kwargs: Any
62
+ ) -> MGSMParseEntry:
63
+ """
64
+ Process a single MGSM entry.
65
+
66
+ Args:
67
+ row: Dictionary containing the MGSM entry fields
68
+ task_name: Language code for the current task
69
+
70
+ Returns:
71
+ MGSMParseEntry: Processed entry with prompt, answer, and metadata
72
+ """
73
+ task = task_name or self._get_current_task(row)
74
+ raw_question = row["question"]
75
+ raw_answer = row["answer"] if row["answer"] else ""
76
+ numerical_answer = row["answer_number"]
77
+ equation_solution = row["equation_solution"]
78
+
79
+ # Construct the prompt with the system prompt and question
80
+ prompt = f"{self._system_prompt}\n{raw_question}"
81
+
82
+ # Use numerical answer as string for the answer field if no detailed answer is provided
83
+ answer = raw_answer if raw_answer else str(numerical_answer)
84
+
85
+ return MGSMParseEntry.create(
86
+ prompt=prompt,
87
+ answer=answer,
88
+ raw_question=raw_question,
89
+ raw_answer=raw_answer,
90
+ numerical_answer=numerical_answer,
91
+ equation_solution=equation_solution,
92
+ task_name=task,
93
+ language=task,
94
+ )
95
+
96
+
97
+ if __name__ == "__main__":
98
+ from pprint import pprint
99
+
100
+ parser = MGSMDatasetParser()
101
+ parser.load(task_name="en") # Load French dataset
102
+ parser.parse()
103
+
104
+ parsed_data = parser.get_parsed_data
105
+ pprint(parsed_data[0].prompt)
106
+ pprint(parsed_data[0].answer)
107
+ pprint(parsed_data[0].raw_question)
108
+ pprint(parsed_data[0].numerical_answer)
109
+ pprint(parsed_data[0].language)
tests/test_mgsm_parser.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ from llmdataparser.mgsm_parser import MGSMDatasetParser, MGSMParseEntry
4
+
5
+
6
+ @pytest.fixture
7
+ def mgsm_parser():
8
+ """Create a MGSM parser instance for testing."""
9
+ return MGSMDatasetParser()
10
+
11
+
12
+ @pytest.fixture
13
+ def loaded_mgsm_parser(mgsm_parser):
14
+ """Create and load a MGSM parser instance with test split."""
15
+ mgsm_parser.load(task_name="en", split="test")
16
+ return mgsm_parser
17
+
18
+
19
+ @pytest.fixture
20
+ def sample_mgsm_entries():
21
+ """Create sample MGSM dataset entries for testing."""
22
+ return [
23
+ {
24
+ "question": "John has 5 apples and buys 3 more. How many apples does he have now?",
25
+ "answer": "Let's solve step by step:\n1) Initial apples = 5\n2) Bought apples = 3\n3) Total = 5 + 3 = 8\nJohn has 8 apples now.",
26
+ "answer_number": 8,
27
+ "equation_solution": "5 + 3 = 8",
28
+ "language": "en",
29
+ },
30
+ {
31
+ "question": "Juan tiene 5 manzanas y compra 3 más. ¿Cuántas manzanas tiene ahora?",
32
+ "answer": "Resolvamos paso a paso:\n1) Manzanas iniciales = 5\n2) Manzanas compradas = 3\n3) Total = 5 + 3 = 8\nJuan tiene 8 manzanas ahora.",
33
+ "answer_number": 8,
34
+ "equation_solution": "5 + 3 = 8",
35
+ "language": "es",
36
+ },
37
+ {
38
+ "question": "ジョンはリンゴを5個持っていて、さらに3個買います。今何個持っていますか?",
39
+ "answer": None, # Testing case with missing detailed answer
40
+ "answer_number": 8,
41
+ "equation_solution": "5 + 3 = 8",
42
+ "language": "ja",
43
+ },
44
+ ]
45
+
46
+
47
+ def test_mgsm_parse_entry_creation_valid():
48
+ """Test valid creation of MGSMParseEntry with all fields."""
49
+ entry = MGSMParseEntry.create(
50
+ prompt="Test prompt",
51
+ answer="Test answer",
52
+ raw_question="Test question",
53
+ raw_answer="Test answer",
54
+ numerical_answer=42,
55
+ equation_solution="21 * 2 = 42",
56
+ task_name="en",
57
+ language="en",
58
+ )
59
+
60
+ assert isinstance(entry, MGSMParseEntry)
61
+ assert entry.prompt == "Test prompt"
62
+ assert entry.answer == "Test answer"
63
+ assert entry.raw_question == "Test question"
64
+ assert entry.raw_answer == "Test answer"
65
+ assert entry.numerical_answer == 42
66
+ assert entry.equation_solution == "21 * 2 = 42"
67
+ assert entry.task_name == "en"
68
+ assert entry.language == "en"
69
+
70
+
71
+ def test_process_entry_with_detailed_answer(mgsm_parser, sample_mgsm_entries):
72
+ """Test processing entry with detailed answer in English."""
73
+ entry = mgsm_parser.process_entry(sample_mgsm_entries[0], task_name="en")
74
+
75
+ assert isinstance(entry, MGSMParseEntry)
76
+ assert entry.numerical_answer == 8
77
+ assert entry.equation_solution == "5 + 3 = 8"
78
+ assert "step by step" in entry.answer
79
+ assert entry.language == "en"
80
+ assert entry.task_name == "en"
81
+
82
+
83
+ def test_process_entry_without_detailed_answer(mgsm_parser, sample_mgsm_entries):
84
+ """Test processing entry without detailed answer (Japanese)."""
85
+ entry = mgsm_parser.process_entry(sample_mgsm_entries[2], task_name="ja")
86
+
87
+ assert isinstance(entry, MGSMParseEntry)
88
+ assert entry.numerical_answer == 8
89
+ assert entry.equation_solution == "5 + 3 = 8"
90
+ assert entry.answer == "8" # Should use numerical_answer as string
91
+ assert entry.language == "ja"
92
+ assert entry.task_name == "ja"
93
+
94
+
95
+ def test_process_entry_spanish(mgsm_parser, sample_mgsm_entries):
96
+ """Test processing Spanish entry."""
97
+ entry = mgsm_parser.process_entry(sample_mgsm_entries[1], task_name="es")
98
+
99
+ assert isinstance(entry, MGSMParseEntry)
100
+ assert entry.numerical_answer == 8
101
+ assert entry.equation_solution == "5 + 3 = 8"
102
+ assert "paso a paso" in entry.answer # Spanish for "step by step"
103
+ assert entry.language == "es"
104
+ assert entry.task_name == "es"
105
+
106
+
107
+ def test_mgsm_parser_initialization(mgsm_parser):
108
+ """Test MGSM parser initialization and properties."""
109
+ assert isinstance(mgsm_parser.task_names, list)
110
+ assert len(mgsm_parser.task_names) == 11 # 11 supported languages
111
+ assert mgsm_parser._data_source == "juletxara/mgsm"
112
+ assert mgsm_parser._default_task == "en"
113
+ assert all(lang in mgsm_parser.task_names for lang in ["en", "es", "ja", "zh"])
114
+ assert (
115
+ mgsm_parser.get_huggingface_link
116
+ == "https://huggingface.co/datasets/juletxara/mgsm"
117
+ )
118
+
119
+
120
+ @pytest.mark.integration
121
+ def test_load_dataset(loaded_mgsm_parser):
122
+ """Test loading the MGSM dataset."""
123
+ assert loaded_mgsm_parser.raw_data is not None
124
+ assert loaded_mgsm_parser.split_names == ["test"]
125
+ assert loaded_mgsm_parser._current_task == "en"
126
+
127
+
128
+ def test_parser_string_representation(loaded_mgsm_parser):
129
+ """Test string representation of MGSM parser."""
130
+ repr_str = str(loaded_mgsm_parser)
131
+ assert "MGSMDatasetParser" in repr_str
132
+ assert "juletxara/mgsm" in repr_str
133
+ assert "en" in repr_str
134
+ assert "loaded" in repr_str
135
+
136
+
137
+ @pytest.mark.integration
138
+ def test_different_languages_parsing(mgsm_parser):
139
+ """Test parsing different language versions."""
140
+ # Load and parse English
141
+ mgsm_parser.load(task_name="en", split="test")
142
+ mgsm_parser.parse(split_names="test", force=True)
143
+ en_count = len(mgsm_parser.get_parsed_data)
144
+
145
+ # Load and parse Spanish
146
+ mgsm_parser.load(task_name="es", split="test")
147
+ mgsm_parser.parse(split_names="test", force=True)
148
+ es_count = len(mgsm_parser.get_parsed_data)
149
+
150
+ assert en_count > 0
151
+ assert es_count > 0
152
+ assert en_count == es_count # Should have same number of problems in each language
153
+
154
+
155
+ @pytest.mark.parametrize("language", ["en", "es", "ja", "zh", "ru"])
156
+ def test_supported_languages(mgsm_parser, language):
157
+ """Test that each supported language can be processed."""
158
+ test_entry = {
159
+ "question": f"Test question in {language}",
160
+ "answer": f"Test answer in {language}",
161
+ "answer_number": 42,
162
+ "equation_solution": "21 * 2 = 42",
163
+ }
164
+
165
+ entry = mgsm_parser.process_entry(test_entry, task_name=language)
166
+ assert entry.language == language
167
+ assert entry.task_name == language
168
+ assert entry.numerical_answer == 42
169
+
170
+
171
+ def test_system_prompt_override(mgsm_parser):
172
+ """Test overriding the default system prompt."""
173
+ custom_prompt = "Custom system prompt for testing"
174
+ parser = MGSMDatasetParser(system_prompt=custom_prompt)
175
+
176
+ test_entry = {
177
+ "question": "Test question",
178
+ "answer": "Test answer",
179
+ "answer_number": 42,
180
+ "equation_solution": "42",
181
+ }
182
+
183
+ entry = parser.process_entry(test_entry, task_name="en")
184
+ assert custom_prompt in entry.prompt