Spaces:
Runtime error
Runtime error
import gradio as gr | |
from dataclasses import dataclass | |
from concurrent.futures import ThreadPoolExecutor, TimeoutError | |
import os | |
import re | |
import subprocess | |
import tempfile | |
import json | |
import datasets | |
import random | |
from typing import Tuple, Dict, Any, List | |
from sympy import N, simplify | |
from sympy.parsing.latex import parse_latex | |
from openai import OpenAI | |
# client = OpenAI( | |
# base_url=os.environ.get("SERVER_URL"), | |
# api_key=os.environ.get("HF_TOKEN"), | |
# ) | |
class Config: | |
model_id: str # SELECT MODEL | |
revision: str # SELECT REVISION | |
# Append an optional system prompt to each problem | |
system_prompt: str | |
# Number of samples to generate per problem | |
num_samples: int | |
num_generations: int | |
# Generation parameters | |
do_sample: bool | |
temperature: float | |
top_p: float | |
top_k: int | |
max_new_tokens: int | |
restart_on_fail: bool | |
# Enable 4-bit quantization | |
is_quantized: bool | |
# Run on train or test data? | |
is_submission: bool = True if os.getenv("KAGGLE_IS_COMPETITION_RERUN") else False | |
validation_set: str = "kaggle-validation-set-medium" | |
notebook_time_limit: int = 9 * 60 * 60 - 15 * 60 # 9 hours - 15 minute buffer | |
# Debug by solving only the first problem | |
debug: bool = False | |
# Push solutions to the Hub | |
push_to_hub: bool = False | |
class PythonREPL: | |
def __init__(self, timeout=5): | |
self.timeout = timeout | |
def execute(self, query: str) -> Tuple[bool, str]: | |
query = "import math\nimport numpy as np\nimport sympy as sp\n" + query | |
query = query.strip().split("\n") | |
if "print(" not in query[-1]: | |
if "#" in query[-1]: | |
query[-1] = query[-1].split("#")[0] | |
query[-1] = "print(" + query[-1] + ")" | |
query = "\n".join(query) | |
with tempfile.TemporaryDirectory() as temp_dir: | |
temp_file_path = os.path.join(temp_dir, "tmp.py") | |
with open(temp_file_path, "w") as f: | |
f.write(query) | |
result = subprocess.run( | |
["python3", temp_file_path], | |
capture_output=True, | |
check=False, | |
text=True, | |
timeout=self.timeout, | |
) | |
if result.returncode == 0: | |
output = result.stdout | |
return True, output.strip() | |
else: | |
error_msg = result.stderr.strip() | |
msgs = error_msg.split("\n") | |
new_msgs = [] | |
want_next = False | |
for m in msgs: | |
if "Traceback" in m: | |
new_msgs.append(m) | |
elif m == msgs[-1]: | |
new_msgs.append(m) | |
elif temp_file_path in m: | |
st = m.index('"/') + 1 if '"/' in m else 0 | |
ed = m.index(temp_file_path) + 1 if temp_file_path in m else None | |
clr = m[st:ed] if not ed else m[st:] | |
m = m.replace(clr, "") | |
new_msgs.append(m) | |
want_next = True | |
elif want_next: | |
new_msgs.append(m) | |
want_next = False | |
error_msg = "\n".join(new_msgs) | |
return False, error_msg.strip() | |
def __call__(self, query: str) -> Tuple[bool, str]: | |
with ThreadPoolExecutor() as executor: | |
future = executor.submit(self.execute, query) | |
try: | |
return future.result(timeout=self.timeout) | |
except TimeoutError: | |
return False, f"Timed out after {self.timeout} seconds." | |
def execute_completion( | |
executor: PythonREPL, | |
completion: str, | |
return_status: bool = False, | |
last_code_block: bool = False, | |
) -> str | Tuple[str, bool]: | |
# executions = ["!" + code for code in re.findall(r"```bash(.*?)```", completion, re.DOTALL) if "!" not in code] | |
executions = re.findall(r"```python(.*?)```", completion, re.DOTALL) | |
if len(executions) == 0: # directly return cot result | |
return completion, False if return_status else completion | |
else: | |
if last_code_block: | |
executions = [executions[-1]] | |
# Python | |
execution_outputs = [] | |
successes = [] | |
for code in executions: | |
success = False | |
if "subprocess" in code: | |
output = "subprocess is not allowed" | |
execution_outputs.append(output) | |
successes.append(success) | |
continue | |
if "venv" in code: | |
output = "venv is not allowed" | |
execution_outputs.append(output) | |
successes.append(success) | |
continue | |
try: | |
success, output = executor(code) | |
except TimeoutError as e: | |
print("time out") | |
output = e | |
if not success and not return_status: | |
output = "" | |
execution_outputs.append(output) | |
successes.append(success) | |
output = str(execution_outputs[-1]).strip() | |
success = successes[-1] | |
if return_status: | |
return output, success | |
else: | |
return output | |
def postprocess_completion( | |
text: str, return_status: bool = False, last_code_block=False, timeout=5 | |
) -> str | Tuple[str, bool]: | |
executor = PythonREPL(timeout=timeout) | |
result = execute_completion(executor, text, return_status=return_status, last_code_block=last_code_block) | |
del executor | |
return result | |
def apply_template(example: Dict[str, Any], prompt: str) -> Dict[str, Any]: | |
return prompt.format(example["prompt"], "{}") | |
def last_boxed_only_string(string): | |
""" | |
Extracts the last LaTeX boxed or framed expression from a string. | |
Args: | |
string (str): The input string containing LaTeX expressions. | |
Returns: | |
str or None: The last boxed or framed expression, if found; | |
otherwise, None. | |
""" | |
idx = string.rfind("\\boxed") | |
if idx < 0: | |
idx = string.rfind("\\fbox") | |
if idx < 0: | |
return None | |
i = idx | |
right_brace_idx = None | |
num_left_braces_open = 0 | |
while i < len(string): | |
if string[i] == "{": | |
num_left_braces_open += 1 | |
if string[i] == "}": | |
num_left_braces_open -= 1 | |
if num_left_braces_open == 0: | |
right_brace_idx = i | |
break | |
i += 1 | |
if right_brace_idx is None: | |
retval = None | |
else: | |
retval = string[idx : right_brace_idx + 1] | |
return retval | |
def remove_boxed(s): | |
""" | |
Removes the LaTeX boxed command, returning the content inside the braces. | |
Args: | |
s (str): The string containing a LaTeX boxed expression. | |
Returns: | |
str or None: The content inside the boxed command, if valid; | |
otherwise, None. | |
""" | |
left = "\\boxed{" | |
try: | |
assert s[: len(left)] == left | |
assert s[-1] == "}" | |
length = len(left) | |
return s[length:-1] | |
except Exception: | |
return None | |
def extract_boxed_answer(pred_str, strip_double_curly_brace=False): | |
""" | |
Extracts the answer from a LaTeX boxed expression within | |
a prediction string. | |
Args: | |
pred_str (str): The string containing one or more LaTeX | |
boxed expressions. | |
strip_double_curly_brace (bool): If True, removes an additional | |
layer of braces. | |
Returns: | |
str or None: The extracted answer, if any; otherwise, None. | |
""" | |
boxed_str = last_boxed_only_string(pred_str) | |
if boxed_str is None: | |
return None | |
answer = remove_boxed(boxed_str) | |
if answer is None: | |
return None | |
if strip_double_curly_brace: | |
match = re.match("^\{(.*)\}$", answer) # noqa: W605 | |
if match: | |
answer = match.group(1) | |
return answer | |
def normalize_final_answer(final_answer: str) -> str: | |
""" | |
Normalizes a final answer string by removing or replacing various LaTeX | |
and text elements. | |
Args: | |
final_answer (str): The answer string to normalize. | |
Returns: | |
str: The normalized answer string. | |
""" | |
match = re.search(r"(.*?)Problem:", final_answer, flags=re.S) | |
if match: | |
final_answer = match.group(1) # 返回匹配的第一部分,即"Problem"之前的所有文本 | |
"""Normalize a final answer to a quantitative reasoning question.""" | |
# final_answer = final_answer.split('=')[-1] | |
SUBSTITUTIONS = [ | |
("an ", ""), | |
("a ", ""), | |
(".$", "$"), | |
("\\$", ""), | |
(r"\ ", ""), | |
(" ", ""), | |
("mbox", "text"), | |
(",\\text{and}", ","), | |
("\\text{and}", ","), | |
("\\text{m}", "\\text{}"), | |
("\\le", "<"), | |
] | |
REMOVED_EXPRESSIONS = [ | |
"square", | |
"ways", | |
"integers", | |
"dollars", | |
"mph", | |
"inches", | |
"ft", | |
"hours", | |
"km", | |
"units", | |
"\\ldots", | |
"sue", | |
"points", | |
"feet", | |
"minutes", | |
"digits", | |
"cents", | |
"degrees", | |
"cm", | |
"gm", | |
"pounds", | |
"meters", | |
"meals", | |
"edges", | |
"students", | |
"childrentickets", | |
"multiples", | |
"\\text{s}", | |
"\\text{.}", | |
"\\text{\ns}", | |
"\\text{}^2", | |
"\\text{}^3", | |
"\\text{\n}", | |
"\\text{}", | |
r"\mathrm{th}", | |
r"^\circ", | |
r"^{\circ}", | |
r"\;", | |
r",\!", | |
"{,}", | |
'"', | |
"\\dots", | |
"\n", | |
"\r", | |
"\f", | |
"\%", | |
] | |
for before, after in SUBSTITUTIONS: | |
final_answer = final_answer.replace(before, after) | |
for expr in REMOVED_EXPRESSIONS: | |
final_answer = final_answer.replace(expr, "") | |
# Extract answer that is in LaTeX math, is bold, | |
# is surrounded by a box, etc. | |
final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) | |
final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) | |
final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) | |
final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) | |
assert "\n" not in final_answer | |
assert "\r" not in final_answer | |
assert "\f" not in final_answer | |
if len(re.findall(r"finalansweris(.*)", final_answer)) > 0: | |
final_answer = re.findall(r"finalansweris(.*)", final_answer)[-1] | |
if len(re.findall(r"answer?is:?(.*)", final_answer)) > 0: | |
final_answer = re.findall(r"answer?is:?(.*)", final_answer)[-1] | |
if len(re.findall(r"oxed\{(.*?)\}", final_answer)) > 0: | |
final_answer = re.findall(r"oxed\{(.*?)\}", final_answer)[-1] | |
if len(re.findall(r"\$(.*?)\$", final_answer)) > 0: | |
final_answer = re.findall(r"\$(.*?)\$", final_answer)[-1] | |
final_answer = final_answer.strip() | |
if "rac" in final_answer and "\\frac" not in final_answer: | |
final_answer = final_answer.replace("rac", "\\frac") | |
final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) | |
final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) | |
final_answer = final_answer.replace("$", "") | |
if final_answer.replace(",", "").isdigit(): | |
final_answer = final_answer.replace(",", "") | |
return final_answer | |
def naive_parse(answer: str) -> str: | |
""" | |
Extracts and returns the numeric digits from the input string, processing them in reverse order | |
until a non-numeric character is encountered after encountering the first numeric character. | |
Args: | |
answer (str): The input string to parse. | |
Returns: | |
str: A string consisting of the numeric digits extracted from the input, in their original order. | |
Example: | |
>>> naive_parse("abc123def") | |
'123' | |
>>> naive_parse("def456ghi") | |
'456' | |
>>> naive_parse("no numbers here") | |
'' | |
""" | |
out = [] | |
start = False | |
end = False | |
for l in reversed(list(answer)): | |
if l in "0123456789" and not end: | |
start = True | |
out.append(l) | |
else: | |
if start: | |
end = True | |
out = reversed(out) | |
return "".join(out) | |
def validate_answer_is_numeric(x: str | int | float) -> int: | |
FLOAT_TOLERANCE = 0.2 | |
try: | |
x = round(float(x)) | |
f = float(x) | |
if abs(x - f) > FLOAT_TOLERANCE: | |
x = -1 | |
except Exception: | |
x = -1 | |
return x | |
def filter_answers(answers: List[str]) -> List[int]: | |
formatted_answers = [validate_answer_is_numeric(a) for a in answers] | |
# Filter for non-negative answers | |
formatted_answers = [a for a in formatted_answers if a >= 0] | |
# Compute modulo | |
formatted_answers = [a % 1_000 for a in formatted_answers] | |
# less than 2.1 billion or cannot convert to C int (32-bit) | |
formatted_answers = [a for a in formatted_answers if a <= 999] | |
return formatted_answers | |
def check_sympy_equivalence(ref_answer: str, model_answer: str) -> bool: | |
def do_answers_match(ref_answer: str, model_answer: str) -> bool: | |
ref_sympy = parse_latex(ref_answer) | |
model_sympy = parse_latex(model_answer) | |
diff = simplify(ref_sympy - model_sympy) | |
return True if -1e-12 < N(diff) < 1e-12 or diff.is_zero else False | |
try: | |
result = do_answers_match(ref_answer, model_answer) | |
return result | |
except Exception as e: | |
print(e) | |
return False | |
def check_string_match(ref_answer: str, model_answer: str) -> bool: | |
try: | |
return ref_answer == model_answer | |
except Exception as e: | |
print(e) | |
return False | |
def check_answer(ref_answer: str, model_answer: str) -> bool: | |
# check if strings are the same | |
correct = check_string_match(ref_answer, model_answer) | |
if correct: | |
return True | |
# use the sympy library to check if the expressions are the same | |
correct = check_sympy_equivalence(ref_answer, model_answer) | |
if correct: | |
return True | |
return False | |
debug = False | |
model_id = "Numina-Math-7B" | |
revision = "main" | |
system_prompt = "{}" | |
validation_set = "kaggle-validation-set-medium" | |
is_submission = True | |
num_samples = 4 | |
num_generations = 4 | |
temperature = 0.8 | |
is_quantized = False | |
restart_on_fail = False | |
top_p = 1.0 | |
top_k = 0 | |
max_new_tokens = 2048 | |
# Papermill related variables | |
push_to_hub = False | |
notebook_name = "" | |
config = Config( | |
debug=debug, | |
push_to_hub=push_to_hub, | |
model_id=model_id, | |
revision=revision, | |
system_prompt=system_prompt, | |
validation_set=validation_set, | |
is_quantized=is_quantized, | |
restart_on_fail=restart_on_fail, | |
is_submission=is_submission, | |
num_samples=num_samples, | |
num_generations=num_generations, | |
do_sample=True, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
max_new_tokens=max_new_tokens, | |
) | |
print(f"=== Running submission with config ===\n\n{config}") | |
def generate(message, temperature): | |
""" | |
Generates a chat completion response by streaming data from the client chat model. | |
This function streams the response from the client chat model and yields the content | |
of the response chunk by chunk. If an error occurs, it yields the error message. | |
Parameters: | |
message (str): The input message to be sent to the chat model. | |
temperature (float): The sampling temperature to use. Higher values mean the model will take more risks. | |
Yields: | |
tuple: A tuple containing the content of the response and a boolean flag indicating if an error occurred. | |
If no error occurred, the boolean flag will be False and the content will be the response text. | |
If an error occurred, the boolean flag will be True and the content will be the error message. | |
""" | |
stream = client.chat.completions.create( | |
model="tgi", | |
messages=message, | |
stream=True, | |
max_tokens=1024, | |
stop=["```output\n"], | |
temperature=temperature, | |
timeout=30, | |
) | |
response = stream.response | |
# The reason why the library method is not used here is that if an error occurs, | |
# the returned data will not be a stream, and using the official library will result in an error. | |
for chunk in response.iter_bytes(): | |
chunk = chunk.decode("utf-8") | |
chune_json = json.loads(chunk.replace("data:", "")) | |
try: | |
if "error" in chune_json and chune_json["error"]: | |
yield chune_json["error"], True | |
break | |
content = chune_json["choices"][0]["delta"]["content"] | |
if content is not None: | |
yield content, False | |
except Exception as e: | |
print(f"func: generate error occurred\njson:{chune_json}\nerror:{e}") | |
yield "", True | |
def get_majority_text(data): | |
from collections import Counter | |
# Count the frequency of each answer in model_answers | |
answer_counts = Counter(data["model_answers"]) | |
# Find the majority response | |
majority_response = answer_counts.most_common(1)[0][0] | |
# Find the index of the first occurrence of the majority response | |
majority_index = data["model_answers"].index(majority_response) | |
# Return the corresponding text in gen_texts | |
return data["gen_texts"][majority_index] | |
def extract_solution(text): | |
# Split the text at "### Solution:" | |
parts = text.split("### Solution:", 1) | |
if len(parts) > 1: | |
# Return everything after "### Solution:" | |
return parts[1].strip() | |
else: | |
# Return an empty string if "### Solution:" is not found | |
return "" | |
def process_code( | |
example: Dict[str, Any], | |
config: Config, | |
restart_on_fail: bool = False, | |
last_step: bool = False, | |
) -> Dict[str, Any]: | |
gen_text = example["gen_texts"] | |
num_python_blocks = len(re.findall(r"```python(.*?)```", gen_text, re.DOTALL)) | |
if num_python_blocks == 0: | |
if restart_on_fail: | |
print("no code has ever been generated, RESTARTING") | |
# reset the text to the original | |
example["gen_texts"] = example["text"] | |
else: | |
print("no code has ever been generated, STOP") | |
example["should_prune"] = True | |
example["has_code"] = False | |
return example | |
if gen_text[-10:] != "```output\n" and ("answer is" in gen_text[-100:] or "\\boxed" in gen_text[-100:]): | |
num_output_blocks = len(re.findall(r"```output(.*?)```", gen_text, re.DOTALL)) | |
if num_output_blocks == 0: | |
print("the model hallucinated the code answer") | |
example["should_prune"] = True | |
return example | |
if "boxed" in gen_text[-100:]: | |
try: | |
answer = normalize_final_answer(extract_boxed_answer(gen_text[-100:])) | |
except Exception: | |
answer = "-1" | |
else: | |
answer = normalize_final_answer(gen_text[-100:]) | |
example["model_answers"] = answer | |
if not config.is_submission: | |
example["corrects"] = check_answer(example["ground_truth"], answer) | |
example["should_prune"] = True | |
print("Answer is: ", answer, example["ground_truth"], example["corrects"]) | |
return example | |
if last_step: | |
# no point in continuing if we are at the last step | |
return example | |
if gen_text[-10:] != "```output\n": | |
# something else has gone wrong with the generation | |
print("warning: output block not found: ", gen_text[-40:]) | |
if restart_on_fail: | |
example["gen_texts"] = example["text"] | |
else: | |
example["should_prune"] = True | |
return example | |
code_result, status = postprocess_completion(gen_text, return_status=True, last_code_block=True) | |
# add the code result for the next round of generation | |
TRUNCATION_LIMIT = 200 | |
if len(code_result) > TRUNCATION_LIMIT: | |
code_result = code_result[:TRUNCATION_LIMIT] + " ... (output truncated)" | |
example["gen_texts"] = gen_text + f"{code_result}\n```" | |
return example | |
def solve_problem(problem, temperature, progress=gr.Progress()): | |
problem = apply_template({"prompt": problem}, prompt=config.system_prompt) | |
print(f"Problem: {problem}") | |
sample = { | |
"problem": problem, # not used for the submission TODO Remove | |
"ground_truth": "unknown", # not used for the submission TODO Remove | |
"text": "## Solution:\n", | |
"gen_texts": "## Solution:\n", # used to store all the generated text | |
"should_prune": False, | |
"problem_index": -1, # not used for the submission TODO Remove | |
"model_answers": "-1", | |
"has_code": True, | |
"corrects": False, # not used for the submission TODO Remove | |
} | |
for step in progress.tqdm( | |
range(config.num_generations), desc="Generating candidates" | |
): # Depth of the tree (e.g. 6 steps = 5 code blocks) | |
step_reponse = sample["gen_texts"] | |
messages = [ | |
{"role": "user", "content": sample["problem"]}, | |
{"role": "assistant", "content": sample["gen_texts"]}, | |
] | |
for reponse_message, error in generate(messages, temperature): | |
if reponse_message is not None: | |
step_reponse += reponse_message | |
yield step_reponse | |
if error: | |
return | |
sample["gen_texts"] = step_reponse | |
# TODO: Maybe it should just return the result of running the code | |
sample = process_code( | |
sample, | |
config=config, | |
restart_on_fail=config.restart_on_fail, | |
last_step=(step == (config.num_generations - 1)), | |
) | |
sample["gen_texts"] = sample["gen_texts"] + "\n" | |
run_code_reponse = sample["gen_texts"].replace(step_reponse, "") | |
for output_mseeage in run_code_reponse: | |
if output_mseeage is not None: | |
step_reponse += output_mseeage | |
yield step_reponse | |
if sample["should_prune"]: | |
break | |
yield sample["gen_texts"] | |
example_data = datasets.load_dataset( | |
"AI-MO/kaggle-validation-set-medium-extended", | |
split="train", | |
use_auth_token=os.environ.get("HF_DATASET_TOKEN", None), | |
) | |
def get_random_problems(): | |
examples = random.sample(list(example_data), 2) | |
problems = [ex["problem"] for ex in examples] | |
return problems[0], problems[1] | |
def copy_problem_to_input(problem): | |
return problem | |
def update_problems(): | |
problem_1_text, problem_2_text = get_random_problems() | |
return ( | |
problem_1_text[:100] + "..." if len(problem_1_text) > 100 else problem_1_text, | |
problem_2_text[:100] + "..." if len(problem_2_text) > 100 else problem_2_text, | |
problem_1_text, | |
problem_2_text, | |
) | |
def clear(): | |
problem_1_display, problem_2_display, problem_1_full, problem_2_full = update_problems() | |
return "", 0.1, "", problem_1_display, problem_2_display, problem_1_full, problem_2_full | |
with open("app.css", "r") as f: | |
css = f.read() | |
latex_delimiters = [ | |
{"left": "[", "right": "]", "display": True}, | |
] | |
with gr.Blocks(css=css, title="Math Olympiad Solver") as demo: | |
with gr.Row(elem_classes="title"): | |
gr.HTML("Math Olympiad Solver", elem_classes="title-content") | |
with gr.Row(elem_classes="sub-title"): | |
gr.HTML("Here may need to add some description string", elem_classes="sub-title-content") | |
with gr.Row(elem_classes="main-area"): | |
with gr.Column(scale=1, elem_classes="left"): | |
with gr.Row(elem_classes="probelm-example-container"): | |
with gr.Blocks(elem_classes="probelm-example-title"): | |
gr.HTML("Probelm example", elem_classes="probelm-example-title-content") | |
with gr.Blocks(elem_classes="action-container"): | |
gr.HTML("another", elem_classes="probelm-example-another") | |
gr.HTML("copy", elem_classes="probelm-example-copy") | |
with gr.Row(elem_classes="copy-icon-container"): | |
gr.Markdown( | |
value="Problem example 2", | |
latex_delimiters=latex_delimiters, | |
elem_classes="probelm-example-content", | |
) | |
with gr.Accordion("Advanced Options", open=False): | |
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.1, step=0.1, label="Temperature") | |
inp = gr.Textbox(placeholder="Problem", label="Problem input", lines=5) | |
with gr.Row(): | |
btn = gr.Button("Run") | |
btn_clear = gr.Button("Clear") | |
with gr.Column(scale=1, elem_classes="right"): | |
gr.HTML("Solution", elem_classes="solution-title-content") | |
out = gr.Markdown(latex_delimiters=latex_delimiters) | |
if __name__ == "__main__": | |
demo.queue(default_concurrency_limit=5).launch() | |