|
|
|
import ast |
|
import json |
|
import random |
|
import re |
|
import signal |
|
import threading |
|
from contextlib import contextmanager |
|
from typing import Any, Tuple |
|
|
|
import pandas as pd |
|
from langchain_experimental.tools.python.tool import PythonAstREPLTool |
|
|
|
from evaluate_code_correction.pytool import ( |
|
extract_last_df, format_result) |
|
|
|
|
|
def read_jsonl(file_path): |
|
|
|
data = [] |
|
with open(file_path, "r", encoding="utf-8") as f: |
|
for line in f: |
|
data.append(json.loads(line.strip())) |
|
return data |
|
|
|
def load_json(data_path): |
|
""" |
|
# 加载 json 文件 |
|
""" |
|
with open(data_path, "r", encoding="utf-8") as f: |
|
data = json.load(f) |
|
return data |
|
|
|
|
|
def save_json(data_path, data_list): |
|
""" |
|
# 保存 json 文件 |
|
""" |
|
with open(data_path, "w", encoding="utf-8") as f: |
|
json.dump(data_list, f, ensure_ascii=False) |
|
|
|
|
|
def get_dfs_info(table_paths): |
|
"""将所有csv文件对应的df-info拼装到一起""" |
|
infos_list = [] |
|
if len(table_paths) == 1: |
|
df_markdown_info = str( |
|
pd.read_csv(table_paths[0], encoding="utf-8").head(5).to_string(index=False) |
|
) |
|
normalized_head = f"""/*\n"df.head()" as follows:\n{df_markdown_info}\n*/""" |
|
infos_list.append(normalized_head) |
|
else: |
|
for i, path in enumerate(table_paths): |
|
|
|
df_markdown_info = str( |
|
pd.read_csv(path, encoding="utf-8").head(5).to_string(index=False) |
|
) |
|
normalized_head = ( |
|
f"""/*\n"df{i+1}.head()" as follows:\n{df_markdown_info}\n*/""" |
|
) |
|
|
|
infos_list.append(normalized_head) |
|
return "\n".join(infos_list) |
|
|
|
|
|
def sample_from_two_lists(list1, list2): |
|
|
|
if not list1 or not list2: |
|
return None |
|
|
|
|
|
rand_val = random.random() |
|
|
|
|
|
if rand_val < 0.5: |
|
return random.choice(list1) |
|
|
|
else: |
|
return random.choice(list2) |
|
|
|
|
|
def recraft_query(query, locals): |
|
last_df = extract_last_df(query, locals) |
|
end_str = "\n" + format_result + "print(format_result({}))".format(last_df) |
|
recraft_query = query + end_str |
|
return recraft_query |
|
|
|
|
|
def extract_code_without_comments(code): |
|
""" |
|
从Python代码中提取除注释行以外的代码。 |
|
|
|
:param code: str, 输入的Python代码 |
|
:return: str, 提取后的代码 |
|
""" |
|
code = re.sub(r'"""[\s\S]*?"""', "", code) |
|
code = re.sub(r"'''[\s\S]*?'''", "", code) |
|
|
|
|
|
lines = code.split("\n") |
|
cleaned_lines = [] |
|
for line in lines: |
|
|
|
cleaned_line = re.sub(r'(?<!["\'"])#.*$', "", line) |
|
cleaned_lines.append(cleaned_line.rstrip()) |
|
|
|
return "\n".join(cleaned_lines) |
|
|
|
|
|
def is_python_code(line: str) -> bool: |
|
"""Tool function to check if a line of text is Python code""" |
|
try: |
|
tree = ast.parse(line) |
|
|
|
for node in ast.walk(tree): |
|
if isinstance( |
|
node, |
|
( |
|
ast.Expr, |
|
ast.Assign, |
|
ast.FunctionDef, |
|
ast.ClassDef, |
|
ast.Import, |
|
ast.ImportFrom, |
|
ast.For, |
|
ast.While, |
|
ast.If, |
|
ast.With, |
|
ast.Raise, |
|
ast.Try, |
|
), |
|
): |
|
return True |
|
return False |
|
except SyntaxError: |
|
return False |
|
|
|
|
|
def extract_text_before_code(text: str) -> str: |
|
"""Tool function for extract text content""" |
|
lines = text.split("\n") |
|
text_before_code = [] |
|
|
|
for line in lines: |
|
if is_python_code(line): |
|
break |
|
text_before_code.append(line) |
|
|
|
return "\n".join(text_before_code) |
|
|
|
|
|
def extract_python_code(text: str) -> str: |
|
"""Tool function for extract python code""" |
|
lines = text.split("\n") |
|
python_code = [] |
|
|
|
for line in lines: |
|
if is_python_code(line): |
|
python_code.append(line) |
|
|
|
return "\n".join(python_code) |
|
|
|
|
|
def fix_indents(text: str) -> str: |
|
return text.replace("\t", " ") |
|
|
|
|
|
def filter_cot(completion: str): |
|
""" |
|
Filter the COT steps before python code |
|
:param completion: llm output contents |
|
:return filtered COT content |
|
""" |
|
try: |
|
|
|
pattern = r"Thought:\s*(.*?)\s*(?=Python Code:)" |
|
match = re.search(pattern, completion, re.DOTALL) |
|
if match: |
|
thought_content = match.group(1) |
|
else: |
|
|
|
thought_content = extract_text_before_code(completion) |
|
return thought_content |
|
except: |
|
return "" |
|
|
|
|
|
def filter_code(completion: str) -> Tuple[str, str]: |
|
""" |
|
Filter python code from the llm output completion |
|
:param completion: llm output contents |
|
:return filtered python code and execute code |
|
""" |
|
|
|
try: |
|
|
|
regex = r"```python\s(.*?)```" |
|
action_match = re.search(regex, completion, re.DOTALL) |
|
if action_match: |
|
action_input = action_match.group(1) |
|
action_input = action_input.strip(" ") |
|
action_input = action_input.strip('"') |
|
code = action_input.strip(" ") |
|
else: |
|
|
|
code = extract_python_code(completion) |
|
code = code.strip(" ") |
|
pure_code = extract_code_without_comments(code) |
|
return code, pure_code |
|
except: |
|
return "", "" |
|
|
|
|
|
def get_tool(df: Any, df_names=None): |
|
""" |
|
Define python code execute tool |
|
:param df: List[pd.DataFrame] or pd.DataFrame |
|
:return Runnable |
|
""" |
|
tool = PythonAstREPLTool() |
|
if df_names == None: |
|
if isinstance(df, pd.DataFrame): |
|
locals = {"df": df} |
|
else: |
|
locals = {} |
|
for i, dataframe in enumerate(df): |
|
locals[f"df{i + 1}"] = dataframe |
|
else: |
|
locals = {} |
|
for i, dataframe in enumerate(df): |
|
locals[df_names[i]] = dataframe |
|
tool.locals = locals |
|
tool.globals = tool.locals |
|
return tool |
|
|
|
|
|
def get_table_infos(table_paths): |
|
"""将所有csv文件对应的df-info拼装到一起""" |
|
infos_list = [] |
|
if len(table_paths) == 1: |
|
df_markdown_info = str( |
|
pd.read_csv(table_paths[0], encoding="utf-8") |
|
.head(3) |
|
.to_markdown(index=False) |
|
) |
|
normalized_head = f"""/*\n"df.head()" as follows:\n{df_markdown_info}\n*/""" |
|
infos_list.append(normalized_head) |
|
else: |
|
for i, path in enumerate(table_paths): |
|
|
|
df_markdown_info = str( |
|
pd.read_csv(path, encoding="utf-8").head(3).to_markdown(index=False) |
|
) |
|
normalized_head = ( |
|
f"""/*\n"df{i+1}.head()" as follows:\n{df_markdown_info}\n*/""" |
|
) |
|
|
|
infos_list.append(normalized_head) |
|
return "\n".join(infos_list) |
|
|
|
|
|
|
|
class TimeoutException(Exception): |
|
pass |
|
|
|
|
|
|
|
@contextmanager |
|
def timeout(time): |
|
|
|
def raise_timeout(signum, frame): |
|
raise TimeoutException(f"Timeout error, running time exceed {time}") |
|
|
|
|
|
signal.signal(signal.SIGALRM, raise_timeout) |
|
signal.alarm(time) |
|
try: |
|
yield |
|
finally: |
|
|
|
signal.alarm(0) |
|
|
|
|
|
def run_code(code, result, tool): |
|
try: |
|
|
|
result.append(tool.run(code)) |
|
except Exception as e: |
|
result.append(e) |
|
|
|
|
|
def execute_with_timeout(code, timeout_seconds, tool): |
|
result = [] |
|
thread = threading.Thread(target=run_code, args=(code, result, tool)) |
|
thread.start() |
|
thread.join(timeout_seconds) |
|
|
|
if thread.is_alive(): |
|
thread._stop() |
|
raise TimeoutException( |
|
f"Timeout error, running time exceed {timeout_seconds} seconds" |
|
) |
|
else: |
|
if isinstance(result[0], Exception): |
|
raise result[0] |
|
return result[0] |
|
|