Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -12,9 +12,6 @@ from dotenv import load_dotenv
|
|
12 |
load_dotenv()
|
13 |
userdata = os.environ
|
14 |
|
15 |
-
DATA_DIR = Path(os.getcwd()) / "data"
|
16 |
-
DATA_DIR.mkdir(parents=True, exist_ok=True)
|
17 |
-
|
18 |
|
19 |
def chat_with_groq(client:groq.Groq,
|
20 |
prompt:str,
|
@@ -49,19 +46,29 @@ def chat_with_groq(client:groq.Groq,
|
|
49 |
# logger.info(f"Completion: {completion}")
|
50 |
return completion.choices[0].message.content
|
51 |
|
52 |
-
def execute_duckdb_query(query:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
try:
|
54 |
conn = duckdb.connect(database=":memory:", read_only=False)
|
55 |
-
|
56 |
-
# Load all CSV files from the data directory
|
57 |
-
for csv_file in DATA_DIR.glob("*.csv"):
|
58 |
-
table_name = csv_file.stem
|
59 |
-
conn.execute(f"CREATE TABLE {table_name} AS SELECT * FROM read_csv_auto('{csv_file}')")
|
60 |
-
|
61 |
query_result = conn.execute(query).fetch_df().reset_index()
|
|
|
62 |
return query_result
|
63 |
except Exception as e:
|
64 |
-
print(f"Error
|
|
|
65 |
raise e
|
66 |
def get_summarization(client:groq.Groq,
|
67 |
use_question:str,
|
@@ -251,6 +258,10 @@ base_prompt = """
|
|
251 |
* Ensure that the entire output is returned on only one single line
|
252 |
* Keep your query as simple and straightforward as possible; do not use subqueries
|
253 |
"""
|
|
|
|
|
|
|
|
|
254 |
user_question = """"""
|
255 |
|
256 |
# And some rules for querying the dataset:
|
@@ -260,32 +271,39 @@ user_question = """"""
|
|
260 |
# * Valid values for product_name include 'Tesla','iPhone' and 'Humane pin'
|
261 |
|
262 |
|
263 |
-
def upload_file(files) ->
|
|
|
264 |
model = "llama3-8b-8192"
|
265 |
-
api_key:
|
266 |
-
|
267 |
-
|
|
|
268 |
files = [files]
|
269 |
-
|
270 |
-
uploaded_files = []
|
271 |
stored_table_descriptions = []
|
272 |
-
|
273 |
for file in files:
|
274 |
filename = Path(file.name).name
|
275 |
-
path =
|
276 |
|
277 |
# Copy the content of the temporary file to our destination
|
278 |
-
|
|
|
279 |
|
280 |
-
|
281 |
-
table_description = identify_column_datatypes_to_SQL_DEF(pd.read_csv(path),
|
282 |
-
desc =
|
283 |
stored_table_descriptions.append(desc)
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
|
|
|
|
|
|
|
|
|
|
289 |
|
290 |
def user_prompt_sanitization(user_prompt:str)->str:
|
291 |
guide = """
|
@@ -308,52 +326,56 @@ def user_prompt_sanitization(user_prompt:str)->str:
|
|
308 |
client = groq.Groq(api_key=api_key)
|
309 |
return chat_with_groq(client,formatted_guide,"llama3-70b-8192",None)
|
310 |
|
311 |
-
def queryModel(user_prompt:
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
|
|
|
|
|
|
|
|
330 |
try:
|
331 |
-
|
332 |
-
"type": "json_object"
|
333 |
-
})
|
334 |
except Exception as e:
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
|
|
|
|
357 |
|
358 |
with gr.Blocks() as demo:
|
359 |
gr.Markdown("# CSV Database Query Interface")
|
@@ -363,28 +385,16 @@ with gr.Blocks() as demo:
|
|
363 |
upload_button = gr.Button("Load CSV Files")
|
364 |
upload_output = gr.Textbox(label="Upload Status", lines=5)
|
365 |
|
366 |
-
|
367 |
-
result = upload_file(files)
|
368 |
-
uploaded_files.extend(result["files"])
|
369 |
-
return result["descriptions"], uploaded_files
|
370 |
-
|
371 |
-
upload_button.click(handle_upload,
|
372 |
-
inputs=[file_output, uploaded_files],
|
373 |
-
outputs=[upload_output, uploaded_files])
|
374 |
-
|
375 |
with gr.Tab("Query Interface"):
|
376 |
chatbot = gr.Chatbot()
|
377 |
with gr.Row():
|
378 |
user_input = gr.Textbox(label="Enter your question")
|
379 |
submit_button = gr.Button("Submit")
|
380 |
-
|
381 |
-
|
382 |
-
return queryModel(user_prompt, files)
|
383 |
|
384 |
-
submit_button.click(query_model_with_files,
|
385 |
-
inputs=[user_input, uploaded_files],
|
386 |
-
outputs=chatbot)
|
387 |
|
388 |
-
demo.launch()
|
389 |
|
390 |
|
|
|
12 |
load_dotenv()
|
13 |
userdata = os.environ
|
14 |
|
|
|
|
|
|
|
15 |
|
16 |
def chat_with_groq(client:groq.Groq,
|
17 |
prompt:str,
|
|
|
46 |
# logger.info(f"Completion: {completion}")
|
47 |
return completion.choices[0].message.content
|
48 |
|
49 |
+
def execute_duckdb_query(query:str)->pd.DataFrame:
|
50 |
+
"""
|
51 |
+
Execute a DuckDB query and return the result as a pandas DataFrame.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
query (str): The DuckDB query to execute.
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
pd.DataFrame: The result of the query as a pandas DataFrame.
|
58 |
+
"""
|
59 |
+
original_cwd = os.getcwd()
|
60 |
+
print(f"PATH:{original_cwd}")
|
61 |
+
os.chdir('data')
|
62 |
+
print(f"PATH:{os.getcwd()}")
|
63 |
+
|
64 |
try:
|
65 |
conn = duckdb.connect(database=":memory:", read_only=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
query_result = conn.execute(query).fetch_df().reset_index()
|
67 |
+
os.chdir(original_cwd)
|
68 |
return query_result
|
69 |
except Exception as e:
|
70 |
+
print(f"Error: {e}")
|
71 |
+
os.chdir(original_cwd)
|
72 |
raise e
|
73 |
def get_summarization(client:groq.Groq,
|
74 |
use_question:str,
|
|
|
258 |
* Ensure that the entire output is returned on only one single line
|
259 |
* Keep your query as simple and straightforward as possible; do not use subqueries
|
260 |
"""
|
261 |
+
table_description = """"""
|
262 |
+
tables_string = """"""
|
263 |
+
table_1 = """"""
|
264 |
+
table_1_wt_xt = """"""
|
265 |
user_question = """"""
|
266 |
|
267 |
# And some rules for querying the dataset:
|
|
|
271 |
# * Valid values for product_name include 'Tesla','iPhone' and 'Humane pin'
|
272 |
|
273 |
|
274 |
+
def upload_file(files) -> List[str]:
|
275 |
+
# will have to change to the private system is initiializes
|
276 |
model = "llama3-8b-8192"
|
277 |
+
api_key:str=userdata.get("GROQ_API_KEY")
|
278 |
+
data_dir = Path("data")
|
279 |
+
data_dir.mkdir(parents=True, exist_ok=True)
|
280 |
+
if type(files) == str:
|
281 |
files = [files]
|
282 |
+
stored_paths = []
|
|
|
283 |
stored_table_descriptions = []
|
284 |
+
tables = []
|
285 |
for file in files:
|
286 |
filename = Path(file.name).name
|
287 |
+
path = data_dir / filename
|
288 |
|
289 |
# Copy the content of the temporary file to our destination
|
290 |
+
with open(file.name, "rb") as source, open(path, "wb") as destination:
|
291 |
+
destination.write(source.read())
|
292 |
|
293 |
+
stored_paths.append(str(path.absolute()))
|
294 |
+
table_description = identify_column_datatypes_to_SQL_DEF(pd.read_csv(path),api_key,model)
|
295 |
+
desc = "Table: " + filename + "\n Columns:\n" + table_description
|
296 |
stored_table_descriptions.append(desc)
|
297 |
+
tables.append(filename)
|
298 |
+
# constructing a string
|
299 |
+
tables_string = join_with_and(tables)
|
300 |
+
final = "\n".join(stored_table_descriptions)
|
301 |
+
table_1_wt_xt = tables[0].split('.')[0]
|
302 |
+
table_description = final
|
303 |
+
tables_string = tables_string
|
304 |
+
table_1 = tables[0]
|
305 |
+
table_1_wt_xt = table_1_wt_xt
|
306 |
+
return final
|
307 |
|
308 |
def user_prompt_sanitization(user_prompt:str)->str:
|
309 |
guide = """
|
|
|
326 |
client = groq.Groq(api_key=api_key)
|
327 |
return chat_with_groq(client,formatted_guide,"llama3-70b-8192",None)
|
328 |
|
329 |
+
def queryModel(user_prompt:str,model:str = "llama3-70b-8192",api_key:str=userdata.get("GROQ_API_KEY")):
|
330 |
+
client = groq.Groq(api_key=api_key)
|
331 |
+
user_prompt = user_prompt_sanitization(user_prompt)
|
332 |
+
print(user_prompt)
|
333 |
+
full_prompt = base_prompt.format(
|
334 |
+
user_question=user_prompt,
|
335 |
+
table_description=table_description,
|
336 |
+
tables=tables_string,
|
337 |
+
table_1=table_1,
|
338 |
+
table_1_wt_xt=table_1_wt_xt
|
339 |
+
)
|
340 |
+
try:
|
341 |
+
response = chat_with_groq(client,full_prompt,model,{
|
342 |
+
"type":"json_object"
|
343 |
+
})
|
344 |
+
except Exception as e:
|
345 |
+
return [(
|
346 |
+
"Groq Advisor",
|
347 |
+
"Error: " + str(e)
|
348 |
+
)]
|
349 |
+
response = json.loads(response)
|
350 |
+
if "sql" in response:
|
351 |
+
sql_query = response["sql"]
|
352 |
try:
|
353 |
+
results_df = execute_duckdb_query(sql_query)
|
|
|
|
|
354 |
except Exception as e:
|
355 |
+
return [(
|
356 |
+
"Groq Advisor",
|
357 |
+
"Error: " + str(e)
|
358 |
+
)]
|
359 |
+
|
360 |
+
fotmatted_sql_query = sqlparse.format(sql_query, reindent=True, keyword_case='upper')
|
361 |
+
query_n_results = "SQL Query: " + fotmatted_sql_query + "\n\n" + results_df.to_markdown()
|
362 |
+
summarization = get_summarization(client,user_prompt,results_df,model)
|
363 |
+
query_n_results += "\n\n" + summarization
|
364 |
+
|
365 |
+
return [(
|
366 |
+
"Groq Advisor",
|
367 |
+
query_n_results
|
368 |
+
)]
|
369 |
+
elif "error" in response:
|
370 |
+
return [(
|
371 |
+
"Groq Advisor",
|
372 |
+
"Error: " + response["error"]
|
373 |
+
)]
|
374 |
+
else:
|
375 |
+
return [(
|
376 |
+
"Groq Advisor",
|
377 |
+
"Error: Unknown error"
|
378 |
+
)]
|
379 |
|
380 |
with gr.Blocks() as demo:
|
381 |
gr.Markdown("# CSV Database Query Interface")
|
|
|
385 |
upload_button = gr.Button("Load CSV Files")
|
386 |
upload_output = gr.Textbox(label="Upload Status", lines=5)
|
387 |
|
388 |
+
upload_button.click(upload_file, inputs=file_output, outputs=upload_output)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
389 |
with gr.Tab("Query Interface"):
|
390 |
chatbot = gr.Chatbot()
|
391 |
with gr.Row():
|
392 |
user_input = gr.Textbox(label="Enter your question")
|
393 |
submit_button = gr.Button("Submit")
|
394 |
+
submit_button.click(queryModel, inputs=[user_input], outputs=chatbot)
|
395 |
+
|
|
|
396 |
|
|
|
|
|
|
|
397 |
|
398 |
+
demo.launch(share=True)
|
399 |
|
400 |
|