DrishtiSharma commited on
Commit
e4ab33c
Β·
verified Β·
1 Parent(s): 40b3f9c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -19
app.py CHANGED
@@ -110,6 +110,25 @@ if st.session_state.df is not None and st.session_state.show_preview:
110
  # st.error("⚠️ GPT-4o failed to generate a valid suggestion.")
111
  # return None
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  def ask_gpt4o_for_visualization(query, df, llm, retries=2):
114
  import json
115
 
@@ -117,14 +136,15 @@ def ask_gpt4o_for_visualization(query, df, llm, retries=2):
117
  numeric_columns = df.select_dtypes(include='number').columns.tolist()
118
  categorical_columns = df.select_dtypes(exclude='number').columns.tolist()
119
 
120
- # Enhanced Prompt with Clear Instructions
121
  prompt = f"""
122
  Analyze the following query and suggest the most suitable visualization(s) using the dataset.
123
 
124
  **Query:** "{query}"
125
 
126
- **Numeric Columns (for Y-axis):** {', '.join(numeric_columns) if numeric_columns else 'None'}
127
- **Categorical Columns (for X-axis or grouping):** {', '.join(categorical_columns) if categorical_columns else 'None'}
 
128
 
129
  Suggest visualizations in this exact JSON format:
130
  [
@@ -138,28 +158,85 @@ def ask_gpt4o_for_visualization(query, df, llm, retries=2):
138
  }}
139
  ]
140
 
141
- **Examples:**
142
- - For salary distribution:
 
 
143
  {{
144
  "chart_type": "box",
145
  "x_axis": "job_title",
146
  "y_axis": "salary_in_usd",
147
  "group_by": "experience_level",
148
  "title": "Salary Distribution by Job Title and Experience",
149
- "description": "A box plot showing salary ranges across job titles and experience levels."
150
  }}
151
 
152
- - For trend analysis:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  {{
154
  "chart_type": "line",
155
  "x_axis": "year",
156
  "y_axis": "revenue",
157
  "group_by": null,
158
- "title": "Revenue Growth Over Years",
159
- "description": "A line chart showing the trend of revenue over the years."
160
  }}
161
 
162
- Only suggest visualizations that make sense for the data and the query.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  """
164
 
165
  for attempt in range(retries + 1):
@@ -170,11 +247,9 @@ def ask_gpt4o_for_visualization(query, df, llm, retries=2):
170
  # Load JSON response
171
  suggestions = json.loads(response)
172
 
173
- # Validate response structure
174
  if isinstance(suggestions, list):
175
- valid_suggestions = [
176
- s for s in suggestions if all(k in s for k in ["chart_type", "x_axis", "y_axis"])
177
- ]
178
  if valid_suggestions:
179
  return valid_suggestions
180
  else:
@@ -182,18 +257,17 @@ def ask_gpt4o_for_visualization(query, df, llm, retries=2):
182
  return None
183
 
184
  elif isinstance(suggestions, dict):
185
- if all(k in suggestions for k in ["chart_type", "x_axis", "y_axis"]):
186
  return [suggestions]
187
  else:
188
- st.warning("⚠️ GPT-4o's suggestion is incomplete.")
189
  return None
190
 
191
  except json.JSONDecodeError:
192
  st.warning(f"⚠️ Attempt {attempt + 1}: GPT-4o returned invalid JSON.")
193
  except Exception as e:
194
  st.error(f"⚠️ Error during GPT-4o call: {e}")
195
-
196
- # Retry if necessary
197
  if attempt < retries:
198
  st.info("πŸ”„ Retrying visualization suggestion...")
199
 
@@ -201,7 +275,6 @@ def ask_gpt4o_for_visualization(query, df, llm, retries=2):
201
  return None
202
 
203
 
204
-
205
  def add_stats_to_figure(fig, df, y_axis, chart_type):
206
  """
207
  Add relevant statistical annotations to the visualization
 
110
  # st.error("⚠️ GPT-4o failed to generate a valid suggestion.")
111
  # return None
112
 
113
+
114
+
115
+
116
+ # Helper Function for Validation
117
+ def is_valid_suggestion(suggestion):
118
+ chart_type = suggestion.get("chart_type", "").lower()
119
+
120
+ if chart_type in ["bar", "line", "box", "scatter"]:
121
+ return all(k in suggestion for k in ["chart_type", "x_axis", "y_axis"])
122
+
123
+ elif chart_type == "pie":
124
+ return all(k in suggestion for k in ["chart_type", "x_axis"])
125
+
126
+ elif chart_type == "heatmap":
127
+ return all(k in suggestion for k in ["chart_type", "x_axis", "y_axis"])
128
+
129
+ else:
130
+ return False
131
+
132
  def ask_gpt4o_for_visualization(query, df, llm, retries=2):
133
  import json
134
 
 
136
  numeric_columns = df.select_dtypes(include='number').columns.tolist()
137
  categorical_columns = df.select_dtypes(exclude='number').columns.tolist()
138
 
139
+ # Enhanced Prompt with Diverse, Query-Based Examples
140
  prompt = f"""
141
  Analyze the following query and suggest the most suitable visualization(s) using the dataset.
142
 
143
  **Query:** "{query}"
144
 
145
+ **Dataset Overview:**
146
+ - **Numeric Columns (for Y-axis):** {', '.join(numeric_columns) if numeric_columns else 'None'}
147
+ - **Categorical Columns (for X-axis or grouping):** {', '.join(categorical_columns) if categorical_columns else 'None'}
148
 
149
  Suggest visualizations in this exact JSON format:
150
  [
 
158
  }}
159
  ]
160
 
161
+ **Query-Based Examples:**
162
+
163
+ - **Query:** "What is the salary distribution across different job titles?"
164
+ **Suggested Visualization:**
165
  {{
166
  "chart_type": "box",
167
  "x_axis": "job_title",
168
  "y_axis": "salary_in_usd",
169
  "group_by": "experience_level",
170
  "title": "Salary Distribution by Job Title and Experience",
171
+ "description": "A box plot to show how salaries vary across different job titles and experience levels."
172
  }}
173
 
174
+ - **Query:** "Show the average salary by company size and industry."
175
+ **Suggested Visualizations:**
176
+ [
177
+ {{
178
+ "chart_type": "bar",
179
+ "x_axis": "company_size",
180
+ "y_axis": "salary_in_usd",
181
+ "group_by": "industry",
182
+ "title": "Average Salary by Company Size and Industry",
183
+ "description": "A grouped bar chart comparing average salaries across company sizes and industries."
184
+ }},
185
+ {{
186
+ "chart_type": "heatmap",
187
+ "x_axis": "industry",
188
+ "y_axis": "company_size",
189
+ "group_by": null,
190
+ "title": "Salary Heatmap by Industry and Company Size",
191
+ "description": "A heatmap showing salary concentration across industries and company sizes."
192
+ }}
193
+ ]
194
+
195
+ - **Query:** "How has the company's revenue changed over the years?"
196
+ **Suggested Visualization:**
197
  {{
198
  "chart_type": "line",
199
  "x_axis": "year",
200
  "y_axis": "revenue",
201
  "group_by": null,
202
+ "title": "Yearly Revenue Growth",
203
+ "description": "A line chart showing revenue growth over time."
204
  }}
205
 
206
+ - **Query:** "What is the market share of each product category?"
207
+ **Suggested Visualization:**
208
+ {{
209
+ "chart_type": "pie",
210
+ "x_axis": "product_category",
211
+ "y_axis": null,
212
+ "group_by": null,
213
+ "title": "Market Share by Product Category",
214
+ "description": "A pie chart to show the market share distribution across different product categories."
215
+ }}
216
+
217
+ - **Query:** "Is there a correlation between years of experience and salary?"
218
+ **Suggested Visualization:**
219
+ {{
220
+ "chart_type": "scatter",
221
+ "x_axis": "years_of_experience",
222
+ "y_axis": "salary_in_usd",
223
+ "group_by": "job_title",
224
+ "title": "Experience vs Salary by Job Title",
225
+ "description": "A scatter plot to analyze the relationship between experience and salary across different job titles."
226
+ }}
227
+
228
+ - **Query:** "Which departments have the highest concentration of employees across regions?"
229
+ **Suggested Visualization:**
230
+ {{
231
+ "chart_type": "heatmap",
232
+ "x_axis": "department",
233
+ "y_axis": "region",
234
+ "group_by": null,
235
+ "title": "Employee Distribution by Department and Region",
236
+ "description": "A heatmap to visualize employee density across departments and regions."
237
+ }}
238
+
239
+ Only suggest visualizations that logically match the query and dataset.
240
  """
241
 
242
  for attempt in range(retries + 1):
 
247
  # Load JSON response
248
  suggestions = json.loads(response)
249
 
250
+ # Validate response structure using the helper function
251
  if isinstance(suggestions, list):
252
+ valid_suggestions = [s for s in suggestions if is_valid_suggestion(s)]
 
 
253
  if valid_suggestions:
254
  return valid_suggestions
255
  else:
 
257
  return None
258
 
259
  elif isinstance(suggestions, dict):
260
+ if is_valid_suggestion(suggestions):
261
  return [suggestions]
262
  else:
263
+ st.warning("⚠️ GPT-4o's suggestion is incomplete or invalid.")
264
  return None
265
 
266
  except json.JSONDecodeError:
267
  st.warning(f"⚠️ Attempt {attempt + 1}: GPT-4o returned invalid JSON.")
268
  except Exception as e:
269
  st.error(f"⚠️ Error during GPT-4o call: {e}")
270
+
 
271
  if attempt < retries:
272
  st.info("πŸ”„ Retrying visualization suggestion...")
273
 
 
275
  return None
276
 
277
 
 
278
  def add_stats_to_figure(fig, df, y_axis, chart_type):
279
  """
280
  Add relevant statistical annotations to the visualization