Spaces:
Runtime error
Runtime error
from tableQA_single_table import * | |
import json | |
import os | |
import sys | |
def run_sql_query(s, df): | |
conn = sqlite3.connect(":memory:") | |
assert isinstance(df, pd.DataFrame) | |
question_column = s.question_column | |
if question_column is None: | |
return { | |
"sql_query": "", | |
"cnt_num": 0, | |
"conclusion": [] | |
} | |
total_conds_filtered = s.total_conds_filtered | |
agg_pred = s.agg_pred | |
conn_pred = s.conn_pred | |
sql_format = "SELECT {} FROM {} {}" | |
header = df.columns.tolist() | |
if len(header) > len(set(header)): | |
req = [] | |
have_req = set([]) | |
idx = 0 | |
for h in header: | |
if h in have_req: | |
idx += 1 | |
req.append("{}_{}".format(h, idx)) | |
else: | |
req.append(h) | |
have_req.add(h) | |
header = req | |
def format_right(val): | |
val = str(val) | |
is_string = True | |
try: | |
literal_eval(val) | |
is_string = False | |
except: | |
pass | |
if is_string: | |
return "'{}'".format(val) | |
else: | |
return val | |
#ic(question_column, header) | |
assert question_column in header | |
assert all(map(lambda t3: t3[0] in header, total_conds_filtered)) | |
assert len(header) == len(set(header)) | |
index_header_mapping = dict(enumerate(header)) | |
header_index_mapping = dict(map(lambda t2: (t2[1], t2[0]) ,index_header_mapping.items())) | |
assert len(index_header_mapping) == len(header_index_mapping) | |
df_saved = df.copy() | |
df_saved.columns = list(map(lambda idx: "col_{}".format(idx), range(len(header)))) | |
df_saved.to_sql("Mem_Table", conn, if_exists = "replace", index = False) | |
question_column_idx = header.index(question_column) | |
sql_question_column = "col_{}".format(question_column_idx) | |
sql_total_conds_filtered = list(map(lambda t3: ("col_{}".format(header.index(t3[0])), t3[1], format_right(t3[2])), total_conds_filtered)) | |
sql_agg_pred = agg_pred | |
if sql_agg_pred.strip(): | |
sql_agg_pred = "{}()".format(sql_agg_pred) | |
else: | |
sql_agg_pred = "()" | |
sql_agg_pred = sql_agg_pred.replace("()", "({})") | |
sql_conn_pred = conn_pred | |
if sql_conn_pred.strip(): | |
pass | |
else: | |
sql_conn_pred = "" | |
#sql_where_string = "" if not (sql_total_conds_filtered and sql_conn_pred) else "WHERE {}".format(" {} ".format(sql_conn_pred).join(map(lambda t3: "{} {} {}".format(t3[0],"=" if t3[1] == "==" else t3[1], t3[2]), sql_total_conds_filtered))) | |
sql_where_string = "" if not (sql_total_conds_filtered) else "WHERE {}".format(" {} ".format(sql_conn_pred if sql_conn_pred else "and").join(map(lambda t3: "{} {} {}".format(t3[0],"=" if t3[1] == "==" else t3[1], t3[2]), sql_total_conds_filtered))) | |
#ic(sql_total_conds_filtered, sql_conn_pred, sql_where_string, s) | |
sql_query = sql_format.format(sql_agg_pred.format(sql_question_column), "Mem_Table", sql_where_string) | |
cnt_sql_query = sql_format.format("COUNT(*)", "Mem_Table", sql_where_string).strip() | |
#ic(cnt_sql_query) | |
cnt_num = pd.read_sql(cnt_sql_query, conn).values.reshape((-1,))[0] | |
if cnt_num == 0: | |
return { | |
"sql_query": sql_query, | |
"cnt_num": 0, | |
"conclusion": [] | |
} | |
query_conclusion_list = pd.read_sql(sql_query, conn).values.reshape((-1,)).tolist() | |
return { | |
"sql_query": sql_query, | |
"cnt_num": cnt_num, | |
"conclusion": query_conclusion_list | |
} | |
#save_conn = sqlite3.connect(":memory:") | |
def single_table_pred(question, pd_df): | |
assert type(question) == type("") | |
assert isinstance(pd_df, pd.DataFrame) | |
qs_df = pd.DataFrame([[question]], columns = ["question"]) | |
#print("pd_df :") | |
#print(pd_df) | |
tableqa_df = full_before_cat_decomp(pd_df, qs_df, only_req_columns=False) | |
#print("tableqa_df :") | |
#print(tableqa_df) | |
assert tableqa_df.shape[0] == 1 | |
#sql_query_dict = run_sql_query(tableqa_df.iloc[0], pd_df, save_conn) | |
sql_query_dict = run_sql_query(tableqa_df.iloc[0], pd_df) | |
return sql_query_dict | |
if __name__ == "__main__": | |
szse_summary_df = pd.read_csv(os.path.join(main_path ,"data/df1.csv")) | |
data = { | |
"tqa_question": "EPS大于0且周涨跌大于5的平均市值是多少?", | |
"tqa_header": szse_summary_df.columns.tolist(), | |
"tqa_rows": szse_summary_df.values.tolist(), | |
"tqa_data_path": os.path.join(main_path ,"data/df1.csv"), | |
"tqa_answer": { | |
"sql_query": "SELECT AVG(col_4) FROM Mem_Table WHERE col_5 > 0 and col_3 > 5", | |
"cnt_num": 2, | |
"conclusion": [57.645] | |
} | |
} | |
pd_df = pd.DataFrame(data["tqa_rows"], columns = data["tqa_header"]) | |
question = data["tqa_question"] | |
single_table_pred(question, pd_df) | |