Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -260,16 +260,15 @@ user_question = """"""
|
|
260 |
# * Valid values for product_name include 'Tesla','iPhone' and 'Humane pin'
|
261 |
|
262 |
|
263 |
-
def upload_file(files) ->
|
264 |
model = "llama3-8b-8192"
|
265 |
api_key: str = userdata.get("GROQ_API_KEY")
|
266 |
|
267 |
if isinstance(files, str):
|
268 |
files = [files]
|
269 |
|
270 |
-
|
271 |
stored_table_descriptions = []
|
272 |
-
tables = []
|
273 |
|
274 |
for file in files:
|
275 |
filename = Path(file.name).name
|
@@ -278,20 +277,15 @@ def upload_file(files) -> List[str]:
|
|
278 |
# Copy the content of the temporary file to our destination
|
279 |
shutil.copy2(file.name, path)
|
280 |
|
281 |
-
|
282 |
-
|
283 |
-
desc = f"Table: {filename}\nColumns:\n{
|
284 |
stored_table_descriptions.append(desc)
|
285 |
-
tables.append(filename)
|
286 |
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
table_1 = tables[0] if tables else ""
|
292 |
-
table_1_wt_xt = table_1.split('.')[0] if table_1 else ""
|
293 |
-
|
294 |
-
return "\n".join(stored_table_descriptions)
|
295 |
|
296 |
def user_prompt_sanitization(user_prompt:str)->str:
|
297 |
guide = """
|
@@ -314,58 +308,52 @@ def user_prompt_sanitization(user_prompt:str)->str:
|
|
314 |
client = groq.Groq(api_key=api_key)
|
315 |
return chat_with_groq(client,formatted_guide,"llama3-70b-8192",None)
|
316 |
|
317 |
-
def queryModel(user_prompt:str,model:str = "llama3-70b-8192",api_key:str=userdata.get("GROQ_API_KEY")):
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
)]
|
337 |
-
response = json.loads(response)
|
338 |
-
if "sql" in response:
|
339 |
-
sql_query = response["sql"]
|
340 |
try:
|
341 |
-
|
|
|
|
|
342 |
except Exception as e:
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
"
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
"Error:
|
363 |
-
|
364 |
-
|
365 |
-
return [(
|
366 |
-
"Groq Advisor",
|
367 |
-
"Error: Unknown error"
|
368 |
-
)]
|
369 |
|
370 |
with gr.Blocks() as demo:
|
371 |
gr.Markdown("# CSV Database Query Interface")
|
@@ -375,14 +363,27 @@ with gr.Blocks() as demo:
|
|
375 |
upload_button = gr.Button("Load CSV Files")
|
376 |
upload_output = gr.Textbox(label="Upload Status", lines=5)
|
377 |
|
378 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
379 |
|
380 |
with gr.Tab("Query Interface"):
|
381 |
chatbot = gr.Chatbot()
|
382 |
with gr.Row():
|
383 |
user_input = gr.Textbox(label="Enter your question")
|
384 |
submit_button = gr.Button("Submit")
|
385 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
386 |
|
387 |
demo.launch()
|
388 |
|
|
|
260 |
# * Valid values for product_name include 'Tesla','iPhone' and 'Humane pin'
|
261 |
|
262 |
|
263 |
+
def upload_file(files) -> dict:
|
264 |
model = "llama3-8b-8192"
|
265 |
api_key: str = userdata.get("GROQ_API_KEY")
|
266 |
|
267 |
if isinstance(files, str):
|
268 |
files = [files]
|
269 |
|
270 |
+
uploaded_files = []
|
271 |
stored_table_descriptions = []
|
|
|
272 |
|
273 |
for file in files:
|
274 |
filename = Path(file.name).name
|
|
|
277 |
# Copy the content of the temporary file to our destination
|
278 |
shutil.copy2(file.name, path)
|
279 |
|
280 |
+
uploaded_files.append(str(path))
|
281 |
+
table_description = identify_column_datatypes_to_SQL_DEF(pd.read_csv(path), api_key, model)
|
282 |
+
desc = f"Table: {filename}\nColumns:\n{table_description}"
|
283 |
stored_table_descriptions.append(desc)
|
|
|
284 |
|
285 |
+
return {
|
286 |
+
"files": uploaded_files,
|
287 |
+
"descriptions": "\n".join(stored_table_descriptions)
|
288 |
+
}
|
|
|
|
|
|
|
|
|
289 |
|
290 |
def user_prompt_sanitization(user_prompt:str)->str:
|
291 |
guide = """
|
|
|
308 |
client = groq.Groq(api_key=api_key)
|
309 |
return chat_with_groq(client,formatted_guide,"llama3-70b-8192",None)
|
310 |
|
311 |
+
def queryModel(user_prompt: str, files: List[str], model: str = "llama3-70b-8192", api_key: str = userdata.get("GROQ_API_KEY")):
|
312 |
+
client = groq.Groq(api_key=api_key)
|
313 |
+
user_prompt = user_prompt_sanitization(user_prompt)
|
314 |
+
print(user_prompt)
|
315 |
+
|
316 |
+
# Update this part to use the actual table names from the uploaded files
|
317 |
+
table_names = [Path(file).stem for file in files]
|
318 |
+
tables_string = join_with_and(table_names)
|
319 |
+
table_1 = table_names[0] if table_names else ""
|
320 |
+
table_1_wt_xt = table_1
|
321 |
+
|
322 |
+
full_prompt = base_prompt.format(
|
323 |
+
user_question=user_prompt,
|
324 |
+
table_description=table_description,
|
325 |
+
tables=tables_string,
|
326 |
+
table_1=table_1,
|
327 |
+
table_1_wt_xt=table_1_wt_xt
|
328 |
+
)
|
329 |
+
|
|
|
|
|
|
|
|
|
330 |
try:
|
331 |
+
response = chat_with_groq(client, full_prompt, model, {
|
332 |
+
"type": "json_object"
|
333 |
+
})
|
334 |
except Exception as e:
|
335 |
+
return [("Groq Advisor", "Error: " + str(e))]
|
336 |
+
|
337 |
+
response = json.loads(response)
|
338 |
+
if "sql" in response:
|
339 |
+
sql_query = response["sql"]
|
340 |
+
try:
|
341 |
+
results_df = execute_duckdb_query(sql_query, files)
|
342 |
+
except Exception as e:
|
343 |
+
return [("Groq Advisor", "Error: " + str(e))]
|
344 |
+
|
345 |
+
formatted_sql_query = sqlparse.format(sql_query, reindent=True, keyword_case='upper')
|
346 |
+
query_n_results = "SQL Query: " + formatted_sql_query + "\n\n" + results_df.to_markdown()
|
347 |
+
summarization = get_summarization(client, user_prompt, results_df, model)
|
348 |
+
query_n_results += "\n\n" + summarization
|
349 |
+
|
350 |
+
return [("Groq Advisor", query_n_results)]
|
351 |
+
elif "error" in response:
|
352 |
+
return [("Groq Advisor", "Error: " + response["error"])]
|
353 |
+
else:
|
354 |
+
return [("Groq Advisor", "Error: Unknown error")]
|
355 |
+
|
356 |
+
uploaded_files = gr.State([])
|
|
|
|
|
|
|
|
|
357 |
|
358 |
with gr.Blocks() as demo:
|
359 |
gr.Markdown("# CSV Database Query Interface")
|
|
|
363 |
upload_button = gr.Button("Load CSV Files")
|
364 |
upload_output = gr.Textbox(label="Upload Status", lines=5)
|
365 |
|
366 |
+
def handle_upload(files, uploaded_files):
|
367 |
+
result = upload_file(files)
|
368 |
+
uploaded_files.extend(result["files"])
|
369 |
+
return result["descriptions"], uploaded_files
|
370 |
+
|
371 |
+
upload_button.click(handle_upload,
|
372 |
+
inputs=[file_output, uploaded_files],
|
373 |
+
outputs=[upload_output, uploaded_files])
|
374 |
|
375 |
with gr.Tab("Query Interface"):
|
376 |
chatbot = gr.Chatbot()
|
377 |
with gr.Row():
|
378 |
user_input = gr.Textbox(label="Enter your question")
|
379 |
submit_button = gr.Button("Submit")
|
380 |
+
|
381 |
+
def query_model_with_files(user_prompt, files):
|
382 |
+
return queryModel(user_prompt, files)
|
383 |
+
|
384 |
+
submit_button.click(query_model_with_files,
|
385 |
+
inputs=[user_input, uploaded_files],
|
386 |
+
outputs=chatbot)
|
387 |
|
388 |
demo.launch()
|
389 |
|