Pra-tham commited on
Commit
6e340e9
1 Parent(s): b7139fd

added message

Browse files
Files changed (1) hide show
  1. codeexecutor.py +45 -46
codeexecutor.py CHANGED
@@ -1,8 +1,8 @@
1
  import os
2
  import re
3
- import signal
4
  import subprocess
5
  import tempfile
 
6
  from collections import Counter
7
  from contextlib import contextmanager
8
  from dataclasses import dataclass
@@ -12,17 +12,37 @@ class PythonREPL:
12
  def __init__(self, timeout=5):
13
  self.timeout = timeout
14
 
15
- @contextmanager
16
- def time_limit(self, seconds):
17
- def signal_handler(*_):
18
- raise TimeoutError(f"Timed out after {seconds} seconds.")
19
-
20
- signal.signal(signal.SIGALRM, signal_handler)
21
- signal.alarm(seconds)
22
- try:
23
- yield
24
- finally:
25
- signal.alarm(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  def __call__(self, query):
28
  query = "import math\nimport numpy as np\nimport sympy as sp\n" + query
@@ -32,43 +52,21 @@ class PythonREPL:
32
  query[-1] = query[-1].split("#")[0]
33
  query[-1] = "print(" + query[-1] + ")"
34
  query = "\n".join(query)
 
35
  with tempfile.TemporaryDirectory() as temp_dir:
36
  temp_file_path = os.path.join(temp_dir, "tmp.py")
37
  with open(temp_file_path, "w", encoding="utf-8") as f:
38
  f.write(query)
39
- with self.time_limit(self.timeout):
40
- result = subprocess.run(
41
- ["python3", temp_file_path],
42
- capture_output=True,
43
- check=False,
44
- text=True,
45
- timeout=self.timeout,
46
- )
47
- if result.returncode == 0:
48
- output = result.stdout
49
- return True, output.strip()
50
- error_msg = result.stderr.strip()
51
- msgs = error_msg.split("\n")
52
- new_msgs = []
53
- want_next = False
54
- for m in msgs:
55
- if "Traceback" in m:
56
- new_msgs.append(m)
57
- elif m == msgs[-1]:
58
- new_msgs.append(m)
59
- elif temp_file_path in m:
60
- st = m.index('"/') + 1 if '"/' in m else 0
61
- ed = m.index(temp_file_path) + 1 if temp_file_path in m else None
62
- clr = m[st:ed] if not ed else m[st:]
63
- m = m.replace(clr, "")
64
- new_msgs.append(m)
65
- want_next = True
66
- elif want_next:
67
- new_msgs.append(m)
68
- want_next = False
69
- error_msg = "\n".join(new_msgs)
70
- return False, error_msg.strip()
71
-
72
 
73
  def execute_completion(executor, completion, return_status, last_code_block):
74
  executions = re.findall(r"```python(.*?)```", completion, re.DOTALL)
@@ -108,9 +106,10 @@ def postprocess_completion(text, return_status, last_code_block):
108
  del executor
109
  return result
110
 
 
111
  def get_majority_vote(answers):
112
  if not len(answers):
113
  return 0
114
  c = Counter(answers)
115
  value, _ = c.most_common()[0]
116
- return value
 
1
  import os
2
  import re
 
3
  import subprocess
4
  import tempfile
5
+ import multiprocessing
6
  from collections import Counter
7
  from contextlib import contextmanager
8
  from dataclasses import dataclass
 
12
  def __init__(self, timeout=5):
13
  self.timeout = timeout
14
 
15
+ @staticmethod
16
+ def _run_code(temp_file_path):
17
+ result = subprocess.run(
18
+ ["python3", temp_file_path],
19
+ capture_output=True,
20
+ check=False,
21
+ text=True
22
+ )
23
+ if result.returncode == 0:
24
+ return True, result.stdout.strip()
25
+ else:
26
+ error_msg = result.stderr.strip()
27
+ msgs = error_msg.split("\n")
28
+ new_msgs = []
29
+ want_next = False
30
+ for m in msgs:
31
+ if "Traceback" in m:
32
+ new_msgs.append(m)
33
+ elif m == msgs[-1]:
34
+ new_msgs.append(m)
35
+ elif temp_file_path in m:
36
+ st = m.index('"/') + 1 if '"/' in m else 0
37
+ ed = m.index(temp_file_path) + 1 if temp_file_path in m else None
38
+ clr = m[st:ed] if not ed else m[st:]
39
+ m = m.replace(clr, "")
40
+ new_msgs.append(m)
41
+ want_next = True
42
+ elif want_next:
43
+ new_msgs.append(m)
44
+ want_next = False
45
+ return False, "\n".join(new_msgs).strip()
46
 
47
  def __call__(self, query):
48
  query = "import math\nimport numpy as np\nimport sympy as sp\n" + query
 
52
  query[-1] = query[-1].split("#")[0]
53
  query[-1] = "print(" + query[-1] + ")"
54
  query = "\n".join(query)
55
+
56
  with tempfile.TemporaryDirectory() as temp_dir:
57
  temp_file_path = os.path.join(temp_dir, "tmp.py")
58
  with open(temp_file_path, "w", encoding="utf-8") as f:
59
  f.write(query)
60
+
61
+ with multiprocessing.Pool(1) as pool:
62
+ result = pool.apply_async(self._run_code, (temp_file_path,))
63
+ try:
64
+ success, output = result.get(self.timeout)
65
+ except multiprocessing.TimeoutError:
66
+ pool.terminate()
67
+ return False, f"Timed out after {self.timeout} seconds."
68
+ return success, output
69
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  def execute_completion(executor, completion, return_status, last_code_block):
72
  executions = re.findall(r"```python(.*?)```", completion, re.DOTALL)
 
106
  del executor
107
  return result
108
 
109
+
110
  def get_majority_vote(answers):
111
  if not len(answers):
112
  return 0
113
  c = Counter(answers)
114
  value, _ = c.most_common()[0]
115
+ return value