Spaces:
Running
Running
import logging | |
import re | |
NUMERIC_IN_EN = r"(?:[\s=+-/<>($:\.\*\\])(?=\S)((?:0|(?:\d{1,3}(?:,\d{3})+(?=\D|$))|(?:\d+))(?:\.\d+)?%?)(?:(?![^\s=+-/>)$:\.\*\\])|(?=, ))" | |
NUMERIC_IN_ZH = ( | |
r"(?:\D|^)((?:0|(?:\d{1,3}(?:,\d{3})+(?=\D|$))|(?:\d+))(?:\.\d+)?%?)(?=\D|$)" | |
) | |
def extract_choice_ans(text): | |
pattern1 = r"\b[ABCDabcd]\b" | |
pattern2 = r"\([ABCDabcd]\)" | |
matches1 = re.findall(pattern1, text) | |
matches2 = re.findall(pattern2, text) | |
matches = matches1 + matches2 | |
def standardize(ans): | |
return ans if len(ans) == 1 else ans[1] | |
return standardize(matches[-1]).lower() if matches else "_" | |
def extract_numeric(string, pattern=NUMERIC_IN_EN) -> str: | |
all_values = list( | |
filter(lambda x: len(x.strip()) != 0 and x != "%", re.findall(pattern, string)) | |
) | |
def standardize(x): | |
y = "".join(x.split(",")) | |
if "." in y: | |
y = y.rstrip("0") | |
if y[-1] == ".": | |
y = y[:-1] | |
if y[0] == ".": | |
y = "0" + y | |
if y[-1] == "%": | |
y = str(eval(y[:-1]) / 100) | |
return y | |
if not len(all_values): | |
logging.debug(f"No numeric value found in string: {string}") | |
value = string | |
else: | |
value = standardize(all_values[-1].strip()) | |
return value | |
def remove_boxed(s): | |
if "\\boxed " in s: | |
left = "\\boxed " | |
assert s[: len(left)] == left | |
return s[len(left) :] | |
left = "\\boxed{" | |
assert s[: len(left)] == left | |
assert s[-1] == "}" | |
return s[len(left) : -1] | |
def last_boxed_only_string(string): | |
idx = string.rfind("\\boxed") | |
if "\\boxed " in string: | |
return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0] | |
if idx < 0: | |
idx = string.rfind("\\fbox") | |
if idx < 0: | |
return None | |
i = idx | |
right_brace_idx = None | |
num_left_braces_open = 0 | |
while i < len(string): | |
if string[i] == "{": | |
num_left_braces_open += 1 | |
if string[i] == "}": | |
num_left_braces_open -= 1 | |
if num_left_braces_open == 0: | |
right_brace_idx = i | |
break | |
i += 1 | |
if right_brace_idx is None: | |
retval = None | |
else: | |
retval = string[idx : right_brace_idx + 1] | |
return retval | |
def fix_sqrt(string): | |
if "\\sqrt" not in string: | |
return string | |
splits = string.split("\\sqrt") | |
new_string = splits[0] | |
for split in splits[1:]: | |
if split[0] != "{": | |
a = split[0] | |
new_substr = "\\sqrt{" + a + "}" + split[1:] | |
else: | |
new_substr = "\\sqrt" + split | |
new_string += new_substr | |
return new_string | |
def remove_right_units(string): | |
# "\\text{ " only ever occurs (at least in the val set) when describing units | |
if "\\text{ " in string: | |
splits = string.split("\\text{ ") | |
assert len(splits) == 2 | |
return splits[0] | |
else: | |
return string | |
def fix_fracs(string): | |
substrs = string.split("\\frac") | |
new_str = substrs[0] | |
if len(substrs) > 1: | |
substrs = substrs[1:] | |
for substr in substrs: | |
new_str += "\\frac" | |
if substr[0] == "{": | |
new_str += substr | |
else: | |
try: | |
assert len(substr) >= 2 | |
except AssertionError: | |
return string | |
a = substr[0] | |
b = substr[1] | |
if b != "{": | |
if len(substr) > 2: | |
post_substr = substr[2:] | |
new_str += "{" + a + "}{" + b + "}" + post_substr | |
else: | |
new_str += "{" + a + "}{" + b + "}" | |
else: | |
if len(substr) > 2: | |
post_substr = substr[2:] | |
new_str += "{" + a + "}" + b + post_substr | |
else: | |
new_str += "{" + a + "}" + b | |
string = new_str | |
return string | |
def fix_a_slash_b(string): | |
if len(string.split("/")) != 2: | |
return string | |
a = string.split("/")[0] | |
b = string.split("/")[1] | |
try: | |
a = int(a) | |
b = int(b) | |
assert string == "{}/{}".format(a, b) | |
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" | |
return new_string | |
except Exception as e: | |
return string | |
def strip_string(string): | |
# linebreaks | |
string = string.replace("\n", "") | |
# remove inverse spaces | |
string = string.replace("\\!", "") | |
# replace \\ with \ | |
string = string.replace("\\\\", "\\") | |
# replace tfrac and dfrac with frac | |
string = string.replace("tfrac", "frac") | |
string = string.replace("dfrac", "frac") | |
# remove \left and \right | |
string = string.replace("\\left", "") | |
string = string.replace("\\right", "") | |
# Remove circ (degrees) | |
string = string.replace("^{\\circ}", "") | |
string = string.replace("^\\circ", "") | |
# remove dollar signs | |
string = string.replace("\\$", "") | |
# remove units (on the right) | |
string = remove_right_units(string) | |
# remove percentage | |
string = string.replace("\\%", "") | |
string = string.replace("\%", "") # noqa: W605 | |
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string | |
string = string.replace(" .", " 0.") | |
string = string.replace("{.", "{0.") | |
# if empty, return empty string | |
if len(string) == 0: | |
return string | |
if string[0] == ".": | |
string = "0" + string | |
# to consider: get rid of e.g. "k = " or "q = " at beginning | |
if len(string.split("=")) == 2: | |
if len(string.split("=")[0]) <= 2: | |
string = string.split("=")[1] | |
# fix sqrt3 --> sqrt{3} | |
string = fix_sqrt(string) | |
# remove spaces | |
string = string.replace(" ", "") | |
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} | |
string = fix_fracs(string) | |
# manually change 0.5 --> \frac{1}{2} | |
if string == "0.5": | |
string = "\\frac{1}{2}" | |
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y | |
# string = fix_a_slash_b(string) | |
return string | |
def get_answer(string): | |
try: | |
answer = remove_boxed(last_boxed_only_string(string)) | |
# answer = strip_string(answer) | |
except Exception: | |
answer = string | |
return answer | |
def is_equiv(str1, str2, verbose=False): | |
if str1 is None and str2 is None: | |
print("WARNING: Both None") | |
return False | |
if str1 is None or str2 is None: | |
return False | |
try: | |
ss1 = strip_string(str1) | |
ss2 = strip_string(str2) | |
if verbose: | |
print(ss1, ss2) | |
return ss1 == ss2 | |
except Exception: | |
return str1 == str2 | |
if __name__ == "__main__": | |
num = extract_numeric("the answer is -1.5") | |
print(num) | |