import gradio as gr from gradio import * from run import * szse_summary_df = pd.read_csv(os.path.join(main_path ,"data/df1.csv")) tableqa_ = "数据表问答(编辑数据)" default_val_dict = { tableqa_ :{ "tqa_question": "EPS大于0且周涨跌大于5的平均市值是多少?", "tqa_header": szse_summary_df.columns.tolist(), "tqa_rows": szse_summary_df.values.tolist(), "tqa_data_path": os.path.join(main_path ,"data/df1.csv"), "tqa_answer": { "sql_query": "SELECT AVG(col_4) FROM Mem_Table WHERE col_5 > 0 and col_3 > 5", "cnt_num": 2, "conclusion": [57.645] } } } def tableqa_layer(post_data): question = post_data["question"] table_rows = post_data["table_rows"] table_header = post_data["table_header"] assert all(map(lambda x: type(x) == type(""), [question, table_rows, table_header])) table_rows = json.loads(table_rows) table_header = json.loads(table_header) assert all(map(lambda x: type(x) == type([]), [table_rows, table_header])) if bool(table_rows) and bool(table_header): assert len(table_header) == len(table_rows[0]) df = pd.DataFrame(table_rows, columns = table_header) conclusion = single_table_pred(question, df) return conclusion def run_tableqa(*input): question, data = input header = data.columns.tolist() rows = data.values.tolist() rows = list(filter(lambda x: any(map(lambda xx: bool(xx), x)), rows)) assert all(map(lambda x: type(x) == type([]), [header, rows])) header = json.dumps(header) rows = json.dumps(rows) assert all(map(lambda x: type(x) == type(""), [question, header, rows])) resp = tableqa_layer( { "question": question, "table_header": header, "table_rows": rows } ) if "cnt_num" in resp: if hasattr(resp["cnt_num"], "tolist"): resp["cnt_num"] = resp["cnt_num"].tolist() if "conclusion" in resp: if hasattr(resp["conclusion"], "tolist"): resp["conclusion"] = resp["conclusion"].tolist() ''' import pickle as pkl with open("resp.pkl", "wb") as f: pkl.dump(resp, f) print(resp) ''' resp = json.loads(json.dumps(resp)) return resp demo = gr.Blocks(css=".container { max-width: 800px; margin: auto; }") with demo: gr.Markdown("") gr.Markdown("This _example_ was **drive** from

[https://github.com/svjack/tableQA-Chinese](https://github.com/svjack/tableQA-Chinese)

\n") with gr.Tabs(): #### tableqa with gr.TabItem("数据表问答(TableQA)"): with gr.Tabs(): with gr.TabItem(tableqa_): tqa_question = gr.Textbox( default_val_dict[tableqa_]["tqa_question"], label = "问句:(输入)" ) tqa_data = gr.Dataframe( headers=default_val_dict[tableqa_]["tqa_header"], value=default_val_dict[tableqa_]["tqa_rows"], row_count = len(default_val_dict[tableqa_]["tqa_rows"]) + 1 ) tqa_answer = JSON( default_val_dict[tableqa_]["tqa_answer"], label = "问句:(输出)" ) tqa_button = gr.Button("得到答案") tqa_button.click(run_tableqa, inputs=[ tqa_question, tqa_data ], outputs=tqa_answer) demo.launch(server_name="0.0.0.0")