def ask_gpt4o_for_visualization(query, df, llm): columns = ', '.join(df.columns) prompt = f""" Analyze the query and suggest one or more relevant visualizations. Query: "{query}" Available Columns: {columns} Respond in this JSON format (as a list if multiple suggestions): [ {{ "chart_type": "bar/box/line/scatter", "x_axis": "column_name", "y_axis": "column_name", "group_by": "optional_column_name" }} ] """ response = llm.generate(prompt) try: return json.loads(response) except json.JSONDecodeError: st.error("⚠️ GPT-4o failed to generate a valid suggestion.") return None def add_stats_to_figure(fig, df, y_axis, chart_type): """ Add relevant statistical annotations to the visualization based on the chart type. """ # Check if the y-axis column is numeric if not pd.api.types.is_numeric_dtype(df[y_axis]): st.warning(f"⚠️ Cannot compute statistics for non-numeric column: {y_axis}") return fig # Compute statistics for numeric data min_val = df[y_axis].min() max_val = df[y_axis].max() avg_val = df[y_axis].mean() median_val = df[y_axis].median() std_dev_val = df[y_axis].std() # Format the stats for display stats_text = ( f"📊 **Statistics**\n\n" f"- **Min:** ${min_val:,.2f}\n" f"- **Max:** ${max_val:,.2f}\n" f"- **Average:** ${avg_val:,.2f}\n" f"- **Median:** ${median_val:,.2f}\n" f"- **Std Dev:** ${std_dev_val:,.2f}" ) # Apply stats only to relevant chart types if chart_type in ["bar", "line"]: # Add annotation box for bar and line charts fig.add_annotation( text=stats_text, xref="paper", yref="paper", x=1.02, y=1, showarrow=False, align="left", font=dict(size=12, color="black"), bordercolor="gray", borderwidth=1, bgcolor="rgba(255, 255, 255, 0.85)" ) # Add horizontal reference lines fig.add_hline(y=min_val, line_dash="dot", line_color="red", annotation_text="Min", annotation_position="bottom right") fig.add_hline(y=median_val, line_dash="dash", line_color="orange", annotation_text="Median", annotation_position="top right") fig.add_hline(y=avg_val, line_dash="dashdot", line_color="green", annotation_text="Avg", annotation_position="top right") fig.add_hline(y=max_val, line_dash="dot", line_color="blue", annotation_text="Max", annotation_position="top right") elif chart_type == "scatter": # Add stats annotation only, no lines for scatter plots fig.add_annotation( text=stats_text, xref="paper", yref="paper", x=1.02, y=1, showarrow=False, align="left", font=dict(size=12, color="black"), bordercolor="gray", borderwidth=1, bgcolor="rgba(255, 255, 255, 0.85)" ) elif chart_type == "box": # Box plots inherently show distribution; no extra stats needed pass elif chart_type == "pie": # Pie charts represent proportions, not suitable for stats st.info("📊 Pie charts represent proportions. Additional stats are not applicable.") elif chart_type == "heatmap": # Heatmaps already reflect data intensity st.info("📊 Heatmaps inherently reflect distribution. No additional stats added.") else: st.warning(f"⚠️ No statistical overlays applied for unsupported chart type: '{chart_type}'.") return fig # Dynamically generate Plotly visualizations based on GPT-4o suggestions def generate_visualization(suggestion, df): """ Generate a Plotly visualization based on GPT-4o's suggestion. If the Y-axis is missing, infer it intelligently. """ chart_type = suggestion.get("chart_type", "bar").lower() x_axis = suggestion.get("x_axis") y_axis = suggestion.get("y_axis") group_by = suggestion.get("group_by") # Step 1: Infer Y-axis if not provided if not y_axis: numeric_columns = df.select_dtypes(include='number').columns.tolist() # Avoid using the same column for both axes if x_axis in numeric_columns: numeric_columns.remove(x_axis) # Smart guess: prioritize salary or relevant metrics if available priority_columns = ["salary_in_usd", "income", "earnings", "revenue"] for col in priority_columns: if col in numeric_columns: y_axis = col break # Fallback to the first numeric column if no priority columns exist if not y_axis and numeric_columns: y_axis = numeric_columns[0] # Step 2: Validate axes if not x_axis or not y_axis: st.warning("⚠️ Unable to determine appropriate columns for visualization.") return None # Step 3: Dynamically select the Plotly function plotly_function = getattr(px, chart_type, None) if not plotly_function: st.warning(f"⚠️ Unsupported chart type '{chart_type}' suggested by GPT-4o.") return None # Step 4: Prepare dynamic plot arguments plot_args = {"data_frame": df, "x": x_axis, "y": y_axis} if group_by and group_by in df.columns: plot_args["color"] = group_by try: # Step 5: Generate the visualization fig = plotly_function(**plot_args) fig.update_layout( title=f"{chart_type.title()} Plot of {y_axis.replace('_', ' ').title()} by {x_axis.replace('_', ' ').title()}", xaxis_title=x_axis.replace('_', ' ').title(), yaxis_title=y_axis.replace('_', ' ').title(), ) # Step 6: Apply statistics intelligently fig = add_statistics_to_visualization(fig, df, y_axis, chart_type) return fig except Exception as e: st.error(f"⚠️ Failed to generate visualization: {e}") return None def generate_multiple_visualizations(suggestions, df): """ Generates one or more visualizations based on GPT-4o's suggestions. Handles both single and multiple suggestions. """ visualizations = [] for suggestion in suggestions: fig = generate_visualization(suggestion, df) if fig: # Apply chart-specific statistics fig = add_stats_to_figure(fig, df, suggestion["y_axis"], suggestion["chart_type"]) visualizations.append(fig) if not visualizations and suggestions: st.warning("⚠️ No valid visualization found. Displaying the most relevant one.") best_suggestion = suggestions[0] fig = generate_visualization(best_suggestion, df) fig = add_stats_to_figure(fig, df, best_suggestion["y_axis"], best_suggestion["chart_type"]) visualizations.append(fig) return visualizations def handle_visualization_suggestions(suggestions, df): """ Determines whether to generate a single or multiple visualizations. """ visualizations = [] # If multiple suggestions, generate multiple plots if isinstance(suggestions, list) and len(suggestions) > 1: visualizations = generate_multiple_visualizations(suggestions, df) # If only one suggestion, generate a single plot elif isinstance(suggestions, dict) or (isinstance(suggestions, list) and len(suggestions) == 1): suggestion = suggestions[0] if isinstance(suggestions, list) else suggestions fig = generate_visualization(suggestion, df) if fig: visualizations.append(fig) # Handle cases when no visualization could be generated if not visualizations: st.warning("⚠️ Unable to generate any visualization based on the suggestion.") # Display all generated visualizations for fig in visualizations: st.plotly_chart(fig, use_container_width=True) ----------------- def ask_gpt4o_for_visualization(query, df, llm, retries=2): import json # Identify numeric and categorical columns numeric_columns = df.select_dtypes(include='number').columns.tolist() categorical_columns = df.select_dtypes(exclude='number').columns.tolist() # Enhanced Prompt with More Examples prompt = f""" Analyze the following query and suggest the most suitable visualization(s) using the dataset. **Query:** "{query}" **Numeric Columns (for Y-axis):** {', '.join(numeric_columns) if numeric_columns else 'None'} **Categorical Columns (for X-axis or grouping):** {', '.join(categorical_columns) if categorical_columns else 'None'} Suggest visualizations in this exact JSON format: [ {{ "chart_type": "bar/box/line/scatter/pie/heatmap", "x_axis": "categorical_or_time_column", "y_axis": "numeric_column", "group_by": "optional_column_for_grouping", "title": "Title of the chart", "description": "Why this chart is suitable" }} ] **Examples:** - For salary distribution: {{ "chart_type": "box", "x_axis": "job_title", "y_axis": "salary_in_usd", "group_by": "experience_level", "title": "Salary Distribution by Job Title and Experience", "description": "A box plot showing salary ranges across job titles and experience levels." }} - For company size comparison: {{ "chart_type": "bar", "x_axis": "company_size", "y_axis": "salary_in_usd", "group_by": null, "title": "Average Salary by Company Size", "description": "A bar chart comparing the average salaries across different company sizes." }} - For revenue trends over time: {{ "chart_type": "line", "x_axis": "year", "y_axis": "revenue", "group_by": null, "title": "Revenue Growth Over Years", "description": "A line chart showing the trend of revenue over the years." }} - For market share breakdown: {{ "chart_type": "pie", "x_axis": "market_segment", "y_axis": null, "group_by": null, "title": "Market Share by Segment", "description": "A pie chart showing the distribution of market share across various segments." }} - For correlation analysis: {{ "chart_type": "scatter", "x_axis": "years_of_experience", "y_axis": "salary_in_usd", "group_by": "job_title", "title": "Experience vs Salary by Job Title", "description": "A scatter plot showing the relationship between years of experience and salary across job titles." }} - For data density: {{ "chart_type": "heatmap", "x_axis": "department", "y_axis": "region", "group_by": null, "title": "Employee Distribution by Department and Region", "description": "A heatmap showing the concentration of employees across departments and regions." }} Only suggest visualizations that make sense for the data and the query. """ for attempt in range(retries + 1): try: # Generate response from the model response = llm.generate(prompt) # Load JSON response suggestions = json.loads(response) # Validate response structure if isinstance(suggestions, list): valid_suggestions = [ s for s in suggestions if all(k in s for k in ["chart_type", "x_axis", "y_axis"]) ] if valid_suggestions: return valid_suggestions else: st.warning("⚠️ GPT-4o did not suggest valid visualizations.") return None elif isinstance(suggestions, dict): if all(k in suggestions for k in ["chart_type", "x_axis", "y_axis"]): return [suggestions] else: st.warning("⚠️ GPT-4o's suggestion is incomplete.") return None except json.JSONDecodeError: st.warning(f"⚠️ Attempt {attempt + 1}: GPT-4o returned invalid JSON.") except Exception as e: st.error(f"⚠️ Error during GPT-4o call: {e}") # Retry if necessary if attempt < retries: st.info("🔄 Retrying visualization suggestion...") st.error("❌ Failed to generate a valid visualization after multiple attempts.") return None