DrishtiSharma commited on
Commit
d325b19
Β·
verified Β·
1 Parent(s): d1f7f7b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -10
app.py CHANGED
@@ -246,35 +246,38 @@ def add_stats_to_figure(fig, df, y_axis):
246
  )
247
  return fig
248
 
249
- # Unified visualization function with LLM fallback
250
  def generate_visual_from_query(query, df, llm=None):
251
  try:
 
252
  matched_columns = fuzzy_match_columns(query)
253
 
254
- # Fallback to LLM if fuzzy matching fails
255
  if not matched_columns and llm:
256
  st.info("πŸ€– No match found. Asking AI for suggestions...")
257
  suggestion = ask_llm_for_columns(query, llm, df)
258
  if suggestion:
259
  matched_columns = [suggestion.get("x_axis"), suggestion.get("group_by")]
260
 
261
- # Handle cases when we have columns to plot
262
  if len(matched_columns) >= 2:
263
  x_axis, group_by = matched_columns[0], matched_columns[1]
264
  elif len(matched_columns) == 1:
265
  x_axis, group_by = matched_columns[0], None
266
  else:
267
- st.warning("❓ No matching columns found. Please refine your query.")
268
  return None
269
 
270
- # Handle distribution queries
 
 
271
  if "distribution" in query:
272
  fig = px.box(df, x=x_axis, y="salary_in_usd", color=group_by,
273
  title=f"Salary Distribution by {x_axis.replace('_', ' ').title()}"
274
  + (f" and {group_by.replace('_', ' ').title()}" if group_by else ""))
275
  return add_stats_to_figure(fig, df, "salary_in_usd")
276
 
277
- # Handle average salary queries
278
  elif "average" in query or "mean" in query:
279
  grouped_df = df.groupby([x_axis] + ([group_by] if group_by else []))["salary_in_usd"].mean().reset_index()
280
  fig = px.bar(grouped_df, x=x_axis, y="salary_in_usd", color=group_by,
@@ -282,22 +285,23 @@ def generate_visual_from_query(query, df, llm=None):
282
  + (f" and {group_by.replace('_', ' ').title()}" if group_by else ""))
283
  return add_stats_to_figure(fig, df, "salary_in_usd")
284
 
285
- # Handle salary trends over time
286
  elif "trend" in query and "work_year" in df.columns:
287
  grouped_df = df.groupby(["work_year", x_axis])["salary_in_usd"].mean().reset_index()
288
  fig = px.line(grouped_df, x="work_year", y="salary_in_usd", color=x_axis,
289
  title=f"Salary Trend Over Years by {x_axis.replace('_', ' ').title()}")
290
  return add_stats_to_figure(fig, df, "salary_in_usd")
291
 
292
- # Handle remote work impact
293
  elif "remote" in query:
294
  grouped_df = df.groupby(["remote_ratio"] + ([group_by] if group_by else []))["salary_in_usd"].mean().reset_index()
295
  fig = px.bar(grouped_df, x="remote_ratio", y="salary_in_usd", color=group_by,
296
  title="Remote Work Impact on Salary")
297
  return add_stats_to_figure(fig, df, "salary_in_usd")
298
 
 
299
  else:
300
- st.warning("⚠️ No suitable visualization generated. Please refine your query.")
301
  return None
302
 
303
  except Exception as e:
@@ -305,7 +309,6 @@ def generate_visual_from_query(query, df, llm=None):
305
  return None
306
 
307
 
308
-
309
  # SQL-RAG Analysis
310
  if st.session_state.df is not None:
311
  temp_dir = tempfile.TemporaryDirectory()
 
246
  )
247
  return fig
248
 
249
+ # Unified Visualization Generator with Fuzzy Matching and LLM Fallback
250
  def generate_visual_from_query(query, df, llm=None):
251
  try:
252
+ # Step 1: Attempt Fuzzy Matching
253
  matched_columns = fuzzy_match_columns(query)
254
 
255
+ # Step 2: Fallback to LLM if no columns are matched
256
  if not matched_columns and llm:
257
  st.info("πŸ€– No match found. Asking AI for suggestions...")
258
  suggestion = ask_llm_for_columns(query, llm, df)
259
  if suggestion:
260
  matched_columns = [suggestion.get("x_axis"), suggestion.get("group_by")]
261
 
262
+ # Step 3: Process Matched Columns
263
  if len(matched_columns) >= 2:
264
  x_axis, group_by = matched_columns[0], matched_columns[1]
265
  elif len(matched_columns) == 1:
266
  x_axis, group_by = matched_columns[0], None
267
  else:
268
+ st.warning("❓ No matching columns found. Try rephrasing your query.")
269
  return None
270
 
271
+ # Step 4: Visualization Generation
272
+
273
+ # Distribution Plot
274
  if "distribution" in query:
275
  fig = px.box(df, x=x_axis, y="salary_in_usd", color=group_by,
276
  title=f"Salary Distribution by {x_axis.replace('_', ' ').title()}"
277
  + (f" and {group_by.replace('_', ' ').title()}" if group_by else ""))
278
  return add_stats_to_figure(fig, df, "salary_in_usd")
279
 
280
+ # Average Salary Plot
281
  elif "average" in query or "mean" in query:
282
  grouped_df = df.groupby([x_axis] + ([group_by] if group_by else []))["salary_in_usd"].mean().reset_index()
283
  fig = px.bar(grouped_df, x=x_axis, y="salary_in_usd", color=group_by,
 
285
  + (f" and {group_by.replace('_', ' ').title()}" if group_by else ""))
286
  return add_stats_to_figure(fig, df, "salary_in_usd")
287
 
288
+ # Salary Trends Over Time
289
  elif "trend" in query and "work_year" in df.columns:
290
  grouped_df = df.groupby(["work_year", x_axis])["salary_in_usd"].mean().reset_index()
291
  fig = px.line(grouped_df, x="work_year", y="salary_in_usd", color=x_axis,
292
  title=f"Salary Trend Over Years by {x_axis.replace('_', ' ').title()}")
293
  return add_stats_to_figure(fig, df, "salary_in_usd")
294
 
295
+ # Remote Work Impact
296
  elif "remote" in query:
297
  grouped_df = df.groupby(["remote_ratio"] + ([group_by] if group_by else []))["salary_in_usd"].mean().reset_index()
298
  fig = px.bar(grouped_df, x="remote_ratio", y="salary_in_usd", color=group_by,
299
  title="Remote Work Impact on Salary")
300
  return add_stats_to_figure(fig, df, "salary_in_usd")
301
 
302
+ # No Specific Match
303
  else:
304
+ st.warning("⚠️ No suitable visualization to display!")
305
  return None
306
 
307
  except Exception as e:
 
309
  return None
310
 
311
 
 
312
  # SQL-RAG Analysis
313
  if st.session_state.df is not None:
314
  temp_dir = tempfile.TemporaryDirectory()