rootacess's picture
M-1: PAL, TA, MP
9b4edaf
raw
history blame
1.47 kB
from prompt import TA_prompt
import re
from utils import generate_response, run_code
def post_process_code(code, question):
func_name = code.split("(")[0].split("def")[-1].strip()
parameters = code.split("\n")[0].split(f"def {func_name}")[-1][1:-2].split(",")
if '' in parameters:
parameters.remove('')
values = re.findall(r"[-+]?\d*\.\d+|\d+", question)[:len(parameters)]
values = [int(v) for v in values]
arguments = list(zip(parameters, values))
arg_string = ""
for param, val in arguments:
arg_string += f"{param}={val},"
func_call = f"\nprint({func_name}({arg_string[:-1]}))"
code += func_call
return code
def solve_ta(question, token):
question = question.strip()
question = "Human: " + question
query = TA_prompt + question
query = query.strip()
query += "\n"
code = generate_response(query, 0.9, token)
splitting_string = "```" if "```python" not in code else "```python"
code = code.split(splitting_string)[-2].split("```")[0].strip()
print(code)
# code preprocessing
code = post_process_code(code, question)
print(code)
# code running
if "input(" in code:
return None, code
pred = None
try:
pred = run_code(code)
except Exception as ex:
return None, code
return pred, code
if __name__ == "__main__":
q = "What is the 7th Fibonacci number? Write the Python code"
print(solve_ta(q))