File size: 3,866 Bytes
44001f9 753cabf 44001f9 753cabf 44001f9 753cabf 44001f9 753cabf 44001f9 753cabf 44001f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import os
import re
import subprocess
import tempfile
import multiprocessing
from collections import Counter
from contextlib import contextmanager
from dataclasses import dataclass
class PythonREPL:
def __init__(self, timeout=5):
self.timeout = timeout
@staticmethod
def _run_code(temp_file_path):
result = subprocess.run(
["python3", temp_file_path],
capture_output=True,
check=False,
text=True
)
if result.returncode == 0:
return True, result.stdout.strip()
else:
error_msg = result.stderr.strip()
msgs = error_msg.split("
")
new_msgs = []
want_next = False
for m in msgs:
if "Traceback" in m:
new_msgs.append(m)
elif m == msgs[-1]:
new_msgs.append(m)
elif temp_file_path in m:
st = m.index('"/') + 1 if '"/' in m else 0
ed = m.index(temp_file_path) + 1 if temp_file_path in m else None
clr = m[st:ed] if not ed else m[st:]
m = m.replace(clr, "")
new_msgs.append(m)
want_next = True
elif want_next:
new_msgs.append(m)
want_next = False
return False, "
".join(new_msgs).strip()
def __call__(self, query):
query = "import math
import numpy as np
import sympy as sp
" + query
query = query.strip().split("
")
if "print(" not in query[-1]:
if "#" in query[-1]:
query[-1] = query[-1].split("#")[0]
query[-1] = "print(" + query[-1] + ")"
query = "
".join(query)
with tempfile.TemporaryDirectory() as temp_dir:
temp_file_path = os.path.join(temp_dir, "tmp.py")
with open(temp_file_path, "w", encoding="utf-8") as f:
f.write(query)
with multiprocessing.Pool(1) as pool:
result = pool.apply_async(self._run_code, (temp_file_path,))
try:
success, output = result.get(self.timeout)
except multiprocessing.TimeoutError:
pool.terminate()
return False, f"Timed out after {self.timeout} seconds."
return success, output
def execute_completion(executor, completion, return_status, last_code_block):
executions = re.findall(r"```python(.*?)```", completion, re.DOTALL)
if len(executions) == 0:
return completion, False if return_status else completion
if last_code_block:
executions = [executions[-1]]
outputs = []
successes = []
for code in executions:
success = False
for lib in ("subprocess", "venv"):
if lib in code:
output = f"{lib} is not allowed"
outputs.append(output)
successes.append(success)
continue
try:
success, output = executor(code)
except TimeoutError as e:
print("Code timed out")
output = e
if not success and not return_status:
output = ""
outputs.append(output)
successes.append(success)
output = str(outputs[-1]).strip()
success = successes[-1]
if return_status:
return output, success
return output
def postprocess_completion(text, return_status, last_code_block):
executor = PythonREPL()
result = execute_completion(executor, text, return_status=return_status, last_code_block=last_code_block)
del executor
return result
def get_majority_vote(answers):
if not len(answers):
return 0
c = Counter(answers)
value, _ = c.most_common()[0]
return value
|