deepseekmath / codeexecutor.py
Pra-tham's picture
added some more code
753cb33
raw
history blame
4.01 kB
import os
import re
import signal
import subprocess
import tempfile
from collections import Counter
from contextlib import contextmanager
from dataclasses import dataclass
class PythonREPL:
def __init__(self, timeout=5):
self.timeout = timeout
@contextmanager
def time_limit(self, seconds):
def signal_handler(*_):
raise TimeoutError(f"Timed out after {seconds} seconds.")
signal.signal(signal.SIGALRM, signal_handler)
signal.alarm(seconds)
try:
yield
finally:
signal.alarm(0)
def __call__(self, query):
query = "import math\nimport numpy as np\nimport sympy as sp\n" + query
query = query.strip().split("\n")
if "print(" not in query[-1]:
if "#" in query[-1]:
query[-1] = query[-1].split("#")[0]
query[-1] = "print(" + query[-1] + ")"
query = "\n".join(query)
with tempfile.TemporaryDirectory() as temp_dir:
temp_file_path = os.path.join(temp_dir, "tmp.py")
with open(temp_file_path, "w", encoding="utf-8") as f:
f.write(query)
with self.time_limit(self.timeout):
result = subprocess.run(
["python3", temp_file_path],
capture_output=True,
check=False,
text=True,
timeout=self.timeout,
)
if result.returncode == 0:
output = result.stdout
return True, output.strip()
error_msg = result.stderr.strip()
msgs = error_msg.split("\n")
new_msgs = []
want_next = False
for m in msgs:
if "Traceback" in m:
new_msgs.append(m)
elif m == msgs[-1]:
new_msgs.append(m)
elif temp_file_path in m:
st = m.index('"/') + 1 if '"/' in m else 0
ed = m.index(temp_file_path) + 1 if temp_file_path in m else None
clr = m[st:ed] if not ed else m[st:]
m = m.replace(clr, "")
new_msgs.append(m)
want_next = True
elif want_next:
new_msgs.append(m)
want_next = False
error_msg = "\n".join(new_msgs)
return False, error_msg.strip()
def execute_completion(executor, completion, return_status, last_code_block):
executions = re.findall(r"```python(.*?)```", completion, re.DOTALL)
if len(executions) == 0:
return completion, False if return_status else completion
if last_code_block:
executions = [executions[-1]]
outputs = []
successes = []
for code in executions:
success = False
for lib in ("subprocess", "venv"):
if lib in code:
output = f"{lib} is not allowed"
outputs.append(output)
successes.append(success)
continue
try:
success, output = executor(code)
except TimeoutError as e:
print("Code timed out")
output = e
if not success and not return_status:
output = ""
outputs.append(output)
successes.append(success)
output = str(outputs[-1]).strip()
success = successes[-1]
if return_status:
return output, success
return output
def postprocess_completion(text, return_status, last_code_block):
executor = PythonREPL()
result = execute_completion(executor, text, return_status=return_status, last_code_block=last_code_block)
del executor
return result
def get_majority_vote(answers):
if not len(answers):
return 0
c = Counter(answers)
value, _ = c.most_common()[0]
return value