yuchenlin commited on
Commit
1a0cf07
1 Parent(s): 55ce5e3

Upload 14 files

Browse files
README.md CHANGED
@@ -1,13 +1,9 @@
1
- ---
2
- title: SwiftSage
3
- emoji: 🔥
4
- colorFrom: purple
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 4.44.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ## 🤖 SwiftSage (v2):
2
+
3
+ > [!IMPORTANT]
4
+ > The code of SwiftSage v1 (for the experiments in NeurIPS 2023) is archived in the [`science_world`](https://github.com/SwiftSage/SwiftSage/tree/science_world) branch.
5
+
6
+
7
+ <!-- Github Readme Important Callout box note -->
8
+
9
+
 
 
 
 
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import json
4
+ import logging
5
+ import numpy as np
6
+ from utils import (PromptTemplate, api_configs, setup_logging)
7
+ from data_loader import load_data
8
+ from evaluate import evaluate
9
+ from main import SwiftSage, run_test, run_benchmark
10
+ import multiprocessing
11
+
12
+
13
+
14
+ def solve_problem(problem, max_iterations, reward_threshold, swift_model_id, sage_model_id, reward_model_id, use_retrieval, start_with_sage):
15
+ # Configuration for each LLM
16
+ max_iterations = int(max_iterations)
17
+ reward_threshold = int(reward_threshold)
18
+
19
+ swift_config = {
20
+ "model_id": swift_model_id,
21
+ "api_config": api_configs['Together']
22
+ }
23
+
24
+ reward_config = {
25
+ "model_id": reward_model_id,
26
+ "api_config": api_configs['Together']
27
+ }
28
+
29
+ sage_config = {
30
+ "model_id": sage_model_id,
31
+ "api_config": api_configs['Together']
32
+ }
33
+
34
+ # specify the path to the prompt templates
35
+ prompt_template_dir = './prompt_templates'
36
+ dataset = []
37
+ embeddings = [] # TODO: for retrieval augmentation (not implemented yet now)
38
+ s2 = SwiftSage(
39
+ dataset,
40
+ embeddings,
41
+ prompt_template_dir,
42
+ swift_config,
43
+ sage_config,
44
+ reward_config,
45
+ use_retrieval=use_retrieval,
46
+ start_with_sage=start_with_sage,
47
+ )
48
+
49
+ reasoning, solution = s2.solve(problem, max_iterations, reward_threshold)
50
+ solution = solution.replace("Answer (from running the code):\n ", " ")
51
+ return reasoning, solution
52
+
53
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
54
+ # gr.Markdown("## SwiftSage: A Multi-Agent Framework for Reasoning")
55
+ # use the html and center the title
56
+ gr.HTML("<h1 style='text-align: center;'>SwiftSage: A Multi-Agent Framework for Reasoning</h1>")
57
+
58
+ with gr.Row():
59
+ swift_model_id = gr.Textbox(label="😄 Swift Model ID", value="meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo")
60
+ reward_model_id = gr.Textbox(label="🤔 Feedback Model ID", value="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo")
61
+ sage_model_id = gr.Textbox(label="😎 Sage Model ID", value="meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo")
62
+ # the following two should have a smaller width
63
+
64
+ with gr.Accordion(label="⚙️ Advanced Options", open=False):
65
+ with gr.Row():
66
+ with gr.Column():
67
+ max_iterations = gr.Textbox(label="Max Iterations", value="5")
68
+ reward_threshold = gr.Textbox(label="Reward Threshold", value="8")
69
+ # TODO: add top-p and temperature for each module for controlling
70
+ with gr.Column():
71
+ top_p_swift = gr.Textbox(label="Top-p for Swift", value="0.9")
72
+ temperature_swift = gr.Textbox(label="Temperature for Swift", value="0.7")
73
+ with gr.Column():
74
+ top_p_sage = gr.Textbox(label="Top-p for Sage", value="0.9")
75
+ temperature_sage = gr.Textbox(label="Temperature for Sage", value="0.7")
76
+ with gr.Column():
77
+ top_p_reward = gr.Textbox(label="Top-p for Feedback", value="0.9")
78
+ temperature_reward = gr.Textbox(label="Temperature for Feedback", value="0.7")
79
+
80
+ use_retrieval = gr.Checkbox(label="Use Retrieval Augmentation", value=False, visible=False)
81
+ start_with_sage = gr.Checkbox(label="Start with Sage", value=False, visible=False)
82
+
83
+ problem = gr.Textbox(label="Input your problem", value="How many letter r are there in the sentence 'My strawberry is so ridiculously red.'?", lines=2)
84
+
85
+ solve_button = gr.Button("🚀 Solve Problem")
86
+ reasoning_output = gr.Textbox(label="Reasoning steps with Code", interactive=False)
87
+ solution_output = gr.Textbox(label="Final answer", interactive=False)
88
+
89
+ solve_button.click(
90
+ solve_problem,
91
+ inputs=[problem, max_iterations, reward_threshold, swift_model_id, sage_model_id, reward_model_id, use_retrieval, start_with_sage],
92
+ outputs=[reasoning_output, solution_output]
93
+ )
94
+
95
+ if __name__ == '__main__':
96
+ multiprocessing.set_start_method('spawn')
97
+ demo.launch(share=False, show_api=False)
code_executor.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Source and credits: https://github.com/ZubinGou/math-evaluation-harness/blob/main/python_executor.py
3
+
4
+ We modified it to be more simple.
5
+ """
6
+
7
+ import io
8
+ import pickle
9
+ import traceback
10
+ from concurrent.futures import ProcessPoolExecutor, TimeoutError
11
+ from contextlib import redirect_stdout
12
+
13
+
14
+ class GenericRuntime:
15
+ GLOBAL_DICT = {}
16
+ LOCAL_DICT = None
17
+ HEADERS = []
18
+
19
+ def __init__(self):
20
+ self._global_vars = self.GLOBAL_DICT.copy()
21
+ self._local_vars = self.LOCAL_DICT.copy() if self.LOCAL_DICT else None
22
+
23
+ for c in self.HEADERS:
24
+ self.exec_code(c)
25
+
26
+ def exec_code(self, code_piece: str) -> None:
27
+ exec(code_piece, self._global_vars)
28
+
29
+ def eval_code(self, expr: str) -> any:
30
+ return eval(expr, self._global_vars)
31
+
32
+ def inject(self, var_dict):
33
+ self._global_vars.update(var_dict)
34
+
35
+ @property
36
+ def answer(self):
37
+ return self._global_vars['answer']
38
+
39
+
40
+ class PythonExecutor:
41
+ def __init__(
42
+ self,
43
+ runtime=None,
44
+ get_answer_symbol=None,
45
+ get_answer_expr=None,
46
+ get_answer_from_stdout=False,
47
+ timeout_length=5,
48
+ ):
49
+ self.runtime = runtime if runtime else GenericRuntime()
50
+ self.answer_symbol = get_answer_symbol
51
+ self.get_answer_expr = get_answer_expr
52
+ self.get_answer_from_stdout = get_answer_from_stdout
53
+ self.timeout_length = timeout_length
54
+
55
+ def execute(self, code):
56
+ try:
57
+ if self.get_answer_from_stdout:
58
+ program_io = io.StringIO()
59
+ with redirect_stdout(program_io):
60
+ self.runtime.exec_code('\n'.join(code))
61
+ program_io.seek(0)
62
+ result = program_io.read()
63
+ elif self.answer_symbol:
64
+ self.runtime.exec_code('\n'.join(code))
65
+ result = self.runtime._global_vars[self.answer_symbol]
66
+ elif self.get_answer_expr:
67
+ self.runtime.exec_code('\n'.join(code))
68
+ result = self.runtime.eval_code(self.get_answer_expr)
69
+ else:
70
+ self.runtime.exec_code('\n'.join(code[:-1]))
71
+ result = self.runtime.eval_code(code[-1])
72
+
73
+ report = "Done"
74
+ pickle.dumps(result) # Serialization check
75
+ except Exception as e:
76
+ result = ''
77
+ report = str(e)
78
+
79
+ return result, report
80
+
81
+ def apply(self, code):
82
+ code_snippet = code.split('\n')
83
+
84
+ # Use ProcessPoolExecutor to enforce timeout
85
+ with ProcessPoolExecutor() as executor:
86
+ future = executor.submit(self.execute, code_snippet)
87
+ try:
88
+ result, report = future.result(timeout=self.timeout_length)
89
+ except TimeoutError:
90
+ result, report = "", "Timeout Error"
91
+
92
+ return result.strip(), report.strip()
93
+
94
+
95
+ # Example usage
96
+ if __name__ == "__main__":
97
+ executor = PythonExecutor(get_answer_from_stdout=True)
98
+ code = """
99
+ from sympy import Matrix
100
+
101
+ def null_space_basis():
102
+ A = Matrix([[3, 3, -1, -6], [9, -1, -8, -1], [7, 4, -2, -9]])
103
+ basis = A.nullspace()
104
+ return [v.evalf(3) for v in basis]
105
+
106
+ result = null_space_basis()
107
+ print(result)
108
+ """
109
+ result, report = executor.apply(code)
110
+ print("Result:", result)
111
+ print("Report:", report)
data_loader.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import re
4
+ import random
5
+ from typing import Any, Iterable, Union
6
+
7
+ from datasets import Dataset, concatenate_datasets, load_dataset
8
+
9
+ from data_utils import (
10
+ lower_keys,
11
+ parse_question,
12
+ parse_ground_truth,
13
+ )
14
+
15
+
16
+ def load_jsonl(file):
17
+ with open(file, "r", encoding="utf-8") as f:
18
+ for line in f:
19
+ try:
20
+ yield json.loads(line)
21
+ except:
22
+ print("Error in loading:", line)
23
+ exit()
24
+
25
+
26
+ def load_data(
27
+ data_name,
28
+ split='test',
29
+ data_dir='./data',
30
+ num_test_sample=-1,
31
+ ):
32
+ if data_name.lower() == "math":
33
+ data_name = 'MATH' # we use 500 problem test split in "Let's Verify Step-by-Step"
34
+ data_file = f"{data_dir}/{data_name}/{split}.jsonl"
35
+ if os.path.exists(data_file):
36
+ examples = list(load_jsonl(data_file))
37
+ else:
38
+ if data_name == "mmlu_stem":
39
+ dataset = load_dataset("hails/mmlu_no_train", 'all', split='test')
40
+ # only keep stem subjects
41
+ stem_subjects = ['abstract_algebra', 'astronomy', 'college_biology', 'college_chemistry',
42
+ 'college_computer_science', 'college_mathematics', 'college_physics', 'computer_security',
43
+ 'conceptual_physics', 'electrical_engineering', 'elementary_mathematics', 'high_school_biology',
44
+ 'high_school_chemistry', 'high_school_computer_science', 'high_school_mathematics',
45
+ 'high_school_physics', 'high_school_statistics', 'machine_learning']
46
+ dataset = dataset.rename_column("subject", "type")
47
+ dataset = dataset.filter(lambda x: x['type'] in stem_subjects)
48
+ elif data_name == "mathvista":
49
+ raise NotImplementedError(data_name)
50
+ elif data_name == "gpqa":
51
+ dataset = load_dataset("Idavidrein/gpqa", "gpqa_diamond", split="train")
52
+ elif data_name == "codeforces":
53
+ raise NotImplementedError(data_name)
54
+ else:
55
+ raise NotImplementedError(data_name)
56
+
57
+ examples = list(dataset)
58
+ examples = [lower_keys(example) for example in examples]
59
+ dataset = Dataset.from_list(examples)
60
+ os.makedirs(f"{data_dir}/{data_name}", exist_ok=True)
61
+ dataset.to_json(data_file)
62
+
63
+ # add 'idx' in the first column
64
+ if 'idx' not in examples[0]:
65
+ examples = [{'idx': i, **example} for i, example in enumerate(examples)]
66
+
67
+ # dedepulicate & sort
68
+ examples = sorted(examples, key=lambda x: x['idx'])
69
+
70
+ if num_test_sample > 0:
71
+ examples = examples[:num_test_sample]
72
+
73
+ return examples
74
+
75
+
76
+ if __name__ == "__main__":
77
+ examples = load_data("gpqa", "test")
78
+ print('test')
data_utils.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Source and credits: https://github.com/ZubinGou/math-evaluation-harness/blob/main/python_executor.py
3
+ """
4
+ import re
5
+ import regex
6
+ import sympy
7
+ from typing import TypeVar, Iterable, List, Union, Any, Dict
8
+ from word2number import w2n
9
+ from utils import *
10
+
11
+
12
+ def lower_keys(example):
13
+ new_example = {}
14
+ for key, value in example.items():
15
+ if key != key.lower():
16
+ new_key = key.lower()
17
+ new_example[new_key] = value
18
+ else:
19
+ new_example[key] = value
20
+ return new_example
21
+
22
+
23
+ def _fix_fracs(string):
24
+ substrs = string.split("\\frac")
25
+ new_str = substrs[0]
26
+ if len(substrs) > 1:
27
+ substrs = substrs[1:]
28
+ for substr in substrs:
29
+ new_str += "\\frac"
30
+ if len(substr) > 0 and substr[0] == "{":
31
+ new_str += substr
32
+ else:
33
+ try:
34
+ assert len(substr) >= 2
35
+ except:
36
+ return string
37
+ a = substr[0]
38
+ b = substr[1]
39
+ if b != "{":
40
+ if len(substr) > 2:
41
+ post_substr = substr[2:]
42
+ new_str += "{" + a + "}{" + b + "}" + post_substr
43
+ else:
44
+ new_str += "{" + a + "}{" + b + "}"
45
+ else:
46
+ if len(substr) > 2:
47
+ post_substr = substr[2:]
48
+ new_str += "{" + a + "}" + b + post_substr
49
+ else:
50
+ new_str += "{" + a + "}" + b
51
+ string = new_str
52
+ return string
53
+
54
+
55
+ def _fix_a_slash_b(string):
56
+ if len(string.split("/")) != 2:
57
+ return string
58
+ a = string.split("/")[0]
59
+ b = string.split("/")[1]
60
+ try:
61
+ if "sqrt" not in a:
62
+ a = int(a)
63
+ if "sqrt" not in b:
64
+ b = int(b)
65
+ assert string == "{}/{}".format(a, b)
66
+ new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
67
+ return new_string
68
+ except:
69
+ return string
70
+
71
+
72
+ def _fix_sqrt(string):
73
+ _string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string)
74
+ return _string
75
+
76
+
77
+ def convert_word_number(text:str) -> str:
78
+ try:
79
+ text = str(w2n.word_to_num(text))
80
+ except:
81
+ pass
82
+ return text
83
+
84
+ # units mainly from MathQA
85
+ unit_texts = [
86
+ "east", "degree", "mph", "kmph", "ft", "m sqaure", " m east", "sq m", "deg", "mile",
87
+ "q .", "monkey", "prime", "ratio", "profit of rs", "rd", "o", "gm",
88
+ "p . m", "lb", "tile", "per", "dm", "lt", "gain", "ab", "way", "west",
89
+ "a .", "b .", "c .", "d .", "e .", "f .", "g .", "h .", "t", "a", "h",
90
+ "no change", "men", "soldier", "pie", "bc", "excess", "st",
91
+ "inches", "noon", "percent", "by", "gal", "kmh", "c", "acre", "rise",
92
+ "a . m", "th", "π r 2", "sq", "mark", "l", "toy", "coin",
93
+ "sq . m", "gallon", "° f", "profit", "minw", "yr", "women",
94
+ "feet", "am", "pm", "hr", "cu cm", "square", "v â € ™", "are",
95
+ "rupee", "rounds", "cubic", "cc", "mtr", "s", "ohm", "number",
96
+ "kmph", "day", "hour", "minute", "min", "second", "man", "woman",
97
+ "sec", "cube", "mt", "sq inch", "mp", "∏ cm ³", "hectare", "more",
98
+ "sec", "unit", "cu . m", "cm 2", "rs .", "rs", "kg", "g", "month",
99
+ "km", "m", "cm", "mm", "apple", "liter", "loss", "yard",
100
+ "pure", "year", "increase", "decrease", "d", "less", "Surface",
101
+ "litre", "pi sq m", "s .", "metre", "meter", "inch",
102
+ ]
103
+
104
+ unit_texts.extend([t + "s" for t in unit_texts])
105
+
106
+ def strip_string(string):
107
+ string = str(string).strip()
108
+ # linebreaks
109
+ string = string.replace("\n", "")
110
+
111
+ # right "."
112
+ string = string.rstrip(".")
113
+
114
+ # remove inverse spaces
115
+ # replace \\ with \
116
+ string = string.replace("\\!", "")
117
+ # string = string.replace("\\ ", "")
118
+ # string = string.replace("\\\\", "\\")
119
+
120
+ # matrix
121
+ string = re.sub(r'\\begin\{array\}\{.*?\}', r'\\begin{pmatrix}', string)
122
+ string = re.sub(r'\\end\{array\}', r'\\end{pmatrix}', string)
123
+ string = string.replace("bmatrix", "pmatrix")
124
+
125
+
126
+ # replace tfrac and dfrac with frac
127
+ string = string.replace("tfrac", "frac")
128
+ string = string.replace("dfrac", "frac")
129
+
130
+ # remove \left and \right
131
+ string = string.replace("\\left", "")
132
+ string = string.replace("\\right", "")
133
+ string = string.replace("\\{", "{")
134
+ string = string.replace("\\}", "}")
135
+
136
+ # Remove unit: miles, dollars if after is not none
137
+ _string = re.sub(r"\\text{.*?}$", "", string).strip()
138
+ if _string != "" and _string != string:
139
+ # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
140
+ string = _string
141
+
142
+ # Remove unit: texts
143
+ for _ in range(2):
144
+ for unit_text in unit_texts:
145
+ # use regex, the prefix should be either the start of the string or a non-alphanumeric character
146
+ # the suffix should be either the end of the string or a non-alphanumeric character
147
+ _string = re.sub(r"(^|\W)" + unit_text + r"($|\W)", r"\1\2", string)
148
+ if _string != "":
149
+ string = _string
150
+
151
+ # Remove circ (degrees)
152
+ string = string.replace("^{\\circ}", "")
153
+ string = string.replace("^\\circ", "")
154
+
155
+ # remove dollar signs
156
+ string = string.replace("\\$", "")
157
+ string = string.replace("$", "")
158
+
159
+ # convert word number to digit
160
+ string = convert_word_number(string)
161
+
162
+ # replace "\\text{...}" to "..."
163
+ string = re.sub(r"\\text\{(.*?)\}", r"\1", string)
164
+ for key in ['x=', 'y=', 'z=', 'x\\in', 'y\\in', 'z\\in', 'x\\to', 'y\\to', 'z\\to']:
165
+ string = string.replace(key, "")
166
+ string = string.replace("\\emptyset", r"{}")
167
+ string = string.replace("(-\\infty,\\infty)", "\\mathbb{R}")
168
+
169
+ # remove percentage
170
+ string = string.replace("\\%", "")
171
+ string = string.replace("\%", "")
172
+ string = string.replace("%", "")
173
+
174
+ # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
175
+ string = string.replace(" .", " 0.")
176
+ string = string.replace("{.", "{0.")
177
+
178
+ # cdot
179
+ # string = string.replace("\\cdot", "")
180
+ if string.startswith("{") and string.endswith("}") and string.isalnum() or \
181
+ string.startswith("(") and string.endswith(")") and string.isalnum() or \
182
+ string.startswith("[") and string.endswith("]") and string.isalnum():
183
+ string = string[1:-1]
184
+
185
+ # inf
186
+ string = string.replace("infinity", "\\infty")
187
+ if "\\infty" not in string:
188
+ string = string.replace("inf", "\\infty")
189
+ string = string.replace("+\\inity", "\\infty")
190
+
191
+ # and
192
+ string = string.replace("and", "")
193
+ string = string.replace("\\mathbf", "")
194
+
195
+ # use regex to remove \mbox{...}
196
+ string = re.sub(r"\\mbox{.*?}", "", string)
197
+
198
+ # quote
199
+ string.replace("'", "")
200
+ string.replace("\"", "")
201
+
202
+ # i, j
203
+ if "j" in string and "i" not in string:
204
+ string = string.replace("j", "i")
205
+
206
+ # replace a.000b where b is not number or b is end, with ab, use regex
207
+ string = re.sub(r"(\d+)\.0*([^\d])", r"\1\2", string)
208
+ string = re.sub(r"(\d+)\.0*$", r"\1", string)
209
+
210
+ # if empty, return empty string
211
+ if len(string) == 0:
212
+ return string
213
+ if string[0] == ".":
214
+ string = "0" + string
215
+
216
+ # to consider: get rid of e.g. "k = " or "q = " at beginning
217
+ if len(string.split("=")) == 2:
218
+ if len(string.split("=")[0]) <= 2:
219
+ string = string.split("=")[1]
220
+
221
+ string = _fix_sqrt(string)
222
+ string = string.replace(" ", "")
223
+
224
+ # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
225
+ string = _fix_fracs(string)
226
+
227
+ # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
228
+ string = _fix_a_slash_b(string)
229
+
230
+ return string
231
+
232
+
233
+ def extract_multi_choice_answer(pred_str):
234
+ # TODO: SFT models
235
+ if 'Problem:' in pred_str:
236
+ pred_str = pred_str.split("Problem:", 1)[0]
237
+ pred_str = pred_str.replace("choice is", "answer is")
238
+ patt = regex.search(r"answer is \(?(?P<ans>[abcde])\)?", pred_str.lower())
239
+ if patt is not None:
240
+ return patt.group('ans').upper()
241
+ return 'placeholder'
242
+
243
+
244
+ def extract_answer(pred_str, data_name):
245
+ if data_name in ["mmlu_stem", "sat_math", "mathqa"]:
246
+ return extract_multi_choice_answer(pred_str)
247
+
248
+ if 'final answer is $' in pred_str and '$. I hope' in pred_str:
249
+ # minerva_math
250
+ tmp = pred_str.split('final answer is $', 1)[1]
251
+ pred = tmp.split('$. I hope', 1)[0].strip()
252
+ elif 'boxed' in pred_str:
253
+ ans = pred_str.split('boxed')[-1]
254
+ if len(ans) == 0:
255
+ return ""
256
+ elif ans[0] == '{':
257
+ stack = 1
258
+ a = ''
259
+ for c in ans[1:]:
260
+ if (c == '{'):
261
+ stack += 1
262
+ a += c
263
+ elif (c == '}'):
264
+ stack -= 1
265
+ if (stack == 0): break
266
+ a += c
267
+ else:
268
+ a += c
269
+ else:
270
+ a = ans.split('$')[0].strip()
271
+ pred = a
272
+ elif ('he answer is' in pred_str):
273
+ pred = pred_str.split('he answer is')[-1].strip()
274
+ elif ('final answer is' in pred_str):
275
+ pred = pred_str.split('final answer is')[-1].strip()
276
+ # elif extract_program_output(pred_str) != "":
277
+ # fall back to program
278
+ # pred = extract_program_output(pred_str)
279
+ else: # use the last number
280
+ pattern = '-?\d*\.?\d+'
281
+ pred = re.findall(pattern, pred_str.replace(",", ""))
282
+ if(len(pred) >= 1):
283
+ pred = pred[-1]
284
+ else: pred = ''
285
+
286
+ # multiple line
287
+ # pred = pred.split("\n")[0]
288
+ pred = re.sub(r"\n\s*", "", pred)
289
+ if pred != "" and pred[0] == ":":
290
+ pred = pred[1:]
291
+ if pred != "" and pred[-1] == ".":
292
+ pred = pred[:-1]
293
+ if pred != "" and pred[-1] == "/":
294
+ pred = pred[:-1]
295
+ pred = strip_string(pred)
296
+ return pred
297
+
298
+
299
+ def parse_ground_truth(example: Dict[str, Any], data_name):
300
+ # parse ground truth
301
+ if data_name in ["MATH", "math", "math_oai", "minerva_math", "ocw", "amps", "hungarian_exam"]:
302
+ gt_ans = example['answer']
303
+ elif data_name == "gsm8k":
304
+ gt_ans = example['answer'].split("####")[-1]
305
+ elif data_name == "mmlu_stem":
306
+ abcd = 'ABCD'
307
+ gt_ans = abcd[example['answer']]
308
+ elif data_name == "gpqa":
309
+ gt_ans = example['correct answer']
310
+ else:
311
+ raise NotImplementedError(f"`{data_name}`")
312
+ # post process
313
+ gt_ans = strip_string(gt_ans)
314
+ return gt_ans
315
+
316
+
317
+ def parse_question(example, data_name):
318
+ question = ""
319
+ if data_name == "mmlu_stem":
320
+ options = example['choices']
321
+ assert len(options) == 4
322
+ for i, (label, option) in enumerate(zip('ABCD', options)):
323
+ options[i] = f"({label}) {str(option).strip()}"
324
+ options = ", ".join(options)
325
+ question = f"{example['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options}"
326
+ else:
327
+ for key in ['question', 'problem', 'Question', 'input']:
328
+ if key in example:
329
+ question = example[key]
330
+ break
331
+ assert question != ""
332
+ # Yes or No question
333
+ gt_ans = parse_ground_truth(example, data_name)
334
+ gt_lower = gt_ans.lower()
335
+ if gt_lower in ["true", "false"]:
336
+ question += " (True or False)"
337
+ if gt_lower in ["yes", "no"]:
338
+ question += " (Yes or No)"
339
+ return question.strip()
340
+
341
+
342
+ def _test_extract_answer():
343
+ text= """
344
+ The answer is $\\boxed{\left(
345
+ \\begin{array}{ccc}
346
+ -13 & 4 & -2 \\\\
347
+ 7 & 8 & -3 \\\\
348
+ 0 & 18 & -7 \\\\
349
+ 6 & 12 & 5 \\\\
350
+ \\end{array}
351
+ \\right)}$.
352
+ """
353
+ print(extract_answer(text, "math"))
354
+ # should output a dict
355
+
356
+
357
+ if __name__ == "__main__":
358
+ _test_extract_answer()
evaluate.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Source and credits: https://github.com/ZubinGou/math-evaluation-harness/blob/main/python_executor.py
3
+ """
4
+ import argparse
5
+ import json
6
+ from concurrent.futures import TimeoutError
7
+
8
+ import numpy as np
9
+ from pebble import ProcessPool
10
+ from tqdm import tqdm
11
+
12
+ from grader import math_equal_process
13
+
14
+
15
+ def evaluate(samples: list=None, file_path: str=None):
16
+ assert samples or file_path, "samples or file_path must be provided"
17
+ if not samples:
18
+ with open(file_path, 'r') as f:
19
+ samples = [json.loads(line) for line in f]
20
+
21
+ # dedup by idx
22
+ if 'idx' in samples[0]:
23
+ samples = {sample['idx']: sample for sample in samples}.values()
24
+ samples = sorted(samples, key=lambda x: x['idx'])
25
+ else:
26
+ samples = [dict(idx=idx, **sample) for idx, sample in enumerate(samples)]
27
+
28
+ params = [(idx, sample['pred'], sample['gt']) for idx, sample in enumerate(samples)]
29
+
30
+ scores = []
31
+ timeout_cnt = 0
32
+
33
+ with ProcessPool() as pool:
34
+ future = pool.map(math_equal_process, params, timeout=3)
35
+ iterator = future.result()
36
+ with tqdm(total=len(samples), desc="Evaluate") as progress_bar:
37
+ while True:
38
+ try:
39
+ result = next(iterator)
40
+ scores.append(result)
41
+ except StopIteration:
42
+ break
43
+ except TimeoutError as error:
44
+ print(error)
45
+ scores.append(False)
46
+ timeout_cnt += 1
47
+ except Exception as error:
48
+ print(error.traceback)
49
+ exit()
50
+ progress_bar.update(1)
51
+
52
+ assert len(samples) == len(scores)
53
+
54
+ for i in range(len(samples)):
55
+ samples[i]['score'] = scores[i]
56
+
57
+ mean_score = np.round(np.mean([score for score in scores if score is not False]), decimals=2)
58
+
59
+ result_json = {
60
+ "num_samples": len(samples),
61
+ "num_scores": len(scores),
62
+ "timeout_samples": timeout_cnt,
63
+ "acc": mean_score
64
+ }
65
+
66
+ return samples, result_json
67
+
68
+
69
+ if __name__ == "__main__":
70
+ samples, results_json = evaluate(file_path="output/MATH.jsonl")
71
+ print('test')
grader.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Source and credits: https://github.com/ZubinGou/math-evaluation-harness/blob/main/python_executor.py
3
+ This logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from:
4
+ - https://github.com/microsoft/ProphetNet/tree/master/CRITIC
5
+ - https://github.com/openai/prm800k
6
+ - https://github.com/microsoft/ToRA/blob/main/src/eval/grader.py
7
+ - https://github.com/deepseek-ai/DeepSeek-Math/blob/main/evaluation/eval/eval_utils.py
8
+ """
9
+ import re
10
+ import regex
11
+ import multiprocessing
12
+ from math import isclose
13
+ from typing import Union
14
+
15
+ from sympy import simplify, N
16
+ from sympy.parsing.sympy_parser import parse_expr
17
+ from sympy.parsing.latex import parse_latex
18
+ from latex2sympy2 import latex2sympy
19
+
20
+
21
+ def parse_digits(num):
22
+ num = regex.sub(',', '', str(num))
23
+ try:
24
+ return float(num)
25
+ except:
26
+ if num.endswith('%'):
27
+ num = num[:-1]
28
+ if num.endswith('\\'):
29
+ num = num[:-1]
30
+ try:
31
+ return float(num) / 100
32
+ except:
33
+ pass
34
+ return None
35
+
36
+ def is_digit(num):
37
+ # paired with parse_digits
38
+ return parse_digits(num) is not None
39
+
40
+
41
+ def str_to_pmatrix(input_str):
42
+ input_str = input_str.strip()
43
+ matrix_str = re.findall(r'\{.*,.*\}', input_str)
44
+ pmatrix_list = []
45
+
46
+ for m in matrix_str:
47
+ m = m.strip('{}')
48
+ pmatrix = r'\begin{pmatrix}' + m.replace(',', '\\') + r'\end{pmatrix}'
49
+ pmatrix_list.append(pmatrix)
50
+
51
+ return ', '.join(pmatrix_list)
52
+
53
+
54
+ def math_equal(prediction: Union[bool, float, str],
55
+ reference: Union[float, str],
56
+ include_percentage: bool = True,
57
+ is_close: bool = True,
58
+ timeout: bool = False,
59
+ ) -> bool:
60
+ """
61
+ Exact match of math if and only if:
62
+ 1. numerical equal: both can convert to float and are equal
63
+ 2. symbolic equal: both can convert to sympy expression and are equal
64
+ """
65
+ # print("Judge:", prediction, reference)
66
+ if str(prediction) == str(reference):
67
+ return True
68
+
69
+ try: # 1. numerical equal
70
+ if is_digit(prediction) and is_digit(reference):
71
+ prediction = parse_digits(prediction)
72
+ reference = parse_digits(reference)
73
+ # number questions
74
+ if include_percentage:
75
+ gt_result = [reference / 100, reference, reference * 100]
76
+ else:
77
+ gt_result = [reference]
78
+ for item in gt_result:
79
+ try:
80
+ if is_close:
81
+ if numeric_equal(prediction, item):
82
+ return True
83
+ else:
84
+ if item == prediction:
85
+ return True
86
+ except Exception:
87
+ continue
88
+ return False
89
+ except:
90
+ pass
91
+
92
+ if not prediction and prediction not in [0, False]:
93
+ return False
94
+ # print("try math_eval")
95
+
96
+ # 2. symbolic equal
97
+ reference = str(reference).strip()
98
+ prediction = str(prediction).strip()
99
+
100
+ ## pmatrix (amps)
101
+ if "pmatrix" in prediction and not 'pmatrix' in reference:
102
+ reference = str_to_pmatrix(reference)
103
+
104
+ ## deal with [], (), {}
105
+ pred_str, ref_str = prediction, reference
106
+ if (prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(")) or \
107
+ (prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[")):
108
+ pred_str = pred_str.strip("[]()")
109
+ ref_str = ref_str.strip("[]()")
110
+ for s in ['{', "}", "(", ")"]:
111
+ ref_str = ref_str.replace(s, "")
112
+ pred_str = pred_str.replace(s, "")
113
+ if pred_str.lower() == ref_str.lower():
114
+ return True
115
+
116
+ ## [a, b] vs. [c, d], return a==c and b==d
117
+ if regex.match(r'(\(|\[).+(\)|\])', prediction) is not None and regex.match(r'(\(|\[).+(\)|\])', reference) is not None:
118
+ pred_parts = prediction[1:-1].split(",")
119
+ ref_parts = reference[1:-1].split(",")
120
+ if len(pred_parts) == len(ref_parts):
121
+ if all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]):
122
+ return True
123
+ if (prediction.startswith("\\begin{pmatrix}") or prediction.startswith("\\begin{bmatrix}")) and (prediction.endswith("\\end{pmatrix}") or prediction.endswith("\\end{bmatrix}")) and \
124
+ (reference.startswith("\\begin{pmatrix}") or reference.startswith("\\begin{bmatrix}")) and (reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")):
125
+ pred_lines = [line.strip() for line in prediction[len("\\begin{pmatrix}"): -len("\\end{pmatrix}")].split("\\\\") if line.strip()]
126
+ ref_lines = [line.strip() for line in reference[len("\\begin{pmatrix}"): -len("\\end{pmatrix}")].split("\\\\") if line.strip()]
127
+ matched = True
128
+ if len(pred_lines) == len(ref_lines):
129
+ for pred_line, ref_line in zip(pred_lines, ref_lines):
130
+ pred_parts = pred_line.split("&")
131
+ ref_parts = ref_line.split("&")
132
+ if len(pred_parts) == len(ref_parts):
133
+ if not all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]):
134
+ matched = False
135
+ break
136
+ else:
137
+ matched = False
138
+ if not matched:
139
+ break
140
+ else:
141
+ matched = False
142
+ if matched:
143
+ return True
144
+
145
+ if prediction.count('=') == 1 and reference.count('=') == 1:
146
+ pred = prediction.split('=')
147
+ pred = f"{pred[0].strip()} - ({pred[1].strip()})"
148
+ ref = reference.split('=')
149
+ ref = f"{ref[0].strip()} - ({ref[1].strip()})"
150
+ if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref):
151
+ return True
152
+ elif prediction.count('=') == 1 and len(prediction.split('=')[0].strip()) <= 2 and '=' not in reference:
153
+ if math_equal(prediction.split('=')[1], reference, include_percentage, is_close):
154
+ return True
155
+ elif reference.count('=') == 1 and len(reference.split('=')[0].strip()) <= 2 and '=' not in prediction:
156
+ if math_equal(prediction, reference.split('=')[1], include_percentage, is_close):
157
+ return True
158
+
159
+ # print("try final")
160
+ # symbolic equal with sympy
161
+ if timeout:
162
+ if call_with_timeout(symbolic_equal_process, prediction, reference):
163
+ return True
164
+ else:
165
+ if symbolic_equal(prediction, reference):
166
+ return True
167
+
168
+ return False
169
+
170
+
171
+ def math_equal_process(param):
172
+ return math_equal(param[-2], param[-1])
173
+
174
+
175
+ def numeric_equal(prediction: float, reference: float):
176
+ # Note that relative tolerance has significant impact
177
+ # on the result of the synthesized gsm_hard dataset
178
+ # if reference.is_integer():
179
+ # return isclose(reference, round(prediction), abs_tol=1e-4)
180
+ # else:
181
+ # prediction = round(prediction, len(str(reference).split(".")[-1]))
182
+ return isclose(reference, prediction, rel_tol=1e-4)
183
+
184
+
185
+ def symbolic_equal(a, b):
186
+ def _parse(s):
187
+ for f in [parse_latex, parse_expr, latex2sympy]:
188
+ try:
189
+ return f(s.replace("\\\\", "\\"))
190
+ except:
191
+ try:
192
+ return f(s)
193
+ except:
194
+ pass
195
+ return s
196
+ a = _parse(a)
197
+ b = _parse(b)
198
+
199
+ # direct equal
200
+ try:
201
+ if str(a) == str(b) or a == b:
202
+ return True
203
+ except:
204
+ pass
205
+
206
+ # print("try simplify")
207
+ # simplify equal
208
+ try:
209
+ if a.equals(b) or simplify(a-b) == 0:
210
+ return True
211
+ except:
212
+ pass
213
+
214
+ # print("try equation")
215
+ # equation equal
216
+ try:
217
+ if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)):
218
+ return True
219
+ except:
220
+ pass
221
+
222
+ try:
223
+ if numeric_equal(float(N(a)), float(N(b))):
224
+ return True
225
+ except:
226
+ pass
227
+
228
+ # matrix
229
+ try:
230
+ # if a and b are matrix
231
+ if a.shape == b.shape:
232
+ _a = a.applyfunc(lambda x: round(x, 3))
233
+ _b = b.applyfunc(lambda x: round(x, 3))
234
+ if _a.equals(_b):
235
+ return True
236
+ except:
237
+ pass
238
+
239
+ return False
240
+
241
+
242
+ def symbolic_equal_process(a, b, output_queue):
243
+ result = symbolic_equal(a, b)
244
+ output_queue.put(result)
245
+
246
+
247
+ def call_with_timeout(func, *args, timeout=1, **kwargs):
248
+ output_queue = multiprocessing.Queue()
249
+ process_args = args + (output_queue,)
250
+ process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs)
251
+ process.start()
252
+ process.join(timeout)
253
+
254
+ if process.is_alive():
255
+ process.terminate()
256
+ process.join()
257
+ return False
258
+
259
+ return output_queue.get()
260
+
261
+
262
+ def _test_math_equal():
263
+ # print(math_equal("0.0833333333333333", "\\frac{1}{12}"))
264
+ # print(math_equal("(1,4.5)", "(1,\\frac{9}{2})"))
265
+ # print(math_equal("\\frac{x}{7}+\\frac{2}{7}", "\\frac{x+2}{7}", timeout=True))
266
+ # print(math_equal("\\sec^2(y)", "\\tan^2(y)+1", timeout=True))
267
+ # print(math_equal("\\begin{pmatrix}-\\frac{7}{4}&-2\\\\4&\\frac{1}{4}\\end{pmatrix}", "(\\begin{pmatrix}-\\frac{7}{4}&-2\\\\4&\\frac{1}{4}\\\\\\end{pmatrix})", timeout=True))
268
+
269
+ # pred = '\\begin{pmatrix}\\frac{1}{3x^{2/3}}&0&0\\\\0&1&0\\\\-\\sin(x)&0&0\\end{pmatrix}'
270
+ # gt = '(\\begin{pmatrix}\\frac{1}{3\\sqrt[3]{x}^2}&0&0\\\\0&1&0\\\\-\\sin(x)&0&0\\\\\\end{pmatrix})'
271
+
272
+ # pred= '-\\frac{8x^2}{9(x^2-2)^{5/3}}+\\frac{2}{3(x^2-2)^{2/3}}'
273
+ # gt= '-\\frac{2(x^2+6)}{9(x^2-2)\\sqrt[3]{x^2-2}^2}'
274
+
275
+ # pred = '-34x-45y+20z-100=0'
276
+ # gt = '34x+45y-20z+100=0'
277
+
278
+ # pred = '\\frac{100}{3}'
279
+ # gt = '33.3'
280
+
281
+ # pred = '\\begin{pmatrix}0.290243531202435\\\\0.196008371385084\\\\-0.186381278538813\\end{pmatrix}'
282
+ # gt = '(\\begin{pmatrix}0.29\\\\0.196\\\\-0.186\\\\\\end{pmatrix})'
283
+
284
+ # pred = '\\frac{\\sqrt{\\sqrt{11}+\\sqrt{194}}}{2\\sqrt{33}+15}'
285
+ # gt = '\\frac{\\sqrt{\\sqrt{11}+\\sqrt{194}}}{15+2\\sqrt{33}}'
286
+
287
+ # pred = '(+5)(b+2)'
288
+ # gt = '(a+5)(b+2)'
289
+
290
+ # pred = '\\frac{1+\\sqrt{5}}{2}'
291
+ # gt = '2'
292
+
293
+ # pred = '\\frac{34}{16}+\\frac{\\sqrt{1358}}{16}', gt = '4'
294
+ # pred = '1', gt = '1\\\\sqrt{19}'
295
+
296
+ pred = '(0.6,2.6667]'
297
+ gt = '(\\frac{3}{5},\\frac{8}{3}]'
298
+
299
+ print(math_equal(pred, gt, timeout=True))
300
+
301
+
302
+ if __name__ == "__main__":
303
+ _test_math_equal()
304
+
305
+
main.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import json
4
+ import logging
5
+ import multiprocessing
6
+ import os
7
+ import re
8
+ from abc import ABC, abstractmethod
9
+
10
+ import hjson
11
+ import numpy as np
12
+ import openai
13
+ from tqdm import tqdm
14
+ from sklearn.metrics.pairwise import cosine_similarity
15
+
16
+ from data_loader import load_data
17
+ from code_executor import PythonExecutor
18
+ from utils import (Agent, LLMClient, PromptTemplate, api_configs,
19
+ extract_and_parse_markup, setup_logging)
20
+ from data_utils import parse_question, parse_ground_truth
21
+ from evaluate import evaluate
22
+
23
+
24
+ logger = setup_logging()
25
+
26
+ class RetrievalAugmentation:
27
+ # TODO: implement the retrieval augmentation later
28
+ def __init__(self, dataset, embeddings):
29
+ self.dataset = dataset
30
+ self.embeddings = embeddings
31
+
32
+ def get_similar_examples(self, query_embedding, n=3):
33
+ similarities = cosine_similarity([query_embedding], self.embeddings)[0]
34
+ top_indices = similarities.argsort()[-n:][::-1]
35
+ return [self.dataset[i] for i in top_indices]
36
+
37
+ class SwiftAgent(Agent):
38
+ def __init__(self, prompt_template, llm_client, retrieval_augmentation=None):
39
+ super().__init__(prompt_template, llm_client)
40
+ self.retrieval_augmentation = retrieval_augmentation
41
+ self.plans = {}
42
+ self.codes = {}
43
+
44
+ def generate_response(self, prompt, reasoning, current_solution, plan, critical_feedback, prefill=True):
45
+ logger.info("SwiftAgent generating response")
46
+ if self.retrieval_augmentation:
47
+ query_embedding = self.get_query_embedding(prompt)
48
+ similar_examples = self.retrieval_augmentation.get_similar_examples(query_embedding)
49
+ examples_text = "\n".join(similar_examples) # TODO: add more context to the prompt
50
+ else:
51
+ examples_text = "No similar examples available."
52
+
53
+ swift_prompt = self.prompt_template.format(
54
+ "swift",
55
+ prompt=prompt,
56
+ current_reasoning=reasoning, # TODO: check if this is needed
57
+ examples=examples_text,
58
+ current_solution=current_solution,
59
+ critical_feedback=critical_feedback,
60
+ revised_plan=plan
61
+ )
62
+ # logger.info(f"SwiftAgent prompt:\n{swift_prompt}")
63
+
64
+ messages = [
65
+ {"role": "system", "content": ''},
66
+ {"role": "user", "content": swift_prompt}
67
+ ]
68
+ if prefill:
69
+ messages.append({"role": "assistant", "content": "<plan>"}) # prefix-filling
70
+
71
+ response = self.llm_client.generate_response(messages)
72
+ if prefill:
73
+ response = "<plan>" + response
74
+
75
+ try:
76
+ parsed_response = extract_and_parse_markup(response)
77
+ return parsed_response
78
+ except json.JSONDecodeError:
79
+ logger.error("Error: Swift's response was not in valid JSON format. Returning raw response.")
80
+ return response
81
+
82
+ def get_query_embedding(self, query):
83
+ # Implement query embedding generation
84
+ return np.random.rand(768) # Placeholder, replace with actual embedding
85
+
86
+ class SageAgent(Agent):
87
+ def __init__(self, prompt_template, llm_client):
88
+ super().__init__(prompt_template, llm_client)
89
+ self.feedbacks = {}
90
+ self.plans = {}
91
+
92
+
93
+ def generate_response(self, prompt, reasoning, current_solution, prefill=True):
94
+ logger.info("SageAgent generating response")
95
+ sage_prompt = self.prompt_template.format(
96
+ "sage",
97
+ prompt=prompt,
98
+ reasoning=reasoning,
99
+ current_solution=current_solution
100
+ )
101
+ # logger.info(f"SageAgent prompt:\n{sage_prompt}")
102
+
103
+ messages = [
104
+ {"role": "system", "content": ""},
105
+ {"role": "user", "content": sage_prompt}
106
+ ]
107
+ if prefill:
108
+ messages.append({"role": "assistant", "content": "<solved>"}) # prefix-filling
109
+
110
+ response = self.llm_client.generate_response(messages)
111
+ # logger.info(f"SageAgent raw response:\n{response}")
112
+ if prefill:
113
+ response = "<solved>" + response
114
+ try:
115
+ parsed_response = extract_and_parse_markup(response)
116
+ return parsed_response
117
+ except json.JSONDecodeError:
118
+ logger.error("Error: Sage's response was not in valid JSON format. Returning raw response.")
119
+ return response
120
+
121
+ class RewardModel:
122
+ def __init__(self, prompt_template, llm_client):
123
+ self.prompt_template = prompt_template
124
+ self.llm_client = llm_client
125
+ self.scores = []
126
+ self.feedbacks = []
127
+ self.stagnant_count = 0
128
+
129
+ def calculate_reward(self, problem, reasoning, current_solution, prefill=True):
130
+ reward_prompt = self.prompt_template.format(
131
+ "reward",
132
+ problem=problem,
133
+ reasoning= reasoning,
134
+ current_solution=current_solution
135
+ )
136
+ # logger.info(f"RewardModel prompt:\n{reward_prompt}")
137
+
138
+ messages = [
139
+ {"role": "system", "content": ""},
140
+ {"role": "user", "content": reward_prompt}
141
+ ]
142
+ if prefill:
143
+ messages.append({"role": "assistant", "content": "<feedback>"}) # prefix-filling
144
+
145
+ reward_response = self.llm_client.generate_response(messages)
146
+ if prefill:
147
+ reward_response = "<feedback>" + reward_response
148
+
149
+ try:
150
+ parsed_response = extract_and_parse_markup(reward_response)
151
+ score = int(parsed_response["score"])
152
+
153
+ # Update stagnant_count based on score comparison
154
+ if len(self.scores) > 0 and score <= self.scores[-1]:
155
+ self.stagnant_count += 1
156
+ else:
157
+ self.stagnant_count = 0
158
+
159
+ return parsed_response
160
+ except json.JSONDecodeError:
161
+ logger.error("Error: Reward model's response was not in valid JSON format. Returning raw response.")
162
+ return reward_response
163
+
164
+ def should_consult_sage(self):
165
+ # This method remains unchanged
166
+ return self.stagnant_count >= 1 or (len(self.scores) > 0 and self.scores[-1] < 5)
167
+
168
+ class SwiftSage:
169
+ def __init__(self, dataset, embeddings, prompt_template_dir, swift_config, sage_config, reward_config, use_retrieval=True, start_with_sage=False):
170
+ prompt_template = PromptTemplate(prompt_template_dir)
171
+ retrieval_augmentation = RetrievalAugmentation(dataset, embeddings) if use_retrieval else None
172
+
173
+ # add logger to the following LLMClient
174
+ swift_llm = LLMClient(**swift_config, logger=logger)
175
+ sage_llm = LLMClient(**sage_config, logger=logger)
176
+ reward_llm = LLMClient(**reward_config, logger=logger)
177
+
178
+ self.swift = SwiftAgent(prompt_template, swift_llm, retrieval_augmentation)
179
+ self.sage = SageAgent(prompt_template, sage_llm)
180
+ self.reward_model = RewardModel(prompt_template, reward_llm)
181
+ self.start_with_sage = start_with_sage
182
+ # self.executor = PythonExecutor(get_answer_from_stdout=True)
183
+
184
+ def solve(self, problem, max_iterations=10, reward_threshold=8):
185
+ logger.info(f"Starting to solve problem: {problem}")
186
+ current_solution = "No current solution yet." # final answer
187
+ current_reasoning = "No reasoning steps yet." # reasoning steps
188
+ plan = "Initial plan: Take a deep breath and think step by step."
189
+ critical_feedback = "No critical feedback yet." # Initialize critical_feedback
190
+ solved = False
191
+ for i in range(max_iterations):
192
+ logger.info(f"Iteration {i+1}")
193
+
194
+
195
+ # Use the Sage Agent
196
+ if (i == 0 and self.start_with_sage) or self.reward_model.should_consult_sage():
197
+ sage_parsed = self.sage.generate_response(problem, current_reasoning, current_solution)
198
+ critical_feedback = sage_parsed["critical_feedback"]
199
+ # plan = "\n - " + "\n - ".join(sage_parsed["revised_plan"])
200
+ current_reasoning = sage_parsed["reasoning_steps"]
201
+ current_code = sage_parsed["code"]
202
+
203
+ solved = sage_parsed["solved"].lower() == "true" if i != 0 else sage_parsed["solved"]
204
+ if solved:
205
+ return current_reasoning, current_solution
206
+ logger.info(f"Sage's feedback (iteration {i+1}):\n{critical_feedback}")
207
+ # logger.info(f"Sage's reasoning steps:\n{current_reasoning}")
208
+ self.sage.feedbacks[i] = critical_feedback
209
+
210
+ # run the code
211
+ executor = PythonExecutor(get_answer_from_stdout=True)
212
+ code_result, code_report = executor.apply(current_code)
213
+ logger.info(f"Sage Code execution report: {code_report}")
214
+ logger.info(f"Sage Code execution result: {code_result}")
215
+ current_reasoning = current_reasoning + f"\n\nThe generated code is:\n\n```python\n{current_code}\n```"
216
+ current_solution = "Answer (from running the code):\n " + code_result
217
+
218
+ # current_solution = sage_parsed["final_answer"]
219
+ logger.info("Activated Sage, so we should return the reasoning and solution from Sage.")
220
+ return current_reasoning, current_solution
221
+
222
+ if not solved:
223
+ # Use the Swift Agent
224
+ swift_parsed = self.swift.generate_response(problem, current_reasoning, current_solution, plan, critical_feedback)
225
+
226
+ if "code" not in swift_parsed and "final_answer" not in swift_parsed:
227
+ logger.info("Swift's response does not contain the 'final_answer' or 'code' field. Returning raw response.")
228
+ self.reward_model.scores.append(0)
229
+ self.reward_model.feedbacks.append("No feedback")
230
+ self.reward_model.stagnant_count += max_iterations # force to use Sage Agent
231
+ continue
232
+
233
+ current_plan = swift_parsed["plan"]
234
+ current_code = swift_parsed["code"]
235
+ current_answer = swift_parsed.get("final_answer", None)
236
+
237
+ self.swift.plans[i] = current_plan
238
+ self.swift.codes[i] = current_code
239
+
240
+ logger.info(f"Swift's plan:\n{current_plan}")
241
+ logger.info(f"Swift's code:\n{current_code}")
242
+
243
+ # Call sandbox to run the code and get the result
244
+ executor = PythonExecutor(get_answer_from_stdout=True)
245
+ code_result, code_report = executor.apply(current_code)
246
+ logger.info(f"Code execution report: {code_report}")
247
+ logger.info(f"Code execution result: {code_result}")
248
+
249
+ current_reasoning = current_plan + f"\nThe generated code is:\n```python\n{current_code}\n```"
250
+ current_solution = "Answer (from running the code):\n " + code_result
251
+
252
+ # Calling the reward model to provide feedback and score
253
+ reward_parsed = self.reward_model.calculate_reward(problem, current_reasoning, current_solution)
254
+ score = int(reward_parsed["score"])
255
+ feedback = reward_parsed["feedback"]
256
+ prev_score = self.reward_model.scores[-1] if len(self.reward_model.scores) > 0 else 0
257
+ self.reward_model.scores.append(score)
258
+ self.reward_model.feedbacks.append(feedback)
259
+
260
+ # detect if the score is lower than the previous score
261
+ logger.info(f"Reward for iteration {i+1}: {score}/10")
262
+ logger.info(f"Feedback: {feedback}")
263
+
264
+ if False and score < prev_score:
265
+ logger.info("Score is lower than the previous score. Stopping the iteration. Reverting to the previous solution and reasoning.")
266
+ # revert to the previous solution and reasoning
267
+ current_solution = self.swift.codes[i-1]
268
+ current_reasoning = self.swift.plans[i-1]
269
+ continue
270
+
271
+
272
+ critical_feedback = feedback
273
+
274
+
275
+ if score >= reward_threshold or solved:
276
+ logger.info("Perfect solution found!")
277
+ return current_reasoning, current_solution
278
+
279
+
280
+ if self.reward_model.should_consult_sage():
281
+ logger.info("Reward model: The solution quality hasn't improved recently. Consulting Sage for the next iteration.")
282
+
283
+ logger.info("Max iterations reached without finding a perfect solution.")
284
+ logger.info("Problem solving completed")
285
+ return current_reasoning, current_solution
286
+
287
+
288
+ def run_test(swiftsage, problem, max_iterations=5, reward_threshold=8):
289
+ logger.info(f"Testing problem: {problem}")
290
+ reasoning, solution = swiftsage.solve(problem, max_iterations, reward_threshold)
291
+ logger.info(f"Final reasoning:\n{reasoning}")
292
+ logger.info(f"Final solution:\n{solution}")
293
+ logger.info("=" * 50)
294
+
295
+
296
+ def run_benchmark(swiftsage, args, max_iterations=5, reward_threshold=8):
297
+ examples = load_data(args.dataset_name, args.split, args.data_dir, args.num_test_sample)
298
+
299
+ res = []
300
+ skip_ids = []
301
+
302
+ output_path = os.path.join(args.output_path, f"{args.dataset_name}.jsonl")
303
+ if os.path.exists(output_path):
304
+ with open(output_path) as fr:
305
+ model_responses = fr.readlines()
306
+
307
+ for item in model_responses:
308
+ item = json.loads(item)
309
+ res.append(item)
310
+ skip_ids.append(item["idx"])
311
+
312
+ for example in tqdm(examples, desc=args.dataset_name):
313
+ if example["idx"] in skip_ids:
314
+ continue
315
+ question = parse_question(example, args.dataset_name)
316
+ gt_ans = parse_ground_truth(example, args.dataset_name)
317
+ reasoning, solution = swiftsage.solve(question, max_iterations, reward_threshold)
318
+
319
+ # TODO: extract answer from solution
320
+
321
+ cur_res = {
322
+ "idx": example["idx"],
323
+ "question": question,
324
+ "gt": gt_ans,
325
+ "pred": solution,
326
+ "reasoning": reasoning,
327
+ }
328
+ res.append(cur_res)
329
+
330
+ with open(output_path, "a") as fw:
331
+ fw.write(json.dumps(res[-1]) + "\n")
332
+
333
+ # Evaluate the results
334
+ res, result_metric = evaluate(res)
335
+ with open(args.output_path, f"{args.dataset_name}_score.jsonl", "w") as fw:
336
+ for item in res:
337
+ fw.write(json.dumps(item) + "\n")
338
+ with open(args.output_path, f"{args.dataset_name}_metric.jsonl", "w") as fw:
339
+ fw.write(json.dumps(result_metric) + "\n")
340
+
341
+
342
+ def main(args):
343
+
344
+ # TODO: for retrieval augmentation (not implemented yet now)
345
+ # dataset = ["Example problem 1: ...", "Example problem 2: ...", "Example problem 3: ..."]
346
+ # embeddings = np.random.rand(len(dataset), 768) # Placeholder, replace with actual embeddings
347
+
348
+
349
+ # Configuration for each LLM
350
+ # swift_config = {
351
+ # "model_id": "Meta-Llama-3.1-8B-Instruct",
352
+ # "api_config": api_configs['SambaNova']
353
+ # }
354
+
355
+ # reward_config = {
356
+ # "model_id": "Meta-Llama-3.1-70B-Instruct",
357
+ # "api_config": api_configs['SambaNova']
358
+ # }
359
+
360
+ # sage_config = {
361
+ # "model_id": "Meta-Llama-3.1-405B-Instruct",
362
+ # "api_config": api_configs['SambaNova']
363
+ # }
364
+
365
+ swift_config = {
366
+ "model_id": args.swift_model_id,
367
+ "api_config": api_configs[args.api_provider]
368
+ }
369
+
370
+ reward_config = {
371
+ "model_id": args.reward_model_id,
372
+ "api_config": api_configs[args.api_provider]
373
+ }
374
+
375
+ sage_config = {
376
+ "model_id": args.sage_model_id,
377
+ "api_config": api_configs[args.api_provider]
378
+ }
379
+
380
+ # specify the path to the prompt templates
381
+ prompt_template_dir = args.prompt_template_dir
382
+ dataset = []
383
+ embeddings = [] # TODO: for retrieval augmentation (not implemented yet now)
384
+ s2 = SwiftSage(
385
+ dataset,
386
+ embeddings,
387
+ prompt_template_dir,
388
+ swift_config,
389
+ sage_config,
390
+ reward_config,
391
+ use_retrieval=args.use_retrieval,
392
+ start_with_sage=args.start_with_sage,
393
+ )
394
+
395
+ if args.eval_mode == "test":
396
+ test_problems = [
397
+ "Solve the equation: 2x + 5 = 13", # 0
398
+ "If h(x)=x-4 and g(h(x))=x^2-8x+10, find g(x)? show the formula for g(x)", # 1
399
+ "Solve the equation: 6y + 5 = 29", # 2
400
+ "Who lives longer, Lowell Sherman or Jonathan Kaplan?", # 3
401
+ "9.9 or 9.11 -- which is bigger?", # 4
402
+ "How can you solve the quadratic equation 3x^2 + 7x + 4 = 0 using the quadratic formula?", # 5
403
+ "Explain why sound waves cannot travel in a vacuum?", # 6
404
+ "How many grams of hydrogen (H) are present in 23.5 grams of water (H2O)?", # 7
405
+ "What is the distance between the points (2, 3) and (5, 8)?", # 8
406
+ "Why can the Hubble telescope capture clear images of distant stars and galaxies, but not a detailed image of Pluto?", # 9
407
+ """A rectangular band formation is a formation with $m$ band members in each of $r$ rows, where $m$ and $r$ are integers. A particular band has less than 100 band members. The director arranges them in a rectangular formation and finds that he has two members left over. If he increases the number of members in each row by 1 and reduces the number of rows by 2, there are exactly enough places in the new formation for each band member. What is the largest number of members the band could have?""",
408
+ """Tim wants to invest some money in a bank which compounds quarterly with an annual interest rate of $7\%$. To the nearest dollar, how much money should he invest if he wants a total of $\$60,\!000$ at the end of $5$ years?""",
409
+ """In an SR latch built from NOR gates, which condition is not allowed
410
+
411
+ Options:
412
+ [ "S=0, R=2", "S=2, R=2", "S=1, R=1", "S=1, R=-1", "S=1, R=2", "S=0, R=0", "S=2, R=0", "S=1, R=0", "S=2, R=1", "S=0, R=1" ]
413
+
414
+ Which one is the correct answer?""",
415
+ # ... add other problems here ...
416
+ """How many letter r are there in the word "strawberry"?"""
417
+ ]
418
+
419
+ # for problem in test_problems:
420
+ pid = 7
421
+ print(f"Problem {pid}: {test_problems[pid]}")
422
+ run_test(s2, test_problems[pid], args.max_iterations, args.reward_threshold)
423
+ elif args.eval_mode == "benchmark":
424
+ run_benchmark(s2, args, args.max_iterations, args.reward_threshold)
425
+
426
+
427
+ if __name__ == '__main__':
428
+ parser = argparse.ArgumentParser()
429
+ parser.add_argument("--eval_mode", default="test", choices=["test", "benchmark"], type=str)
430
+
431
+ parser.add_argument("--dataset_name", default="MATH", type=str)
432
+ parser.add_argument("--data_dir", default="./data", type=str)
433
+ parser.add_argument("--split", default="test", type=str)
434
+ parser.add_argument("--num_test_sample", default=-1, type=int) # -1 for full data
435
+
436
+ parser.add_argument("--api_provider", default="Together", choices=["Together", "SambaNova"], type=str)
437
+ parser.add_argument("--swift_model_id", default="meta-llama/Meta-Llama-3-8B-Instruct-Turbo", type=str)
438
+ parser.add_argument("--reward_model_id", default="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", type=str)
439
+ parser.add_argument("--sage_model_id", default="meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", type=str)
440
+
441
+ parser.add_argument("--prompt_template_dir", default='./prompt_templates', type=str)
442
+ parser.add_argument("--use_retrieval", action="store_true")
443
+ parser.add_argument("--start_with_sage", action="store_true")
444
+
445
+ parser.add_argument("--max_iterations", default=5, type=int)
446
+ parser.add_argument("--reward_threshold", default=8, type=int)
447
+
448
+ parser.add_argument("--save_outputs", action="store_true")
449
+ parser.add_argument("--output_path", default="./output", type=str)
450
+ parser.add_argument("--overwrite", action="store_true")
451
+
452
+ args = parser.parse_args()
453
+
454
+ # remove console output for benchmark evaluation
455
+ if args.eval_mode != "test":
456
+ root_logger = logging.getLogger("")
457
+ for handler in root_logger.handlers:
458
+ if isinstance(handler, logging.StreamHandler):
459
+ root_logger.removeHandler(handler)
460
+ break
461
+
462
+ if args.api_provider == "SambaNova":
463
+ args.swift_model_id = args.swift_model_id.split("/")[-1][:-len("Turbo")]
464
+ args.reward_model_id = args.reward_model_id.split("/")[-1][:-len("Turbo")]
465
+ args.sage_model_id = args.sage_model_id.split("/")[-1][:-len("Turbo")]
466
+
467
+ multiprocessing.set_start_method('spawn')
468
+ main(args)
prompt_templates/reward_template.md ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Instruction
2
+
3
+ You are a reward model. You will be given a problem, a solution. You will then evaluate the solution based on the criteria provided.
4
+
5
+ ## Problem
6
+ <problem>
7
+
8
+ ## Current Solution
9
+
10
+ ### Reasoning Steps
11
+ <reasoning>
12
+
13
+ ### Final Answer
14
+ <current_solution>
15
+
16
+
17
+ ## Your Evaluation
18
+
19
+ We are not sure if the current solution is correct. Please evaluate the current solution based on the following criteria:
20
+
21
+ 1. Correctness
22
+ 2. Completeness
23
+
24
+ Provide a score from 1 to 10 and a brief explanation.
25
+ If you are not sure about the final answer, provide a score between 1 to 7 and explain why you are not sure about the final answer.
26
+ Take care and do not give false information in the critical feedback.
27
+
28
+
29
+ ## Output Format
30
+
31
+ Remember to present your output in the following format:
32
+
33
+ <feedback>
34
+ Your critical feedback here.
35
+ </feedback>
36
+
37
+
38
+ <score>
39
+ Your score here.
40
+ </score>
41
+
42
+ # Important Notes
43
+
44
+ You must follow the format strictly, do not miss any field. Start your output by "<feedback>" and end your output by "</score>".
prompt_templates/sage_template.md ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Instruction
2
+
3
+ You are a high-level problem-solving agent. You will be given a problem and a current solution. You will then provide a critical feedback on the current solution and suggest a revised plan if needed.
4
+ If the current solution is correct and complete, you will suggest the problem is solved and no further action is needed.
5
+
6
+ ## Problem
7
+ <prompt>
8
+
9
+ ## Current Solution
10
+
11
+ ### Reasoning Steps
12
+ <reasoning>
13
+
14
+ ### Final Answer
15
+ <current_solution>
16
+
17
+
18
+ ## Critical Feedback
19
+
20
+ We are not sure if the current solution is correct, can you provide a critical feedback on the current solution and suggest a revised plan for the next steps. Consider any challenges or improvements needed.
21
+
22
+ If the solution and answer are correct, please set `solved` to `"True"`, and leave `critical_feedback` and `reasoning_steps` empty.
23
+ Please point out the errors in the current solution if there are any in the `critical_feedback` field, and then provide the revised plan in the `reasoning_steps` field, and finally provide the final answer in the `final_answer` field.
24
+
25
+
26
+ Format your response in the following format:
27
+
28
+
29
+ <solved>
30
+ [True or False]
31
+ </solved>
32
+
33
+ <critical_feedback>
34
+ [Your critical feedback here.]
35
+ </critical_feedback>
36
+
37
+ <reasoning_steps>
38
+ [Put your reasoning steps here to revise the previous solution. Use additional knowledge if needed and then we will write the code to solve the problem in the next field.]
39
+ </reasoning_steps>
40
+
41
+ <code>
42
+ [Put your updated code here to solve the problem.]
43
+ </code>
44
+
45
+
prompt_templates/swift_template.md ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Instruction
2
+
3
+ ## Similar Examples with Solutions
4
+
5
+ ### Example Task 1
6
+
7
+ <task>
8
+ Convert the point $(0, -3 \sqrt{3}, 3)$ in rectangular coordinates to spherical coordinates. Enter your answer in the form $(\rho,\theta,\phi),$ where $\rho > 0,$ $0 \le \theta < 2 \pi,$ and $0 \le \phi \le \pi.$
9
+ </task>
10
+
11
+ <plan>
12
+ Step 1. Recall the formulas for converting from rectangular coordinates $(x, y, z)$ to spherical coordinates $(\rho, \theta, \phi)$:
13
+ - $\rho = \sqrt{x^2 + y^2 + z^2}$
14
+ - $\theta = \arctan2(y, x)$
15
+ - $\phi = \arccos\left(\frac{z}{\rho}\right)$
16
+
17
+ Step 2. Given point: $(0, -3\sqrt{3}, 3)$
18
+ $x = 0$
19
+ $y = -3\sqrt{3}$
20
+ $z = 3$
21
+
22
+ Step 3. Calculate $\rho$ using the formula.
23
+
24
+ Step 4. Calculate $\theta$:
25
+ - Since $x = 0$, we need to handle this special case.
26
+ - When $x = 0$ and $y < 0$, $\theta = \frac{3\pi}{2}$
27
+
28
+ Step 5. Calculate $\phi$ using the formula.
29
+
30
+ Step 6. Ensure $\theta$ is in the range $[0, 2\pi)$ and $\phi$ is in the range $[0, \pi]$.
31
+ </plan>
32
+
33
+ <code>
34
+ from sympy import sqrt, atan2, acos, pi
35
+
36
+ def rectangular_to_spherical():
37
+ x, y, z = 0, -3*sqrt(3), 3
38
+ rho = sqrt(x**2 + y**2 + z**2)
39
+ theta = atan2(y, x)
40
+ phi = acos(z/rho)
41
+ return rho, theta, phi
42
+
43
+ spherical_coordinates = rectangular_to_spherical()
44
+ print(spherical_coordinates)
45
+ </code>
46
+
47
+
48
+ <final_answer>
49
+ (6, -pi/2, pi/3)
50
+ </final_answer>
51
+
52
+ ### Example Task 2
53
+
54
+ <task>
55
+ Determine who lived longer between Lowell Sherman and Jonathan Kaplan.
56
+ </task>
57
+
58
+ <plan>
59
+ Step 1: Research the birth and death dates of Lowell Sherman.
60
+ Step 2: Research the birth and death dates of Jonathan Kaplan.
61
+ Step 3: Calculate the lifespan of each person in years.
62
+ Step 4: Compare the lifespans to determine who lived longer.
63
+ </plan>
64
+
65
+ <code>
66
+ from datetime import datetime
67
+
68
+ def calculate_lifespan(birth_date, death_date):
69
+ birth = datetime.strptime(birth_date, "%Y-%m-%d")
70
+ death = datetime.strptime(death_date, "%Y-%m-%d")
71
+ return (death - birth).days / 365.25
72
+
73
+ def compare_lifespans():
74
+ lowell_sherman = calculate_lifespan("1885-10-11", "1934-12-28")
75
+ jonathan_kaplan = calculate_lifespan("1947-11-25", "2021-01-03")
76
+
77
+ if lowell_sherman > jonathan_kaplan:
78
+ return "Lowell Sherman"
79
+ elif jonathan_kaplan > lowell_sherman:
80
+ return "Jonathan Kaplan"
81
+ else:
82
+ return "They lived equally long"
83
+
84
+ result = compare_lifespans()
85
+ print(f"{result} lived longer.")
86
+ </code>
87
+
88
+ <final_answer>
89
+ Jonathan Kaplan lived longer.
90
+ </final_answer>
91
+
92
+
93
+ ---
94
+
95
+ ## Important Notes
96
+
97
+ Note that the above are some example tasks and output formats. You need to solve the current problem below.
98
+
99
+ ---
100
+
101
+ ## Current problem that we want to solve
102
+ <task>
103
+ <prompt>
104
+ </task>
105
+
106
+ ## Previous Solution
107
+
108
+ ### Previous Reasoning Steps
109
+ <plan>
110
+ <current_reasoning>
111
+ </plan>
112
+
113
+ ### Previous Answer
114
+ <final_answer>
115
+ <current_solution>
116
+ </final_answer>
117
+
118
+
119
+
120
+ ---
121
+
122
+ ## Critical Feedback
123
+ <critical_feedback>
124
+
125
+ ### Suggested Plan
126
+ <revised_plan>
127
+
128
+ ---
129
+
130
+ ## Your Final Solution
131
+
132
+ Read the current problem in <task>...</task> again.
133
+
134
+ <task>
135
+ <prompt>
136
+ </task>
137
+
138
+ To solve the current problem, you should first write the overall plan in <plan>...</plan> to solve the problem. Then, write python code in <code>...</code> tags to solve the problem. If there is critical feedback and suggested plan, please revise your previous solution (if any) and provide the new plan and solution to solve the problem based on the critical feedback and suggested plan.
139
+
140
+ ## Remember to present your output in the following format:
141
+
142
+ <plan>
143
+ [Your general plan to solve the problem by using code. You can recall the required knowledge that you can use in the code, such as the facts, formulas, etc.]
144
+ </plan>
145
+
146
+ <code>
147
+ [Your python code to solve the current problem (instead of the example problems). Please print the final answer at the end of the code.]
148
+ </code>
149
+
150
+ You must follow the format strictly, do not miss any field.
151
+ Start your output by "<plan>...</plan>" and end your output by "<code> ... </code>".
152
+
run_eval.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ DEBUG_MODE="-m debugpy --listen 127.0.0.1:5679 --wait-for-client"
2
+
3
+ python $DEBUG_MODE main.py \
4
+ --eval_mode benchmark \
5
+ --dataset_name MATH \
6
+ --num_test_sample 4 \
test.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from code_executor import PythonExecutor
2
+ import multiprocess
3
+
4
+ if __name__ == '__main__':
5
+ multiprocess.set_start_method('spawn')
6
+
7
+ current_code = """
8
+ ```python
9
+ def calculate_hydrogen_mass(mass_of_water_grams):
10
+ mass_of_hydrogen = 1.00794 # g/mol
11
+ mass_of_water = 18.01528 # g/mol
12
+ ratio = (2 * mass_of_hydrogen) / mass_of_water
13
+ return ratio * mass_of_water_grams
14
+
15
+ mass_of_water = 23.5 # grams
16
+ hydrogen_mass = calculate_hydrogen_mass(mass_of_water)
17
+
18
+ print(hydrogen_mass)
19
+ ```
20
+ """
21
+ executor = PythonExecutor(get_answer_from_stdout=True)
22
+ result, report = executor.apply(current_code)
23
+ print("Result:", result)
24
+ print("Report:", report)
25
+
26
+ # Make sure to close the pool when done
27
+ executor.pool.close()
28
+ executor.pool.join()
utils.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import json
3
+ import logging
4
+ import os
5
+ import re
6
+ from abc import ABC, abstractmethod
7
+
8
+ import dirtyjson
9
+ import hjson
10
+ import numpy as np
11
+ import openai
12
+ from fuzzywuzzy import process
13
+ from sklearn.metrics.pairwise import cosine_similarity
14
+
15
+ api_configs = {
16
+ "SambaNova": {
17
+ "api_key": os.environ.get("SAMBANOVA_API_KEY"),
18
+ "url_base": "https://api.sambanova.ai/v1"
19
+ },
20
+ "Together": {
21
+ "api_key": os.environ.get("TOGETHER_API_KEY"),
22
+ "url_base": "https://api.together.xyz/v1"
23
+ }
24
+ # You can add more API configurations here for other providers
25
+ }
26
+
27
+ class Agent(ABC):
28
+ def __init__(self, prompt_template, llm_client):
29
+ self.prompt_template = prompt_template
30
+ self.llm_client = llm_client
31
+
32
+ @abstractmethod
33
+ def generate_response(self, prompt):
34
+ pass
35
+
36
+
37
+ def setup_logging():
38
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
39
+ log_filename = f"logs/swiftsage_log_{timestamp}.txt"
40
+
41
+ logging.basicConfig(
42
+ level=logging.INFO,
43
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
44
+ filename=log_filename,
45
+ filemode='w'
46
+ )
47
+
48
+ # Also print to console
49
+ console = logging.StreamHandler()
50
+ console.setLevel(logging.INFO)
51
+ formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s')
52
+ console.setFormatter(formatter)
53
+ logging.getLogger('').addHandler(console)
54
+
55
+ return logging.getLogger('SwiftSage')
56
+
57
+
58
+
59
+ def extract_and_parse_markup(text):
60
+ keys = ["reasoning_steps", "final_answer", "feedback", "score", "critical_feedback", "revised_plan", "solved", "plan", "code"]
61
+ result = {}
62
+ if "<final_answer>" in text and "</final_answer>" not in text:
63
+ text = text + "</final_answer>"
64
+
65
+ for key in keys:
66
+ # Create a pattern for each key
67
+ pattern = f'<{key}>(.*?)</{key}>'
68
+
69
+ # Search for the pattern in the text
70
+ match = re.search(pattern, text, re.DOTALL)
71
+
72
+ if match:
73
+ # Extract the content, strip whitespace, and add to the result
74
+ content = match.group(1).strip()
75
+ result[key] = content
76
+
77
+ if "code" in result.keys():
78
+ result["code"] = result["code"].replace("```python", "").replace("```", "").strip()
79
+
80
+ return result
81
+
82
+
83
+ class PromptTemplate:
84
+ def __init__(self, template_dir):
85
+ self.template_dir = template_dir
86
+ self.templates = {}
87
+ self.load_templates()
88
+
89
+ def load_templates(self):
90
+ for filename in ['swift_template.md', 'sage_template.md', 'reward_template.md']:
91
+ with open(os.path.join(self.template_dir, filename), 'r') as f:
92
+ key = filename.split('_')[0]
93
+ self.templates[key] = f.read()
94
+
95
+ def format(self, key, **kwargs):
96
+ template = self.templates.get(key, "")
97
+ for k, v in kwargs.items():
98
+ template = template.replace("<" + k + ">", str(v))
99
+ return template
100
+
101
+
102
+ class LLMClient:
103
+ def __init__(self, model_id, api_config, temperature=0.3, top_p=1.0, max_tokens=3000, logger=None):
104
+ self.client = openai.OpenAI(
105
+ api_key=api_config['api_key'],
106
+ base_url=api_config['url_base']
107
+ )
108
+ self.model_id = model_id
109
+ self.temperature = temperature
110
+ self.top_p = top_p
111
+ self.max_tokens = max_tokens
112
+ self.logger = logger
113
+
114
+ def generate_response(self, messages):
115
+ self.logger.info(f"Sending request to {self.model_id}")
116
+ self.logger.info(f"Messages: {messages}")
117
+ response = self.client.chat.completions.create(
118
+ model=self.model_id,
119
+ messages=messages,
120
+ temperature=self.temperature,
121
+ top_p=self.top_p,
122
+ max_tokens=self.max_tokens
123
+ )
124
+ content = response.choices[0].message.content
125
+ self.logger.info(f"Response from {self.model_id}:\n{content}")
126
+ return content
127
+
128
+
129
+
130
+
131
+
132
+ if __name__ == "__main__":
133
+ test_text = "test"
134
+
135
+ print(extract_and_parse_markup(test_text))
136
+
137
+
138
+
139
+ """
140
+
141
+ def extract_and_parse_json(text):
142
+
143
+ keys_and_types = [
144
+ ("reasoning_steps", list),
145
+ ("final_answer", str),
146
+ ("feedback", str),
147
+ ("score", str),
148
+ ("score", int),
149
+ ("feedback", str),
150
+ ("solved", str),
151
+ ("critical_feedback", str),
152
+ ("revised_plan", list),
153
+ ]
154
+
155
+ # Try to parse the JSON first
156
+ try:
157
+ # find the first and last curly braces and parse the json
158
+ first_brace = text.find("{")
159
+ last_brace = text.rfind("}")
160
+ if last_brace == -1:
161
+ text = text + "\"}"
162
+ if first_brace != -1 and last_brace != -1 and first_brace < last_brace:
163
+ data = json.loads(text[first_brace:last_brace+1])
164
+ return data
165
+ except Exception as e:
166
+ data = {}
167
+ try:
168
+ data = dirtyjson.loads(text)
169
+ except Exception as e:
170
+ pass
171
+ # If JSON parsing fails, use regex to extract key-value pairs
172
+
173
+ for key, _ in keys_and_types:
174
+ # pattern = rf'"{key}"\s*:\s*([\[{{].*?[\]}}]|".*?")'
175
+ pattern = rf'"{key}"\s*:\s*([\[{{].*?[\]}}]|".*?"|[-+]?\d+)'
176
+ match = re.search(pattern, text, re.DOTALL)
177
+ if match:
178
+ try:
179
+ value = json.loads(match.group(1))
180
+ except Exception as e:
181
+ value = match.group(1).strip('"')
182
+ data[key] = value
183
+
184
+ result = {}
185
+ for key, expected_type in keys_and_types:
186
+ if key in result.keys() and result[key] is not None:
187
+ continue
188
+ # Use fuzzy matching to find the closest key
189
+ try:
190
+ closest_key, score = process.extractOne(key, data.keys())
191
+ except Exception as e:
192
+ continue
193
+ if score > 80: # You can adjust this threshold
194
+ value = data[closest_key]
195
+
196
+ # Type checking and conversion
197
+ if expected_type == list and isinstance(value, str):
198
+ value = [item.strip() for item in value.strip('[]').split(',')]
199
+ elif expected_type == str and isinstance(value, list):
200
+ value = ', '.join(value)
201
+ elif expected_type == int and value is not None:
202
+ try:
203
+ value = int(value)
204
+ except ValueError:
205
+ value = None
206
+
207
+ result[key] = value
208
+ else:
209
+ result[key] = None
210
+
211
+ for key in list(result.keys()):
212
+ if result[key] is None:
213
+ del result[key]
214
+ return result
215
+
216
+ def extract_and_parse_json_v1(text):
217
+ def find_json_objects(s):
218
+ # Find all substrings that look like JSON objects
219
+ json_like_strs = re.findall(r'\{(?:[^{}]|\{[^{}]*\})*\}', s)
220
+ return json_like_strs
221
+
222
+ def try_parse_json(s):
223
+ try:
224
+ return json.loads(s)
225
+ except json.JSONDecodeError:
226
+ try:
227
+ s = s.replace("\n", "")
228
+ return hjson.loads(s)
229
+ except json.JSONDecodeError:
230
+ return None
231
+ return None
232
+
233
+ # First, try to find JSON within code blocks
234
+ code_block_pattern = r'```(?:json)?\s*([\s\S]*?)\s*```'
235
+ code_blocks = re.findall(code_block_pattern, text, re.IGNORECASE)
236
+
237
+ all_json_candidates = []
238
+
239
+ # Add JSON candidates from code blocks
240
+ for block in code_blocks:
241
+ all_json_candidates.extend(find_json_objects(block))
242
+
243
+ # Add JSON candidates from the entire text
244
+ all_json_candidates.extend(find_json_objects(text))
245
+
246
+ # Sort candidates by length, descending
247
+ all_json_candidates.sort(key=len, reverse=True)
248
+
249
+ # Try to parse each candidate
250
+ for candidate in all_json_candidates:
251
+ parsed_json = try_parse_json(candidate)
252
+ if parsed_json is not None:
253
+ return parsed_json
254
+
255
+ raise ValueError("No valid JSON object found in the text")
256
+
257
+
258
+
259
+
260
+ """