blesspearl commited on
Commit
9767170
·
verified ·
1 Parent(s): 67ceaa7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -67
app.py CHANGED
@@ -260,16 +260,15 @@ user_question = """"""
260
  # * Valid values for product_name include 'Tesla','iPhone' and 'Humane pin'
261
 
262
 
263
- def upload_file(files) -> List[str]:
264
  model = "llama3-8b-8192"
265
  api_key: str = userdata.get("GROQ_API_KEY")
266
 
267
  if isinstance(files, str):
268
  files = [files]
269
 
270
- stored_paths = []
271
  stored_table_descriptions = []
272
- tables = []
273
 
274
  for file in files:
275
  filename = Path(file.name).name
@@ -278,20 +277,15 @@ def upload_file(files) -> List[str]:
278
  # Copy the content of the temporary file to our destination
279
  shutil.copy2(file.name, path)
280
 
281
- stored_paths.append(str(path))
282
- table_description_temp = identify_column_datatypes_to_SQL_DEF(pd.read_csv(path), api_key, model)
283
- desc = f"Table: {filename}\nColumns:\n{table_description_temp}"
284
  stored_table_descriptions.append(desc)
285
- tables.append(filename)
286
 
287
- # Update global variables
288
- global table_description, tables_string, table_1, table_1_wt_xt
289
- table_description = "\n".join(stored_table_descriptions)
290
- tables_string = join_with_and(tables)
291
- table_1 = tables[0] if tables else ""
292
- table_1_wt_xt = table_1.split('.')[0] if table_1 else ""
293
-
294
- return "\n".join(stored_table_descriptions)
295
 
296
  def user_prompt_sanitization(user_prompt:str)->str:
297
  guide = """
@@ -314,58 +308,52 @@ def user_prompt_sanitization(user_prompt:str)->str:
314
  client = groq.Groq(api_key=api_key)
315
  return chat_with_groq(client,formatted_guide,"llama3-70b-8192",None)
316
 
317
- def queryModel(user_prompt:str,model:str = "llama3-70b-8192",api_key:str=userdata.get("GROQ_API_KEY")):
318
- client = groq.Groq(api_key=api_key)
319
- user_prompt = user_prompt_sanitization(user_prompt)
320
- print(user_prompt)
321
- full_prompt = base_prompt.format(
322
- user_question=user_prompt,
323
- table_description=table_description,
324
- tables=tables_string,
325
- table_1=table_1,
326
- table_1_wt_xt=table_1_wt_xt
327
- )
328
- try:
329
- response = chat_with_groq(client,full_prompt,model,{
330
- "type":"json_object"
331
- })
332
- except Exception as e:
333
- return [(
334
- "Groq Advisor",
335
- "Error: " + str(e)
336
- )]
337
- response = json.loads(response)
338
- if "sql" in response:
339
- sql_query = response["sql"]
340
  try:
341
- results_df = execute_duckdb_query(sql_query)
 
 
342
  except Exception as e:
343
- return [(
344
- "Groq Advisor",
345
- "Error: " + str(e)
346
- )]
347
-
348
- fotmatted_sql_query = sqlparse.format(sql_query, reindent=True, keyword_case='upper')
349
- # print(f"SQL Query: {fotmatted_sql_query}")
350
- # print(results_df.to_markdown())
351
- query_n_results = "SQL Query: " + fotmatted_sql_query + "\n\n" + results_df.to_markdown()
352
- summarization = get_summarization(client,user_prompt,results_df,model)
353
- query_n_results += "\n\n" + summarization
354
-
355
- return [(
356
- "Groq Advisor",
357
- query_n_results
358
- )]
359
- elif "error" in response:
360
- return [(
361
- "Groq Advisor",
362
- "Error: " + response["error"]
363
- )]
364
- else:
365
- return [(
366
- "Groq Advisor",
367
- "Error: Unknown error"
368
- )]
369
 
370
  with gr.Blocks() as demo:
371
  gr.Markdown("# CSV Database Query Interface")
@@ -375,14 +363,27 @@ with gr.Blocks() as demo:
375
  upload_button = gr.Button("Load CSV Files")
376
  upload_output = gr.Textbox(label="Upload Status", lines=5)
377
 
378
- upload_button.click(upload_file, inputs=file_output, outputs=upload_output)
 
 
 
 
 
 
 
379
 
380
  with gr.Tab("Query Interface"):
381
  chatbot = gr.Chatbot()
382
  with gr.Row():
383
  user_input = gr.Textbox(label="Enter your question")
384
  submit_button = gr.Button("Submit")
385
- submit_button.click(queryModel, inputs=[user_input], outputs=chatbot)
 
 
 
 
 
 
386
 
387
  demo.launch()
388
 
 
260
  # * Valid values for product_name include 'Tesla','iPhone' and 'Humane pin'
261
 
262
 
263
+ def upload_file(files) -> dict:
264
  model = "llama3-8b-8192"
265
  api_key: str = userdata.get("GROQ_API_KEY")
266
 
267
  if isinstance(files, str):
268
  files = [files]
269
 
270
+ uploaded_files = []
271
  stored_table_descriptions = []
 
272
 
273
  for file in files:
274
  filename = Path(file.name).name
 
277
  # Copy the content of the temporary file to our destination
278
  shutil.copy2(file.name, path)
279
 
280
+ uploaded_files.append(str(path))
281
+ table_description = identify_column_datatypes_to_SQL_DEF(pd.read_csv(path), api_key, model)
282
+ desc = f"Table: {filename}\nColumns:\n{table_description}"
283
  stored_table_descriptions.append(desc)
 
284
 
285
+ return {
286
+ "files": uploaded_files,
287
+ "descriptions": "\n".join(stored_table_descriptions)
288
+ }
 
 
 
 
289
 
290
  def user_prompt_sanitization(user_prompt:str)->str:
291
  guide = """
 
308
  client = groq.Groq(api_key=api_key)
309
  return chat_with_groq(client,formatted_guide,"llama3-70b-8192",None)
310
 
311
+ def queryModel(user_prompt: str, files: List[str], model: str = "llama3-70b-8192", api_key: str = userdata.get("GROQ_API_KEY")):
312
+ client = groq.Groq(api_key=api_key)
313
+ user_prompt = user_prompt_sanitization(user_prompt)
314
+ print(user_prompt)
315
+
316
+ # Update this part to use the actual table names from the uploaded files
317
+ table_names = [Path(file).stem for file in files]
318
+ tables_string = join_with_and(table_names)
319
+ table_1 = table_names[0] if table_names else ""
320
+ table_1_wt_xt = table_1
321
+
322
+ full_prompt = base_prompt.format(
323
+ user_question=user_prompt,
324
+ table_description=table_description,
325
+ tables=tables_string,
326
+ table_1=table_1,
327
+ table_1_wt_xt=table_1_wt_xt
328
+ )
329
+
 
 
 
 
330
  try:
331
+ response = chat_with_groq(client, full_prompt, model, {
332
+ "type": "json_object"
333
+ })
334
  except Exception as e:
335
+ return [("Groq Advisor", "Error: " + str(e))]
336
+
337
+ response = json.loads(response)
338
+ if "sql" in response:
339
+ sql_query = response["sql"]
340
+ try:
341
+ results_df = execute_duckdb_query(sql_query, files)
342
+ except Exception as e:
343
+ return [("Groq Advisor", "Error: " + str(e))]
344
+
345
+ formatted_sql_query = sqlparse.format(sql_query, reindent=True, keyword_case='upper')
346
+ query_n_results = "SQL Query: " + formatted_sql_query + "\n\n" + results_df.to_markdown()
347
+ summarization = get_summarization(client, user_prompt, results_df, model)
348
+ query_n_results += "\n\n" + summarization
349
+
350
+ return [("Groq Advisor", query_n_results)]
351
+ elif "error" in response:
352
+ return [("Groq Advisor", "Error: " + response["error"])]
353
+ else:
354
+ return [("Groq Advisor", "Error: Unknown error")]
355
+
356
+ uploaded_files = gr.State([])
 
 
 
 
357
 
358
  with gr.Blocks() as demo:
359
  gr.Markdown("# CSV Database Query Interface")
 
363
  upload_button = gr.Button("Load CSV Files")
364
  upload_output = gr.Textbox(label="Upload Status", lines=5)
365
 
366
+ def handle_upload(files, uploaded_files):
367
+ result = upload_file(files)
368
+ uploaded_files.extend(result["files"])
369
+ return result["descriptions"], uploaded_files
370
+
371
+ upload_button.click(handle_upload,
372
+ inputs=[file_output, uploaded_files],
373
+ outputs=[upload_output, uploaded_files])
374
 
375
  with gr.Tab("Query Interface"):
376
  chatbot = gr.Chatbot()
377
  with gr.Row():
378
  user_input = gr.Textbox(label="Enter your question")
379
  submit_button = gr.Button("Submit")
380
+
381
+ def query_model_with_files(user_prompt, files):
382
+ return queryModel(user_prompt, files)
383
+
384
+ submit_button.click(query_model_with_files,
385
+ inputs=[user_input, uploaded_files],
386
+ outputs=chatbot)
387
 
388
  demo.launch()
389