""" Source and credits: https://github.com/ZubinGou/math-evaluation-harness/blob/main/python_executor.py This logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from: - https://github.com/microsoft/ProphetNet/tree/master/CRITIC - https://github.com/openai/prm800k - https://github.com/microsoft/ToRA/blob/main/src/eval/grader.py - https://github.com/deepseek-ai/DeepSeek-Math/blob/main/evaluation/eval/eval_utils.py """ import re import regex import multiprocessing from math import isclose from typing import Union from sympy import simplify, N from sympy.parsing.sympy_parser import parse_expr from sympy.parsing.latex import parse_latex from latex2sympy2 import latex2sympy def parse_digits(num): num = regex.sub(',', '', str(num)) try: return float(num) except: if num.endswith('%'): num = num[:-1] if num.endswith('\\'): num = num[:-1] try: return float(num) / 100 except: pass return None def is_digit(num): # paired with parse_digits return parse_digits(num) is not None def str_to_pmatrix(input_str): input_str = input_str.strip() matrix_str = re.findall(r'\{.*,.*\}', input_str) pmatrix_list = [] for m in matrix_str: m = m.strip('{}') pmatrix = r'\begin{pmatrix}' + m.replace(',', '\\') + r'\end{pmatrix}' pmatrix_list.append(pmatrix) return ', '.join(pmatrix_list) def math_equal(prediction: Union[bool, float, str], reference: Union[float, str], include_percentage: bool = True, is_close: bool = True, timeout: bool = False, ) -> bool: """ Exact match of math if and only if: 1. numerical equal: both can convert to float and are equal 2. symbolic equal: both can convert to sympy expression and are equal """ # print("Judge:", prediction, reference) if str(prediction) == str(reference): return True try: # 1. numerical equal if is_digit(prediction) and is_digit(reference): prediction = parse_digits(prediction) reference = parse_digits(reference) # number questions if include_percentage: gt_result = [reference / 100, reference, reference * 100] else: gt_result = [reference] for item in gt_result: try: if is_close: if numeric_equal(prediction, item): return True else: if item == prediction: return True except Exception: continue return False except: pass if not prediction and prediction not in [0, False]: return False # print("try math_eval") # 2. symbolic equal reference = str(reference).strip() prediction = str(prediction).strip() ## pmatrix (amps) if "pmatrix" in prediction and not 'pmatrix' in reference: reference = str_to_pmatrix(reference) ## deal with [], (), {} pred_str, ref_str = prediction, reference if (prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(")) or \ (prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[")): pred_str = pred_str.strip("[]()") ref_str = ref_str.strip("[]()") for s in ['{', "}", "(", ")"]: ref_str = ref_str.replace(s, "") pred_str = pred_str.replace(s, "") if pred_str.lower() == ref_str.lower(): return True ## [a, b] vs. [c, d], return a==c and b==d if regex.match(r'(\(|\[).+(\)|\])', prediction) is not None and regex.match(r'(\(|\[).+(\)|\])', reference) is not None: pred_parts = prediction[1:-1].split(",") ref_parts = reference[1:-1].split(",") if len(pred_parts) == len(ref_parts): if all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]): return True if (prediction.startswith("\\begin{pmatrix}") or prediction.startswith("\\begin{bmatrix}")) and (prediction.endswith("\\end{pmatrix}") or prediction.endswith("\\end{bmatrix}")) and \ (reference.startswith("\\begin{pmatrix}") or reference.startswith("\\begin{bmatrix}")) and (reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")): pred_lines = [line.strip() for line in prediction[len("\\begin{pmatrix}"): -len("\\end{pmatrix}")].split("\\\\") if line.strip()] ref_lines = [line.strip() for line in reference[len("\\begin{pmatrix}"): -len("\\end{pmatrix}")].split("\\\\") if line.strip()] matched = True if len(pred_lines) == len(ref_lines): for pred_line, ref_line in zip(pred_lines, ref_lines): pred_parts = pred_line.split("&") ref_parts = ref_line.split("&") if len(pred_parts) == len(ref_parts): if not all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]): matched = False break else: matched = False if not matched: break else: matched = False if matched: return True if prediction.count('=') == 1 and reference.count('=') == 1: pred = prediction.split('=') pred = f"{pred[0].strip()} - ({pred[1].strip()})" ref = reference.split('=') ref = f"{ref[0].strip()} - ({ref[1].strip()})" if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref): return True elif prediction.count('=') == 1 and len(prediction.split('=')[0].strip()) <= 2 and '=' not in reference: if math_equal(prediction.split('=')[1], reference, include_percentage, is_close): return True elif reference.count('=') == 1 and len(reference.split('=')[0].strip()) <= 2 and '=' not in prediction: if math_equal(prediction, reference.split('=')[1], include_percentage, is_close): return True # print("try final") # symbolic equal with sympy if timeout: if call_with_timeout(symbolic_equal_process, prediction, reference): return True else: if symbolic_equal(prediction, reference): return True return False def math_equal_process(param): return math_equal(param[-2], param[-1]) def numeric_equal(prediction: float, reference: float): # Note that relative tolerance has significant impact # on the result of the synthesized gsm_hard dataset # if reference.is_integer(): # return isclose(reference, round(prediction), abs_tol=1e-4) # else: # prediction = round(prediction, len(str(reference).split(".")[-1])) return isclose(reference, prediction, rel_tol=1e-4) def symbolic_equal(a, b): def _parse(s): for f in [parse_latex, parse_expr, latex2sympy]: try: return f(s.replace("\\\\", "\\")) except: try: return f(s) except: pass return s a = _parse(a) b = _parse(b) # direct equal try: if str(a) == str(b) or a == b: return True except: pass # print("try simplify") # simplify equal try: if a.equals(b) or simplify(a-b) == 0: return True except: pass # print("try equation") # equation equal try: if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)): return True except: pass try: if numeric_equal(float(N(a)), float(N(b))): return True except: pass # matrix try: # if a and b are matrix if a.shape == b.shape: _a = a.applyfunc(lambda x: round(x, 3)) _b = b.applyfunc(lambda x: round(x, 3)) if _a.equals(_b): return True except: pass return False def symbolic_equal_process(a, b, output_queue): result = symbolic_equal(a, b) output_queue.put(result) def call_with_timeout(func, *args, timeout=1, **kwargs): output_queue = multiprocessing.Queue() process_args = args + (output_queue,) process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs) process.start() process.join(timeout) if process.is_alive(): process.terminate() process.join() return False return output_queue.get() def _test_math_equal(): # print(math_equal("0.0833333333333333", "\\frac{1}{12}")) # print(math_equal("(1,4.5)", "(1,\\frac{9}{2})")) # print(math_equal("\\frac{x}{7}+\\frac{2}{7}", "\\frac{x+2}{7}", timeout=True)) # print(math_equal("\\sec^2(y)", "\\tan^2(y)+1", timeout=True)) # print(math_equal("\\begin{pmatrix}-\\frac{7}{4}&-2\\\\4&\\frac{1}{4}\\end{pmatrix}", "(\\begin{pmatrix}-\\frac{7}{4}&-2\\\\4&\\frac{1}{4}\\\\\\end{pmatrix})", timeout=True)) # pred = '\\begin{pmatrix}\\frac{1}{3x^{2/3}}&0&0\\\\0&1&0\\\\-\\sin(x)&0&0\\end{pmatrix}' # gt = '(\\begin{pmatrix}\\frac{1}{3\\sqrt[3]{x}^2}&0&0\\\\0&1&0\\\\-\\sin(x)&0&0\\\\\\end{pmatrix})' # pred= '-\\frac{8x^2}{9(x^2-2)^{5/3}}+\\frac{2}{3(x^2-2)^{2/3}}' # gt= '-\\frac{2(x^2+6)}{9(x^2-2)\\sqrt[3]{x^2-2}^2}' # pred = '-34x-45y+20z-100=0' # gt = '34x+45y-20z+100=0' # pred = '\\frac{100}{3}' # gt = '33.3' # pred = '\\begin{pmatrix}0.290243531202435\\\\0.196008371385084\\\\-0.186381278538813\\end{pmatrix}' # gt = '(\\begin{pmatrix}0.29\\\\0.196\\\\-0.186\\\\\\end{pmatrix})' # pred = '\\frac{\\sqrt{\\sqrt{11}+\\sqrt{194}}}{2\\sqrt{33}+15}' # gt = '\\frac{\\sqrt{\\sqrt{11}+\\sqrt{194}}}{15+2\\sqrt{33}}' # pred = '(+5)(b+2)' # gt = '(a+5)(b+2)' # pred = '\\frac{1+\\sqrt{5}}{2}' # gt = '2' # pred = '\\frac{34}{16}+\\frac{\\sqrt{1358}}{16}', gt = '4' # pred = '1', gt = '1\\\\sqrt{19}' pred = '(0.6,2.6667]' gt = '(\\frac{3}{5},\\frac{8}{3}]' print(math_equal(pred, gt, timeout=True)) if __name__ == "__main__": _test_math_equal()