File size: 5,185 Bytes
753cb33
 
 
 
6e340e9
753cb33
 
 
e8c0a63
 
 
753cb33
 
 
e8c0a63
753cb33
 
6e340e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
753cb33
 
 
 
 
 
 
 
 
6e340e9
753cb33
 
 
 
6e340e9
 
 
 
 
 
 
 
 
 
753cb33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35deaeb
753cb33
 
 
 
 
 
 
 
6e340e9
753cb33
 
 
 
 
6e340e9
35deaeb
 
e8c0a63
35deaeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42bb7ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8c0a63
 
42bb7ad
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import os
import re
import subprocess
import tempfile
import multiprocessing
from collections import Counter
from contextlib import contextmanager
from dataclasses import dataclass
import sympy as sp
import numpy as np
import matplotlib.pyplot as plt


class PythonREPL:
    def __init__(self, timeout=15):
        self.timeout = timeout

    @staticmethod
    def _run_code(temp_file_path):
        result = subprocess.run(
            ["python3", temp_file_path],
            capture_output=True,
            check=False,
            text=True
        )
        if result.returncode == 0:
            return True, result.stdout.strip()
        else:
            error_msg = result.stderr.strip()
            msgs = error_msg.split("\n")
            new_msgs = []
            want_next = False
            for m in msgs:
                if "Traceback" in m:
                    new_msgs.append(m)
                elif m == msgs[-1]:
                    new_msgs.append(m)
                elif temp_file_path in m:
                    st = m.index('"/') + 1 if '"/' in m else 0
                    ed = m.index(temp_file_path) + 1 if temp_file_path in m else None
                    clr = m[st:ed] if not ed else m[st:]
                    m = m.replace(clr, "")
                    new_msgs.append(m)
                    want_next = True
                elif want_next:
                    new_msgs.append(m)
                    want_next = False
            return False, "\n".join(new_msgs).strip()

    def __call__(self, query):
        query = "import math\nimport numpy as np\nimport sympy as sp\n" + query
        query = query.strip().split("\n")
        if "print(" not in query[-1]:
            if "#" in query[-1]:
                query[-1] = query[-1].split("#")[0]
            query[-1] = "print(" + query[-1] + ")"
        query = "\n".join(query)
        
        with tempfile.TemporaryDirectory() as temp_dir:
            temp_file_path = os.path.join(temp_dir, "tmp.py")
            with open(temp_file_path, "w", encoding="utf-8") as f:
                f.write(query)

            with multiprocessing.Pool(1) as pool:
                result = pool.apply_async(self._run_code, (temp_file_path,))
                try:
                    success, output = result.get(self.timeout)
                except multiprocessing.TimeoutError:
                    pool.terminate()
                    return False, f"Timed out after {self.timeout} seconds."
        return success, output


def execute_completion(executor, completion, return_status, last_code_block):
    executions = re.findall(r"```python(.*?)```", completion, re.DOTALL)
    if len(executions) == 0:
        return completion, False if return_status else completion
    if last_code_block:
        executions = [executions[-1]]
    outputs = []
    successes = []
    for code in executions:
        success = False
        for lib in ("subprocess", "venv"):
            if lib in code:
                output = f"{lib} is not allowed"
                outputs.append(output)
                successes.append(success)
                continue
        try:
            success, output = executor(code)
        except TimeoutError as e:
            print("Code timed out")
            output = e
        if not success and not return_status:
            output = ""
        outputs.append(output)
        successes.append(success)
    output = str(outputs[-1]).strip()
    success = successes[-1]
    if return_status:
        return output, success
    return output ,False


def postprocess_completion(text, return_status, last_code_block):
    executor = PythonREPL()
    result = execute_completion(executor, text, return_status=return_status, last_code_block=last_code_block)
    del executor
    return result


def get_majority_vote(answers):
    if not len(answers):
        return 0
    c = Counter(answers)
    value, _ = c.most_common()[0]
    return value


def type_check(expr_str):
       
       
        expr = sp.sympify(expr_str)
        
        # Check if the expression is a real number
        if expr.is_real:
            return "Real"

        # Check if the expression is a complex number
        if expr.is_complex:
            return "Complex"

        # Check if the expression is a polynomial
        if expr.is_polynomial():
            return "Polynomial"

        # Otherwise, classify as other
        return "Other"


def draw_polynomial_plot(expression):
    try:
        x = sp.symbols('x')
        poly_expr = sp.sympify(expression)  # Convert input to sympy expression
        poly_lambda = sp.lambdify(x, poly_expr, 'numpy')

        # Create the plot
        x_vals = np.linspace(-10, 10, 400)
        y_vals = poly_lambda(x_vals)

        plt.figure()
        plt.plot(x_vals, y_vals)
        plt.title('Polynomial Plot')
        plt.xlabel('x')
        plt.ylabel('y')
        plt.grid(True)

        # Save the plot to a file
        plot_filename = "polynomial_plot.png"
        plt.savefig(plot_filename)
        plt.close()

        return plot_filename 
    except Exception as e:
        print(f"Error in draw_polynomial_plot: {e}")
        return None