blesspearl commited on
Commit
5933e78
·
verified ·
1 Parent(s): 9767170

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -88
app.py CHANGED
@@ -12,9 +12,6 @@ from dotenv import load_dotenv
12
  load_dotenv()
13
  userdata = os.environ
14
 
15
- DATA_DIR = Path(os.getcwd()) / "data"
16
- DATA_DIR.mkdir(parents=True, exist_ok=True)
17
-
18
 
19
  def chat_with_groq(client:groq.Groq,
20
  prompt:str,
@@ -49,19 +46,29 @@ def chat_with_groq(client:groq.Groq,
49
  # logger.info(f"Completion: {completion}")
50
  return completion.choices[0].message.content
51
 
52
- def execute_duckdb_query(query: str) -> pd.DataFrame:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  try:
54
  conn = duckdb.connect(database=":memory:", read_only=False)
55
-
56
- # Load all CSV files from the data directory
57
- for csv_file in DATA_DIR.glob("*.csv"):
58
- table_name = csv_file.stem
59
- conn.execute(f"CREATE TABLE {table_name} AS SELECT * FROM read_csv_auto('{csv_file}')")
60
-
61
  query_result = conn.execute(query).fetch_df().reset_index()
 
62
  return query_result
63
  except Exception as e:
64
- print(f"Error executing query: {e}")
 
65
  raise e
66
  def get_summarization(client:groq.Groq,
67
  use_question:str,
@@ -251,6 +258,10 @@ base_prompt = """
251
  * Ensure that the entire output is returned on only one single line
252
  * Keep your query as simple and straightforward as possible; do not use subqueries
253
  """
 
 
 
 
254
  user_question = """"""
255
 
256
  # And some rules for querying the dataset:
@@ -260,32 +271,39 @@ user_question = """"""
260
  # * Valid values for product_name include 'Tesla','iPhone' and 'Humane pin'
261
 
262
 
263
- def upload_file(files) -> dict:
 
264
  model = "llama3-8b-8192"
265
- api_key: str = userdata.get("GROQ_API_KEY")
266
-
267
- if isinstance(files, str):
 
268
  files = [files]
269
-
270
- uploaded_files = []
271
  stored_table_descriptions = []
272
-
273
  for file in files:
274
  filename = Path(file.name).name
275
- path = DATA_DIR / filename
276
 
277
  # Copy the content of the temporary file to our destination
278
- shutil.copy2(file.name, path)
 
279
 
280
- uploaded_files.append(str(path))
281
- table_description = identify_column_datatypes_to_SQL_DEF(pd.read_csv(path), api_key, model)
282
- desc = f"Table: {filename}\nColumns:\n{table_description}"
283
  stored_table_descriptions.append(desc)
284
-
285
- return {
286
- "files": uploaded_files,
287
- "descriptions": "\n".join(stored_table_descriptions)
288
- }
 
 
 
 
 
289
 
290
  def user_prompt_sanitization(user_prompt:str)->str:
291
  guide = """
@@ -308,52 +326,56 @@ def user_prompt_sanitization(user_prompt:str)->str:
308
  client = groq.Groq(api_key=api_key)
309
  return chat_with_groq(client,formatted_guide,"llama3-70b-8192",None)
310
 
311
- def queryModel(user_prompt: str, files: List[str], model: str = "llama3-70b-8192", api_key: str = userdata.get("GROQ_API_KEY")):
312
- client = groq.Groq(api_key=api_key)
313
- user_prompt = user_prompt_sanitization(user_prompt)
314
- print(user_prompt)
315
-
316
- # Update this part to use the actual table names from the uploaded files
317
- table_names = [Path(file).stem for file in files]
318
- tables_string = join_with_and(table_names)
319
- table_1 = table_names[0] if table_names else ""
320
- table_1_wt_xt = table_1
321
-
322
- full_prompt = base_prompt.format(
323
- user_question=user_prompt,
324
- table_description=table_description,
325
- tables=tables_string,
326
- table_1=table_1,
327
- table_1_wt_xt=table_1_wt_xt
328
- )
329
-
 
 
 
 
330
  try:
331
- response = chat_with_groq(client, full_prompt, model, {
332
- "type": "json_object"
333
- })
334
  except Exception as e:
335
- return [("Groq Advisor", "Error: " + str(e))]
336
-
337
- response = json.loads(response)
338
- if "sql" in response:
339
- sql_query = response["sql"]
340
- try:
341
- results_df = execute_duckdb_query(sql_query, files)
342
- except Exception as e:
343
- return [("Groq Advisor", "Error: " + str(e))]
344
-
345
- formatted_sql_query = sqlparse.format(sql_query, reindent=True, keyword_case='upper')
346
- query_n_results = "SQL Query: " + formatted_sql_query + "\n\n" + results_df.to_markdown()
347
- summarization = get_summarization(client, user_prompt, results_df, model)
348
- query_n_results += "\n\n" + summarization
349
-
350
- return [("Groq Advisor", query_n_results)]
351
- elif "error" in response:
352
- return [("Groq Advisor", "Error: " + response["error"])]
353
- else:
354
- return [("Groq Advisor", "Error: Unknown error")]
355
-
356
- uploaded_files = gr.State([])
 
 
357
 
358
  with gr.Blocks() as demo:
359
  gr.Markdown("# CSV Database Query Interface")
@@ -363,28 +385,16 @@ with gr.Blocks() as demo:
363
  upload_button = gr.Button("Load CSV Files")
364
  upload_output = gr.Textbox(label="Upload Status", lines=5)
365
 
366
- def handle_upload(files, uploaded_files):
367
- result = upload_file(files)
368
- uploaded_files.extend(result["files"])
369
- return result["descriptions"], uploaded_files
370
-
371
- upload_button.click(handle_upload,
372
- inputs=[file_output, uploaded_files],
373
- outputs=[upload_output, uploaded_files])
374
-
375
  with gr.Tab("Query Interface"):
376
  chatbot = gr.Chatbot()
377
  with gr.Row():
378
  user_input = gr.Textbox(label="Enter your question")
379
  submit_button = gr.Button("Submit")
380
-
381
- def query_model_with_files(user_prompt, files):
382
- return queryModel(user_prompt, files)
383
 
384
- submit_button.click(query_model_with_files,
385
- inputs=[user_input, uploaded_files],
386
- outputs=chatbot)
387
 
388
- demo.launch()
389
 
390
 
 
12
  load_dotenv()
13
  userdata = os.environ
14
 
 
 
 
15
 
16
  def chat_with_groq(client:groq.Groq,
17
  prompt:str,
 
46
  # logger.info(f"Completion: {completion}")
47
  return completion.choices[0].message.content
48
 
49
+ def execute_duckdb_query(query:str)->pd.DataFrame:
50
+ """
51
+ Execute a DuckDB query and return the result as a pandas DataFrame.
52
+
53
+ Args:
54
+ query (str): The DuckDB query to execute.
55
+
56
+ Returns:
57
+ pd.DataFrame: The result of the query as a pandas DataFrame.
58
+ """
59
+ original_cwd = os.getcwd()
60
+ print(f"PATH:{original_cwd}")
61
+ os.chdir('data')
62
+ print(f"PATH:{os.getcwd()}")
63
+
64
  try:
65
  conn = duckdb.connect(database=":memory:", read_only=False)
 
 
 
 
 
 
66
  query_result = conn.execute(query).fetch_df().reset_index()
67
+ os.chdir(original_cwd)
68
  return query_result
69
  except Exception as e:
70
+ print(f"Error: {e}")
71
+ os.chdir(original_cwd)
72
  raise e
73
  def get_summarization(client:groq.Groq,
74
  use_question:str,
 
258
  * Ensure that the entire output is returned on only one single line
259
  * Keep your query as simple and straightforward as possible; do not use subqueries
260
  """
261
+ table_description = """"""
262
+ tables_string = """"""
263
+ table_1 = """"""
264
+ table_1_wt_xt = """"""
265
  user_question = """"""
266
 
267
  # And some rules for querying the dataset:
 
271
  # * Valid values for product_name include 'Tesla','iPhone' and 'Humane pin'
272
 
273
 
274
+ def upload_file(files) -> List[str]:
275
+ # will have to change to the private system is initiializes
276
  model = "llama3-8b-8192"
277
+ api_key:str=userdata.get("GROQ_API_KEY")
278
+ data_dir = Path("data")
279
+ data_dir.mkdir(parents=True, exist_ok=True)
280
+ if type(files) == str:
281
  files = [files]
282
+ stored_paths = []
 
283
  stored_table_descriptions = []
284
+ tables = []
285
  for file in files:
286
  filename = Path(file.name).name
287
+ path = data_dir / filename
288
 
289
  # Copy the content of the temporary file to our destination
290
+ with open(file.name, "rb") as source, open(path, "wb") as destination:
291
+ destination.write(source.read())
292
 
293
+ stored_paths.append(str(path.absolute()))
294
+ table_description = identify_column_datatypes_to_SQL_DEF(pd.read_csv(path),api_key,model)
295
+ desc = "Table: " + filename + "\n Columns:\n" + table_description
296
  stored_table_descriptions.append(desc)
297
+ tables.append(filename)
298
+ # constructing a string
299
+ tables_string = join_with_and(tables)
300
+ final = "\n".join(stored_table_descriptions)
301
+ table_1_wt_xt = tables[0].split('.')[0]
302
+ table_description = final
303
+ tables_string = tables_string
304
+ table_1 = tables[0]
305
+ table_1_wt_xt = table_1_wt_xt
306
+ return final
307
 
308
  def user_prompt_sanitization(user_prompt:str)->str:
309
  guide = """
 
326
  client = groq.Groq(api_key=api_key)
327
  return chat_with_groq(client,formatted_guide,"llama3-70b-8192",None)
328
 
329
+ def queryModel(user_prompt:str,model:str = "llama3-70b-8192",api_key:str=userdata.get("GROQ_API_KEY")):
330
+ client = groq.Groq(api_key=api_key)
331
+ user_prompt = user_prompt_sanitization(user_prompt)
332
+ print(user_prompt)
333
+ full_prompt = base_prompt.format(
334
+ user_question=user_prompt,
335
+ table_description=table_description,
336
+ tables=tables_string,
337
+ table_1=table_1,
338
+ table_1_wt_xt=table_1_wt_xt
339
+ )
340
+ try:
341
+ response = chat_with_groq(client,full_prompt,model,{
342
+ "type":"json_object"
343
+ })
344
+ except Exception as e:
345
+ return [(
346
+ "Groq Advisor",
347
+ "Error: " + str(e)
348
+ )]
349
+ response = json.loads(response)
350
+ if "sql" in response:
351
+ sql_query = response["sql"]
352
  try:
353
+ results_df = execute_duckdb_query(sql_query)
 
 
354
  except Exception as e:
355
+ return [(
356
+ "Groq Advisor",
357
+ "Error: " + str(e)
358
+ )]
359
+
360
+ fotmatted_sql_query = sqlparse.format(sql_query, reindent=True, keyword_case='upper')
361
+ query_n_results = "SQL Query: " + fotmatted_sql_query + "\n\n" + results_df.to_markdown()
362
+ summarization = get_summarization(client,user_prompt,results_df,model)
363
+ query_n_results += "\n\n" + summarization
364
+
365
+ return [(
366
+ "Groq Advisor",
367
+ query_n_results
368
+ )]
369
+ elif "error" in response:
370
+ return [(
371
+ "Groq Advisor",
372
+ "Error: " + response["error"]
373
+ )]
374
+ else:
375
+ return [(
376
+ "Groq Advisor",
377
+ "Error: Unknown error"
378
+ )]
379
 
380
  with gr.Blocks() as demo:
381
  gr.Markdown("# CSV Database Query Interface")
 
385
  upload_button = gr.Button("Load CSV Files")
386
  upload_output = gr.Textbox(label="Upload Status", lines=5)
387
 
388
+ upload_button.click(upload_file, inputs=file_output, outputs=upload_output)
 
 
 
 
 
 
 
 
389
  with gr.Tab("Query Interface"):
390
  chatbot = gr.Chatbot()
391
  with gr.Row():
392
  user_input = gr.Textbox(label="Enter your question")
393
  submit_button = gr.Button("Submit")
394
+ submit_button.click(queryModel, inputs=[user_input], outputs=chatbot)
395
+
 
396
 
 
 
 
397
 
398
+ demo.launch(share=True)
399
 
400