SwiftSage / data_utils.py
yuchenlin's picture
Upload 14 files
1a0cf07 verified
raw
history blame
No virus
12.1 kB
"""
Source and credits: https://github.com/ZubinGou/math-evaluation-harness/blob/main/python_executor.py
"""
import re
import regex
import sympy
from typing import TypeVar, Iterable, List, Union, Any, Dict
from word2number import w2n
from utils import *
def lower_keys(example):
new_example = {}
for key, value in example.items():
if key != key.lower():
new_key = key.lower()
new_example[new_key] = value
else:
new_example[key] = value
return new_example
def _fix_fracs(string):
substrs = string.split("\\frac")
new_str = substrs[0]
if len(substrs) > 1:
substrs = substrs[1:]
for substr in substrs:
new_str += "\\frac"
if len(substr) > 0 and substr[0] == "{":
new_str += substr
else:
try:
assert len(substr) >= 2
except:
return string
a = substr[0]
b = substr[1]
if b != "{":
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}{" + b + "}" + post_substr
else:
new_str += "{" + a + "}{" + b + "}"
else:
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}" + b + post_substr
else:
new_str += "{" + a + "}" + b
string = new_str
return string
def _fix_a_slash_b(string):
if len(string.split("/")) != 2:
return string
a = string.split("/")[0]
b = string.split("/")[1]
try:
if "sqrt" not in a:
a = int(a)
if "sqrt" not in b:
b = int(b)
assert string == "{}/{}".format(a, b)
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
return new_string
except:
return string
def _fix_sqrt(string):
_string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string)
return _string
def convert_word_number(text:str) -> str:
try:
text = str(w2n.word_to_num(text))
except:
pass
return text
# units mainly from MathQA
unit_texts = [
"east", "degree", "mph", "kmph", "ft", "m sqaure", " m east", "sq m", "deg", "mile",
"q .", "monkey", "prime", "ratio", "profit of rs", "rd", "o", "gm",
"p . m", "lb", "tile", "per", "dm", "lt", "gain", "ab", "way", "west",
"a .", "b .", "c .", "d .", "e .", "f .", "g .", "h .", "t", "a", "h",
"no change", "men", "soldier", "pie", "bc", "excess", "st",
"inches", "noon", "percent", "by", "gal", "kmh", "c", "acre", "rise",
"a . m", "th", "π r 2", "sq", "mark", "l", "toy", "coin",
"sq . m", "gallon", "° f", "profit", "minw", "yr", "women",
"feet", "am", "pm", "hr", "cu cm", "square", "v â € ™", "are",
"rupee", "rounds", "cubic", "cc", "mtr", "s", "ohm", "number",
"kmph", "day", "hour", "minute", "min", "second", "man", "woman",
"sec", "cube", "mt", "sq inch", "mp", "∏ cm ³", "hectare", "more",
"sec", "unit", "cu . m", "cm 2", "rs .", "rs", "kg", "g", "month",
"km", "m", "cm", "mm", "apple", "liter", "loss", "yard",
"pure", "year", "increase", "decrease", "d", "less", "Surface",
"litre", "pi sq m", "s .", "metre", "meter", "inch",
]
unit_texts.extend([t + "s" for t in unit_texts])
def strip_string(string):
string = str(string).strip()
# linebreaks
string = string.replace("\n", "")
# right "."
string = string.rstrip(".")
# remove inverse spaces
# replace \\ with \
string = string.replace("\\!", "")
# string = string.replace("\\ ", "")
# string = string.replace("\\\\", "\\")
# matrix
string = re.sub(r'\\begin\{array\}\{.*?\}', r'\\begin{pmatrix}', string)
string = re.sub(r'\\end\{array\}', r'\\end{pmatrix}', string)
string = string.replace("bmatrix", "pmatrix")
# replace tfrac and dfrac with frac
string = string.replace("tfrac", "frac")
string = string.replace("dfrac", "frac")
# remove \left and \right
string = string.replace("\\left", "")
string = string.replace("\\right", "")
string = string.replace("\\{", "{")
string = string.replace("\\}", "}")
# Remove unit: miles, dollars if after is not none
_string = re.sub(r"\\text{.*?}$", "", string).strip()
if _string != "" and _string != string:
# print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
string = _string
# Remove unit: texts
for _ in range(2):
for unit_text in unit_texts:
# use regex, the prefix should be either the start of the string or a non-alphanumeric character
# the suffix should be either the end of the string or a non-alphanumeric character
_string = re.sub(r"(^|\W)" + unit_text + r"($|\W)", r"\1\2", string)
if _string != "":
string = _string
# Remove circ (degrees)
string = string.replace("^{\\circ}", "")
string = string.replace("^\\circ", "")
# remove dollar signs
string = string.replace("\\$", "")
string = string.replace("$", "")
# convert word number to digit
string = convert_word_number(string)
# replace "\\text{...}" to "..."
string = re.sub(r"\\text\{(.*?)\}", r"\1", string)
for key in ['x=', 'y=', 'z=', 'x\\in', 'y\\in', 'z\\in', 'x\\to', 'y\\to', 'z\\to']:
string = string.replace(key, "")
string = string.replace("\\emptyset", r"{}")
string = string.replace("(-\\infty,\\infty)", "\\mathbb{R}")
# remove percentage
string = string.replace("\\%", "")
string = string.replace("\%", "")
string = string.replace("%", "")
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
string = string.replace(" .", " 0.")
string = string.replace("{.", "{0.")
# cdot
# string = string.replace("\\cdot", "")
if string.startswith("{") and string.endswith("}") and string.isalnum() or \
string.startswith("(") and string.endswith(")") and string.isalnum() or \
string.startswith("[") and string.endswith("]") and string.isalnum():
string = string[1:-1]
# inf
string = string.replace("infinity", "\\infty")
if "\\infty" not in string:
string = string.replace("inf", "\\infty")
string = string.replace("+\\inity", "\\infty")
# and
string = string.replace("and", "")
string = string.replace("\\mathbf", "")
# use regex to remove \mbox{...}
string = re.sub(r"\\mbox{.*?}", "", string)
# quote
string.replace("'", "")
string.replace("\"", "")
# i, j
if "j" in string and "i" not in string:
string = string.replace("j", "i")
# replace a.000b where b is not number or b is end, with ab, use regex
string = re.sub(r"(\d+)\.0*([^\d])", r"\1\2", string)
string = re.sub(r"(\d+)\.0*$", r"\1", string)
# if empty, return empty string
if len(string) == 0:
return string
if string[0] == ".":
string = "0" + string
# to consider: get rid of e.g. "k = " or "q = " at beginning
if len(string.split("=")) == 2:
if len(string.split("=")[0]) <= 2:
string = string.split("=")[1]
string = _fix_sqrt(string)
string = string.replace(" ", "")
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
string = _fix_fracs(string)
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
string = _fix_a_slash_b(string)
return string
def extract_multi_choice_answer(pred_str):
# TODO: SFT models
if 'Problem:' in pred_str:
pred_str = pred_str.split("Problem:", 1)[0]
pred_str = pred_str.replace("choice is", "answer is")
patt = regex.search(r"answer is \(?(?P<ans>[abcde])\)?", pred_str.lower())
if patt is not None:
return patt.group('ans').upper()
return 'placeholder'
def extract_answer(pred_str, data_name):
if data_name in ["mmlu_stem", "sat_math", "mathqa"]:
return extract_multi_choice_answer(pred_str)
if 'final answer is $' in pred_str and '$. I hope' in pred_str:
# minerva_math
tmp = pred_str.split('final answer is $', 1)[1]
pred = tmp.split('$. I hope', 1)[0].strip()
elif 'boxed' in pred_str:
ans = pred_str.split('boxed')[-1]
if len(ans) == 0:
return ""
elif ans[0] == '{':
stack = 1
a = ''
for c in ans[1:]:
if (c == '{'):
stack += 1
a += c
elif (c == '}'):
stack -= 1
if (stack == 0): break
a += c
else:
a += c
else:
a = ans.split('$')[0].strip()
pred = a
elif ('he answer is' in pred_str):
pred = pred_str.split('he answer is')[-1].strip()
elif ('final answer is' in pred_str):
pred = pred_str.split('final answer is')[-1].strip()
# elif extract_program_output(pred_str) != "":
# fall back to program
# pred = extract_program_output(pred_str)
else: # use the last number
pattern = '-?\d*\.?\d+'
pred = re.findall(pattern, pred_str.replace(",", ""))
if(len(pred) >= 1):
pred = pred[-1]
else: pred = ''
# multiple line
# pred = pred.split("\n")[0]
pred = re.sub(r"\n\s*", "", pred)
if pred != "" and pred[0] == ":":
pred = pred[1:]
if pred != "" and pred[-1] == ".":
pred = pred[:-1]
if pred != "" and pred[-1] == "/":
pred = pred[:-1]
pred = strip_string(pred)
return pred
def parse_ground_truth(example: Dict[str, Any], data_name):
# parse ground truth
if data_name in ["MATH", "math", "math_oai", "minerva_math", "ocw", "amps", "hungarian_exam"]:
gt_ans = example['answer']
elif data_name == "gsm8k":
gt_ans = example['answer'].split("####")[-1]
elif data_name == "mmlu_stem":
abcd = 'ABCD'
gt_ans = abcd[example['answer']]
elif data_name == "gpqa":
gt_ans = example['correct answer']
else:
raise NotImplementedError(f"`{data_name}`")
# post process
gt_ans = strip_string(gt_ans)
return gt_ans
def parse_question(example, data_name):
question = ""
if data_name == "mmlu_stem":
options = example['choices']
assert len(options) == 4
for i, (label, option) in enumerate(zip('ABCD', options)):
options[i] = f"({label}) {str(option).strip()}"
options = ", ".join(options)
question = f"{example['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options}"
else:
for key in ['question', 'problem', 'Question', 'input']:
if key in example:
question = example[key]
break
assert question != ""
# Yes or No question
gt_ans = parse_ground_truth(example, data_name)
gt_lower = gt_ans.lower()
if gt_lower in ["true", "false"]:
question += " (True or False)"
if gt_lower in ["yes", "no"]:
question += " (Yes or No)"
return question.strip()
def _test_extract_answer():
text= """
The answer is $\\boxed{\left(
\\begin{array}{ccc}
-13 & 4 & -2 \\\\
7 & 8 & -3 \\\\
0 & 18 & -7 \\\\
6 & 12 & 5 \\\\
\\end{array}
\\right)}$.
"""
print(extract_answer(text, "math"))
# should output a dict
if __name__ == "__main__":
_test_extract_answer()