DrishtiSharma commited on
Commit
70acfe7
Β·
verified Β·
1 Parent(s): 9f3c9dc

Update dummy_funcs.py

Browse files
Files changed (1) hide show
  1. dummy_funcs.py +62 -50
dummy_funcs.py CHANGED
@@ -228,16 +228,17 @@ def ask_gpt4o_for_visualization(query, df, llm, retries=2):
228
  numeric_columns = df.select_dtypes(include='number').columns.tolist()
229
  categorical_columns = df.select_dtypes(exclude='number').columns.tolist()
230
 
231
- # Enhanced Prompt with More Examples
232
  prompt = f"""
233
  Analyze the following query and suggest the most suitable visualization(s) using the dataset.
234
 
235
  **Query:** "{query}"
236
 
237
- **Numeric Columns (for Y-axis):** {', '.join(numeric_columns) if numeric_columns else 'None'}
238
- **Categorical Columns (for X-axis or grouping):** {', '.join(categorical_columns) if categorical_columns else 'None'}
 
239
 
240
- Suggest visualizations in this exact JSON format:
241
  [
242
  {{
243
  "chart_type": "bar/box/line/scatter/pie/heatmap",
@@ -249,83 +250,96 @@ def ask_gpt4o_for_visualization(query, df, llm, retries=2):
249
  }}
250
  ]
251
 
252
- **Examples:**
253
- - For salary distribution:
 
 
254
  {{
255
  "chart_type": "box",
256
  "x_axis": "job_title",
257
  "y_axis": "salary_in_usd",
258
  "group_by": "experience_level",
259
  "title": "Salary Distribution by Job Title and Experience",
260
- "description": "A box plot showing salary ranges across job titles and experience levels."
261
  }}
262
 
263
- - For company size comparison:
264
- {{
265
- "chart_type": "bar",
266
- "x_axis": "company_size",
267
- "y_axis": "salary_in_usd",
268
- "group_by": null,
269
- "title": "Average Salary by Company Size",
270
- "description": "A bar chart comparing the average salaries across different company sizes."
271
- }}
272
-
273
- - For revenue trends over time:
 
 
 
 
 
 
 
 
 
 
 
 
274
  {{
275
  "chart_type": "line",
276
- "x_axis": "year",
277
- "y_axis": "revenue",
278
- "group_by": null,
279
- "title": "Revenue Growth Over Years",
280
- "description": "A line chart showing the trend of revenue over the years."
281
  }}
282
 
283
- - For market share breakdown:
 
284
  {{
285
  "chart_type": "pie",
286
- "x_axis": "market_segment",
287
  "y_axis": null,
288
  "group_by": null,
289
- "title": "Market Share by Segment",
290
- "description": "A pie chart showing the distribution of market share across various segments."
291
  }}
292
 
293
- - For correlation analysis:
 
294
  {{
295
  "chart_type": "scatter",
296
- "x_axis": "years_of_experience",
297
  "y_axis": "salary_in_usd",
298
- "group_by": "job_title",
299
- "title": "Experience vs Salary by Job Title",
300
- "description": "A scatter plot showing the relationship between years of experience and salary across job titles."
301
  }}
302
 
303
- - For data density:
 
304
  {{
305
  "chart_type": "heatmap",
306
- "x_axis": "department",
307
- "y_axis": "region",
308
  "group_by": null,
309
- "title": "Employee Distribution by Department and Region",
310
- "description": "A heatmap showing the concentration of employees across departments and regions."
311
  }}
312
 
313
- Only suggest visualizations that make sense for the data and the query.
314
  """
315
 
 
316
  for attempt in range(retries + 1):
317
  try:
318
- # Generate response from the model
319
  response = llm.generate(prompt)
320
-
321
- # Load JSON response
322
  suggestions = json.loads(response)
323
 
324
- # Validate response structure
325
  if isinstance(suggestions, list):
326
- valid_suggestions = [
327
- s for s in suggestions if all(k in s for k in ["chart_type", "x_axis", "y_axis"])
328
- ]
329
  if valid_suggestions:
330
  return valid_suggestions
331
  else:
@@ -333,21 +347,19 @@ def ask_gpt4o_for_visualization(query, df, llm, retries=2):
333
  return None
334
 
335
  elif isinstance(suggestions, dict):
336
- if all(k in suggestions for k in ["chart_type", "x_axis", "y_axis"]):
337
  return [suggestions]
338
  else:
339
- st.warning("⚠️ GPT-4o's suggestion is incomplete.")
340
  return None
341
 
342
  except json.JSONDecodeError:
343
  st.warning(f"⚠️ Attempt {attempt + 1}: GPT-4o returned invalid JSON.")
344
  except Exception as e:
345
  st.error(f"⚠️ Error during GPT-4o call: {e}")
346
-
347
- # Retry if necessary
348
  if attempt < retries:
349
  st.info("πŸ”„ Retrying visualization suggestion...")
350
 
351
  st.error("❌ Failed to generate a valid visualization after multiple attempts.")
352
  return None
353
-
 
228
  numeric_columns = df.select_dtypes(include='number').columns.tolist()
229
  categorical_columns = df.select_dtypes(exclude='number').columns.tolist()
230
 
231
+ # Enhanced Prompt with Dataset-Specific, Query-Based Examples
232
  prompt = f"""
233
  Analyze the following query and suggest the most suitable visualization(s) using the dataset.
234
 
235
  **Query:** "{query}"
236
 
237
+ **Dataset Overview:**
238
+ - **Numeric Columns (for Y-axis):** {', '.join(numeric_columns) if numeric_columns else 'None'}
239
+ - **Categorical Columns (for X-axis or grouping):** {', '.join(categorical_columns) if categorical_columns else 'None'}
240
 
241
+ **Expected JSON Response:**
242
  [
243
  {{
244
  "chart_type": "bar/box/line/scatter/pie/heatmap",
 
250
  }}
251
  ]
252
 
253
+ **Query-Based Examples:**
254
+
255
+ - **Query:** "What is the salary distribution across different job titles?"
256
+ **Suggested Visualization:**
257
  {{
258
  "chart_type": "box",
259
  "x_axis": "job_title",
260
  "y_axis": "salary_in_usd",
261
  "group_by": "experience_level",
262
  "title": "Salary Distribution by Job Title and Experience",
263
+ "description": "A box plot to show how salaries vary across different job titles and experience levels."
264
  }}
265
 
266
+ - **Query:** "Show the average salary by company size and employment type."
267
+ **Suggested Visualizations:**
268
+ [
269
+ {{
270
+ "chart_type": "bar",
271
+ "x_axis": "company_size",
272
+ "y_axis": "salary_in_usd",
273
+ "group_by": "employment_type",
274
+ "title": "Average Salary by Company Size and Employment Type",
275
+ "description": "A grouped bar chart comparing average salaries across company sizes and employment types."
276
+ }},
277
+ {{
278
+ "chart_type": "heatmap",
279
+ "x_axis": "company_size",
280
+ "y_axis": "salary_in_usd",
281
+ "group_by": "employment_type",
282
+ "title": "Salary Heatmap by Company Size and Employment Type",
283
+ "description": "A heatmap showing salary concentration across company sizes and employment types."
284
+ }}
285
+ ]
286
+
287
+ - **Query:** "How has the average salary changed over the years?"
288
+ **Suggested Visualization:**
289
  {{
290
  "chart_type": "line",
291
+ "x_axis": "work_year",
292
+ "y_axis": "salary_in_usd",
293
+ "group_by": "experience_level",
294
+ "title": "Average Salary Trend Over Years",
295
+ "description": "A line chart showing how the average salary has changed across different experience levels over the years."
296
  }}
297
 
298
+ - **Query:** "What is the employee distribution by company location?"
299
+ **Suggested Visualization:**
300
  {{
301
  "chart_type": "pie",
302
+ "x_axis": "company_location",
303
  "y_axis": null,
304
  "group_by": null,
305
+ "title": "Employee Distribution by Company Location",
306
+ "description": "A pie chart showing the distribution of employees across company locations."
307
  }}
308
 
309
+ - **Query:** "Is there a relationship between remote work ratio and salary?"
310
+ **Suggested Visualization:**
311
  {{
312
  "chart_type": "scatter",
313
+ "x_axis": "remote_ratio",
314
  "y_axis": "salary_in_usd",
315
+ "group_by": "experience_level",
316
+ "title": "Remote Work Ratio vs Salary",
317
+ "description": "A scatter plot to analyze the relationship between remote work ratio and salary."
318
  }}
319
 
320
+ - **Query:** "Which job titles have the highest salaries across regions?"
321
+ **Suggested Visualization:**
322
  {{
323
  "chart_type": "heatmap",
324
+ "x_axis": "job_title",
325
+ "y_axis": "employee_residence",
326
  "group_by": null,
327
+ "title": "Salary Heatmap by Job Title and Region",
328
+ "description": "A heatmap showing the concentration of high-paying job titles across regions."
329
  }}
330
 
331
+ Only suggest visualizations that logically match the query and dataset.
332
  """
333
 
334
+ # Attempt LLM Response with Retry
335
  for attempt in range(retries + 1):
336
  try:
 
337
  response = llm.generate(prompt)
 
 
338
  suggestions = json.loads(response)
339
 
340
+ # Validate suggestions using helper
341
  if isinstance(suggestions, list):
342
+ valid_suggestions = [s for s in suggestions if is_valid_suggestion(s)]
 
 
343
  if valid_suggestions:
344
  return valid_suggestions
345
  else:
 
347
  return None
348
 
349
  elif isinstance(suggestions, dict):
350
+ if is_valid_suggestion(suggestions):
351
  return [suggestions]
352
  else:
353
+ st.warning("⚠️ GPT-4o's suggestion is incomplete or invalid.")
354
  return None
355
 
356
  except json.JSONDecodeError:
357
  st.warning(f"⚠️ Attempt {attempt + 1}: GPT-4o returned invalid JSON.")
358
  except Exception as e:
359
  st.error(f"⚠️ Error during GPT-4o call: {e}")
360
+
 
361
  if attempt < retries:
362
  st.info("πŸ”„ Retrying visualization suggestion...")
363
 
364
  st.error("❌ Failed to generate a valid visualization after multiple attempts.")
365
  return None