import gradio as gr from dataclasses import dataclass from concurrent.futures import ThreadPoolExecutor, TimeoutError from huggingface_hub import InferenceClient import os import re import subprocess import tempfile import json import datasets from datasets import load_dataset from datasets import Value, Features import random import time from typing import Tuple, Dict, Any, List from sympy import N, simplify from sympy.parsing.latex import parse_latex #from openai import OpenAI import base64 from transformers import AutoModelForSequenceClassification, AutoTokenizer from transformers import AutoTokenizer, AutoModelForPreTraining from langchain_community.llms.manifest import ManifestWrapper #client = OpenAI( # base_url=os.environ.get("SERVER_URL"), # api_key=os.environ.get("HF_TOKEN"), #) client = InferenceClient("HuggingFaceH4/zephyr-7b-beta") @dataclass class Config: debug: bool = False push_to_hub: bool = False model_id: str = None revision: str = None system_prompt: str = None validation_set: str = None is_quantized: bool = False restart_on_fail: bool = False is_submission: bool = False num_samples: int = 1 num_generations: int = 1 do_sample: bool = True temperature: float = 1.0 top_p: float = 0.9 top_k: int = 50 max_new_tokens: int = 100 # Load pre-trained Wit Transformer model and tokenizer tokenizer = AutoTokenizer.from_pretrained("AnReu/math_pretrained_bert") model = AutoModelForPreTraining.from_pretrained("AnReu/math_pretrained_bert") class PythonREPL: def __init__(self, timeout=5): self.timeout = timeout def execute(self, query: str) -> Tuple[bool, str]: query = "import math\nimport numpy as np\nimport sympy as sp\n" + query query = query.strip().split("\n") if "print(" not in query[-1]: if "#" in query[-1]: query[-1] = query[-1].split("#")[0] query[-1] = "print(" + query[-1] + ")" query = "\n".join(query) with tempfile.TemporaryDirectory() as temp_dir: temp_file_path = os.path.join(temp_dir, "tmp.py") with open(temp_file_path, "w") as f: f.write(query) result = subprocess.run( ["python3", temp_file_path], capture_output=True, check=False, text=True, timeout=self.timeout, ) if result.returncode == 0: output = result.stdout return True, output.strip() else: error_msg = result.stderr.strip() msgs = error_msg.split("\n") new_msgs = [] want_next = False for m in msgs: if "Traceback" in m: new_msgs.append(m) elif m == msgs[-1]: new_msgs.append(m) elif temp_file_path in m: st = m.index('"/') + 1 if '"/' in m else 0 ed = m.index(temp_file_path) + 1 if temp_file_path in m else None clr = m[st:ed] if not ed else m[st:] m = m.replace(clr, "") new_msgs.append(m) want_next = True elif want_next: new_msgs.append(m) want_next = False error_msg = "\n".join(new_msgs) return False, error_msg.strip() def __call__(self, query: str) -> Tuple[bool, str]: with ThreadPoolExecutor() as executor: future = executor.submit(self.execute, query) try: return future.result(timeout=self.timeout) except TimeoutError: return False, f"Timed out after {self.timeout} seconds." def execute_completion( executor: PythonREPL, completion: str, return_status: bool = False, last_code_block: bool = False, ) -> str | Tuple[str, bool]: # executions = ["!" + code for code in re.findall(r"```bash(.*?)```", completion, re.DOTALL) if "!" not in code] executions = re.findall(r"```python(.*?)```", completion, re.DOTALL) if len(executions) == 0: # directly return cot result return completion, False if return_status else completion else: if last_code_block: executions = [executions[-1]] # Python execution_outputs = [] successes = [] for code in executions: success = False if "subprocess" in code: output = "subprocess is not allowed" execution_outputs.append(output) successes.append(success) continue if "venv" in code: output = "venv is not allowed" execution_outputs.append(output) successes.append(success) continue try: success, output = executor(code) except TimeoutError as e: print("time out") output = e if not success and not return_status: output = "" execution_outputs.append(output) successes.append(success) output = str(execution_outputs[-1]).strip() success = successes[-1] if return_status: return output, success else: return output def postprocess_completion( text: str, return_status: bool = False, last_code_block=False, timeout=5 ) -> str | Tuple[str, bool]: executor = PythonREPL(timeout=timeout) result = execute_completion(executor, text, return_status=return_status, last_code_block=last_code_block) del executor return result def apply_template(example: Dict[str, Any], prompt: str) -> Dict[str, Any]: return prompt.format(example["prompt"], "{}") def last_boxed_only_string(string): """ Extracts the last LaTeX boxed or framed expression from a string. Args: string (str): The input string containing LaTeX expressions. Returns: str or None: The last boxed or framed expression, if found; otherwise, None. """ idx = string.rfind("\\boxed") if idx < 0: idx = string.rfind("\\fbox") if idx < 0: return None i = idx right_brace_idx = None num_left_braces_open = 0 while i < len(string): if string[i] == "{": num_left_braces_open += 1 if string[i] == "}": num_left_braces_open -= 1 if num_left_braces_open == 0: right_brace_idx = i break i += 1 if right_brace_idx is None: retval = None else: retval = string[idx : right_brace_idx + 1] return retval def remove_boxed(s): """ Removes the LaTeX boxed command, returning the content inside the braces. Args: s (str): The string containing a LaTeX boxed expression. Returns: str or None: The content inside the boxed command, if valid; otherwise, None. """ left = "\\boxed{" try: assert s[: len(left)] == left assert s[-1] == "}" length = len(left) return s[length:-1] except Exception: return None def extract_boxed_answer(pred_str, strip_double_curly_brace=False): """ Extracts the answer from a LaTeX boxed expression within a prediction string. Args: pred_str (str): The string containing one or more LaTeX boxed expressions. strip_double_curly_brace (bool): If True, removes an additional layer of braces. Returns: str or None: The extracted answer, if any; otherwise, None. """ boxed_str = last_boxed_only_string(pred_str) if boxed_str is None: return None answer = remove_boxed(boxed_str) if answer is None: return None if strip_double_curly_brace: match = re.match("^\{(.*)\}$", answer) # noqa: W605 if match: answer = match.group(1) return answer def normalize_final_answer(final_answer: str) -> str: """ Normalizes a final answer string by removing or replacing various LaTeX and text elements. Args: final_answer (str): The answer string to normalize. Returns: str: The normalized answer string. """ match = re.search(r"(.*?)Problem:", final_answer, flags=re.S) if match: final_answer = match.group(1) # 返回匹配的第一部分,即"Problem"之前的所有文本 """Normalize a final answer to a quantitative reasoning question.""" # final_answer = final_answer.split('=')[-1] SUBSTITUTIONS = [ ("an ", ""), ("a ", ""), (".$", "$"), ("\\$", ""), (r"\ ", ""), (" ", ""), ("mbox", "text"), (",\\text{and}", ","), ("\\text{and}", ","), ("\\text{m}", "\\text{}"), ("\\le", "<"), ] REMOVED_EXPRESSIONS = [ "square", "ways", "integers", "dollars", "mph", "inches", "ft", "hours", "km", "units", "\\ldots", "sue", "points", "feet", "minutes", "digits", "cents", "degrees", "cm", "gm", "pounds", "meters", "meals", "edges", "students", "childrentickets", "multiples", "\\text{s}", "\\text{.}", "\\text{\ns}", "\\text{}^2", "\\text{}^3", "\\text{\n}", "\\text{}", r"\mathrm{th}", r"^\circ", r"^{\circ}", r"\;", r",\!", "{,}", '"', "\\dots", "\n", "\r", "\f", "\%", ] for before, after in SUBSTITUTIONS: final_answer = final_answer.replace(before, after) for expr in REMOVED_EXPRESSIONS: final_answer = final_answer.replace(expr, "") # Extract answer that is in LaTeX math, is bold, # is surrounded by a box, etc. final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) assert "\n" not in final_answer assert "\r" not in final_answer assert "\f" not in final_answer if len(re.findall(r"finalansweris(.*)", final_answer)) > 0: final_answer = re.findall(r"finalansweris(.*)", final_answer)[-1] if len(re.findall(r"answer?is:?(.*)", final_answer)) > 0: final_answer = re.findall(r"answer?is:?(.*)", final_answer)[-1] if len(re.findall(r"oxed\{(.*?)\}", final_answer)) > 0: final_answer = re.findall(r"oxed\{(.*?)\}", final_answer)[-1] if len(re.findall(r"\$(.*?)\$", final_answer)) > 0: final_answer = re.findall(r"\$(.*?)\$", final_answer)[-1] final_answer = final_answer.strip() if "rac" in final_answer and "\\frac" not in final_answer: final_answer = final_answer.replace("rac", "\\frac") final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) final_answer = final_answer.replace("$", "") if final_answer.replace(",", "").isdigit(): final_answer = final_answer.replace(",", "") return final_answer def naive_parse(answer: str) -> str: """ Extracts and returns the numeric digits from the input string, processing them in reverse order until a non-numeric character is encountered after encountering the first numeric character. Args: answer (str): The input string to parse. Returns: str: A string consisting of the numeric digits extracted from the input, in their original order. Example: >>> naive_parse("abc123def") '123' >>> naive_parse("def456ghi") '456' >>> naive_parse("no numbers here") '' """ out = [] start = False end = False for l in reversed(list(answer)): if l in "0123456789" and not end: start = True out.append(l) else: if start: end = True out = reversed(out) return "".join(out) def validate_answer_is_numeric(x: str | int | float) -> int: FLOAT_TOLERANCE = 0.2 try: x = round(float(x)) f = float(x) if abs(x - f) > FLOAT_TOLERANCE: x = -1 except Exception: x = -1 return x def filter_answers(answers: List[str]) -> List[int]: formatted_answers = [validate_answer_is_numeric(a) for a in answers] # Filter for non-negative answers formatted_answers = [a for a in formatted_answers if a >= 0] # Compute modulo formatted_answers = [a % 1_000 for a in formatted_answers] # less than 2.1 billion or cannot convert to C int (32-bit) formatted_answers = [a for a in formatted_answers if a <= 999] return formatted_answers def check_sympy_equivalence(ref_answer: str, model_answer: str) -> bool: def do_answers_match(ref_answer: str, model_answer: str) -> bool: ref_sympy = parse_latex(ref_answer) model_sympy = parse_latex(model_answer) diff = simplify(ref_sympy - model_sympy) return True if -1e-12 < N(diff) < 1e-12 or diff.is_zero else False try: result = do_answers_match(ref_answer, model_answer) return result except Exception as e: print(e) return False def check_string_match(ref_answer: str, model_answer: str) -> bool: try: return ref_answer == model_answer except Exception as e: print(e) return False def check_answer(ref_answer: str, model_answer: str) -> bool: # check if strings are the same correct = check_string_match(ref_answer, model_answer) if correct: return True # use the sympy library to check if the expressions are the same correct = check_sympy_equivalence(ref_answer, model_answer) if correct: return True return False debug = False model_id = "athstral-7B-v0.m1" revision = "main" system_prompt = "{}" validation_set = "kaggle-validation-set-medium" is_submission = True num_samples = 4 num_generations = 4 temperature = 0.8 is_quantized = False restart_on_fail = False top_p = 1.0 top_k = 0 max_new_tokens = 2048 # Papermill related variables push_to_hub = False notebook_name = "" config = Config( debug=False, push_to_hub=False, model_id=model_id, revision=revision, system_prompt=system_prompt, validation_set=validation_set, is_quantized=is_quantized, restart_on_fail=restart_on_fail, is_submission=is_submission, num_samples=num_samples, num_generations=num_generations, do_sample=True, temperature=temperature, top_p=top_p, top_k=top_k, max_new_tokens=max_new_tokens ) print(f"=== Running submission with config ===\n\n{config}") def parse_data_chunk(data_chunk): """ Parse a given data chunk string into a list of individual data entries. The function splits the input string by the delimiter "data:" and removes any leading or trailing whitespace from each resulting chunk. Empty chunks are filtered out from the final list. Parameters: data_chunk (str): The input string containing data chunks separated by "data:". Returns: list: A list of individual data entries with whitespace stripped. """ if isinstance(data_chunk, client.ChatCompletionStreamOutput): data_chunk = data_chunk.text chunks = data_chunk.split("data:") def parse_data_chunk(data_chunk): """ Parse a given data chunk string into a list of individual data entries. The function splits the input string by the delimiter "data:" and removes any leading or trailing whitespace from each resulting chunk. Empty chunks are filtered out from the final list. Parameters: data_chunk (str): The input string containing data chunks separated by "data:". Returns: list: A list of individual data entries with whitespace stripped. """ if isinstance(data_chunk, InferenceClient.ChatCompletionStreamOutput): # Update this line if you're using a different client class data_chunk = data_chunk.text chunks = data_chunk.split("data:") for chunk in response: chunk = chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk data_chunks = parse_data_chunk(chunk) try: for data_chunk in data_chunks: chunk_json = json.loads(data_chunk) if "error" in chunk_json and chunk_json["error"]: yield chunk_json["error"], True break delta = chunk_json["choices"][0]["delta"] content = delta["content"] if "content" in delta else "" if content != "": yield content, False except json.JSONDecodeError as e: print(f"func: generate error occurred\nchunk:{chunk}\nerror:{e}") raise e except KeyError as e: print(f"func: generate error occurred\nchunk:{chunk}\nerror:{e}") raise e def get_majority_text(data): from collections import Counter # Count the frequency of each answer in model_answers answer_counts = Counter(data["model_answers"]) # Find the majority response majority_response = answer_counts.most_common(1)[0][0] # Find the index of the first occurrence of the majority response majority_index = data["model_answers"].index(majority_response) # Return the corresponding text in gen_texts return data["gen_texts"][majority_index] def extract_solution(text): # Split the text at "### Solution:" parts = text.split("### Solution:", 1) if len(parts) > 1: # Return everything after "### Solution:" return parts[1].strip() else: # Return an empty string if "### Solution:" is not found return "" def process_code( example: Dict[str, Any], config: Config, restart_on_fail: bool = False, last_step: bool = False, ) -> Dict[str, Any]: gen_text = example["gen_texts"] num_python_blocks = len(re.findall(r"```python(.*?)```", gen_text, re.DOTALL)) if num_python_blocks == 0: if restart_on_fail: print("no code has ever been generated, RESTARTING") # reset the text to the original example["gen_texts"] = example["text"] else: print("no code has ever been generated, STOP") example["should_prune"] = True example["has_code"] = False return example if gen_text[-10:] != "```output\n" and ("answer is" in gen_text[-100:] or "\\boxed" in gen_text[-100:]): num_output_blocks = len(re.findall(r"```output(.*?)```", gen_text, re.DOTALL)) if num_output_blocks == 0: print("the model hallucinated the code answer") example["should_prune"] = True return example if "boxed" in gen_text[-100:]: try: answer = normalize_final_answer(extract_boxed_answer(gen_text[-100:])) except Exception: answer = "-1" else: answer = normalize_final_answer(gen_text[-100:]) example["model_answers"] = answer if not config.is_submission: example["corrects"] = check_answer(example["ground_truth"], answer) example["should_prune"] = True print("Answer is: ", answer, example["ground_truth"], example["corrects"]) return example if last_step: # no point in continuing if we are at the last step return example if gen_text[-10:] != "```output\n": # something else has gone wrong with the generation print("warning: output block not found: ", gen_text[-40:]) if restart_on_fail: example["gen_texts"] = example["text"] else: example["should_prune"] = True return example code_result, status = postprocess_completion(gen_text, return_status=True, last_code_block=True) # add the code result for the next round of generation TRUNCATION_LIMIT = 200 if len(code_result) > TRUNCATION_LIMIT: code_result = code_result[:TRUNCATION_LIMIT] + " ... (output truncated)" example["gen_texts"] = gen_text + f"{code_result}\n```" return example def solve_problem(problem, temperature, progress=gr.Progress()): """ yield token: string, stop: bool """ problem = apply_template({"prompt": problem}, prompt=config.system_prompt) print(f"Problem: {problem}") sample = { "problem": problem, # not used for the submission TODO Remove "ground_truth": "unknown", # not used for the submission TODO Remove "text": "## Solution:\n", "gen_texts": "", # used to store all the generated text "should_prune": False, "problem_index": -1, # not used for the submission TODO Remove "model_answers": "-1", "has_code": True, "corrects": False, # not used for the submission TODO Remove } for step in progress.tqdm( range(config.num_generations), desc="Generating candidates" ): # Depth of the tree (e.g. 6 steps = 5 code blocks) step_reponse = sample["gen_texts"] messages = [ {"role": "user", "content": sample["problem"]}, {"role": "assistant", "content": sample["gen_texts"]}, ] stop = False for reponse_message, error in generate(messages, temperature): if reponse_message is not None: step_reponse += reponse_message yield step_reponse, False if error: stop = True sample["gen_texts"] = step_reponse # TODO: Maybe it should just return the result of running the code sample = process_code( sample, config=config, restart_on_fail=config.restart_on_fail, last_step=(step == (config.num_generations - 1)), ) sample["gen_texts"] = sample["gen_texts"] + "\n" run_code_reponse = sample["gen_texts"].replace(step_reponse, "") for output_mseeage in run_code_reponse: if output_mseeage is not None: step_reponse += output_mseeage yield step_reponse, False if sample["should_prune"] or stop: break yield sample["gen_texts"], True features = Features({ 'id': Value('int64'), 'problem': Value('string'), 'answer': Value('string'), #'prompt': Value('string'), # Ensure this matches the actual data type of 'prompt' in your dataset #'level': Value('string') }) # Now load the dataset using the defined schema example_data = datasets.load_dataset( "AI-MO/aimo-validation-math-level-5", split="train", use_auth_token=os.environ.get("HF_DATASET_TOKEN", None), features=features # Pass the schema definition here ) with open( "app.css", "r") as f: css = f.read() latex_delimiters = [ {"left": "[", "right": "]", "display": True}, ] def get_random_problem(): example = random.choice(list(example_data)) problem = example["problem"] return problem def update_example_problem(): problem_example_text = get_random_problem() return problem_example_text, problem_example_text def clear(): problem_example_text = get_random_problem() return "", 0.1, "", problem_example_text, problem_example_text def preprocess_output(text): return text.replace(r"\(", r"\\(").replace(r"\)", r"\\)") with gr.Blocks(css=css, title="Math Olympiad Solver") as demo: btn_list = [] problem_input_ele_list = [] problem_example_text = get_random_problem() with gr.Row(elem_classes="title"): gr.HTML("Math Olympiad Solver", elem_classes="title-content") with gr.Row(elem_classes="sub-title"): gr.HTML( "
Demo of the maths solving with AI Models. Example data are drawn randomly generated.
", elem_classes="sub-title-content", ) with gr.Row(elem_classes="main-area"): with gr.Column(scale=1, elem_classes="left"): with gr.Row(elem_classes="probelm-example-container"): with gr.Blocks(elem_classes="probelm-example-title"): gr.HTML("Problem example", elem_classes="probelm-example-title-content") with gr.Blocks(elem_classes="action-container"): another_btn = gr.Button( "", elem_classes="probelm-example-another", icon="./static/images/reset.png", ) copy_btn = gr.Button("Copy", elem_classes="probelm-example-copy") problem_example = gr.HTML( problem_example_text, elem_classes="probelm-example-content", ) with gr.Row(elem_classes="probelm-input-container"): inp = gr.Textbox(placeholder="Problem", label="Problem input", lines=5, visible=True) problem_markdown = gr.Markdown( visible=False, latex_delimiters=[ {"left": "[", "right": "]", "display": True}, {"left": "$", "right": "$", "display": False}, {"left": r"\(", "right": r"\)", "display": False}, ], ) inp.change(fn=lambda text: text, inputs=[inp], outputs=[problem_markdown]) problem_input_ele_list.append(inp) problem_input_ele_list.append(problem_markdown) with gr.Accordion("Advanced Options", open=False): temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.1, step=0.1, label="Temperature") with gr.Row() as btn_area: btn_clear = gr.Button("Clear", elem_classes="clear-btn") btn_run = gr.Button("Run", elem_classes="run-btn") btn_list.append(btn_clear) btn_list.append(btn_run) with gr.Column(scale=1, elem_classes="right"): gr.HTML("Solution", elem_classes="solution-title-content") out = gr.Markdown( elem_classes="solution-content", latex_delimiters=[ {"left": "[", "right": "]", "display": True}, {"left": "$", "right": "$", "display": False}, {"left": r"\(", "right": r"\)", "display": False}, ], ) problem_example_text_hidden = gr.Markdown(value=problem_example_text, visible=False) def solve_problem_wrapper(inp_text, temperature): new_running_btn = gr.Button("", elem_classes="run-btn running-btn") try: for after_tokens, stop in solve_problem(inp_text, temperature): yield preprocess_output(after_tokens), new_running_btn if stop: btn_run = gr.Button("Run", elem_classes="run-btn") yield preprocess_output(after_tokens), btn_run except Exception as e: raise e def mount_run_btn(btn): btn.click(fn=solve_problem_wrapper, inputs=[inp, temperature], outputs=[out, btn_list[1]]) btn.click(get_run_after_problem_input, None, outputs=problem_input_ele_list) def get_run_after_problem_input(): return gr.Textbox(placeholder="Problem", label="Problem input", lines=5, visible=False), gr.Markdown( visible=True, latex_delimiters=[ {"left": "[", "right": "]", "display": True}, {"left": "$", "right": "$", "display": False}, ], elem_classes="problem-input-markdown", ) def get_init_problem_input(): return gr.Textbox(placeholder="Problem", label="Problem input", lines=5, visible=True), gr.Markdown( visible=False, latex_delimiters=[ {"left": "[", "right": "]", "display": True}, {"left": "$", "right": "$", "display": False}, ], ) copy_btn.click(fn=lambda example: example, inputs=[problem_example_text_hidden], outputs=[inp]) btn_clear.click( fn=clear, inputs=[], outputs=[ inp, temperature, out, problem_example, problem_example_text_hidden, ], ) btn_clear.click(get_init_problem_input, None, outputs=problem_input_ele_list) mount_run_btn(btn_run) demo.load( update_example_problem, inputs=None, outputs=[ problem_example, problem_example_text_hidden, ], ) another_btn.click( fn=update_example_problem, inputs=[], outputs=[ problem_example, problem_example_text_hidden, ], ) if __name__ == "__main__": demo.queue(default_concurrency_limit=5).launch(share=True)