def ask_gpt4o_for_visualization(query, df, llm):
    columns = ', '.join(df.columns)
    prompt = f"""
    Analyze the query and suggest one or more relevant visualizations.
    Query: "{query}"
    Available Columns: {columns}
    Respond in this JSON format (as a list if multiple suggestions):
    [
      {{
        "chart_type": "bar/box/line/scatter",
        "x_axis": "column_name",
        "y_axis": "column_name",
        "group_by": "optional_column_name"
      }}
    ]
    """
    response = llm.generate(prompt)
    try:
        return json.loads(response)
    except json.JSONDecodeError:
        st.error("⚠️ GPT-4o failed to generate a valid suggestion.")
        return None

def add_stats_to_figure(fig, df, y_axis, chart_type):
    """
    Add relevant statistical annotations to the visualization 
    based on the chart type.
    """
    # Check if the y-axis column is numeric
    if not pd.api.types.is_numeric_dtype(df[y_axis]):
        st.warning(f"⚠️ Cannot compute statistics for non-numeric column: {y_axis}")
        return fig

    # Compute statistics for numeric data
    min_val = df[y_axis].min()
    max_val = df[y_axis].max()
    avg_val = df[y_axis].mean()
    median_val = df[y_axis].median()
    std_dev_val = df[y_axis].std()

    # Format the stats for display
    stats_text = (
        f"📊 **Statistics**\n\n"
        f"- **Min:** ${min_val:,.2f}\n"
        f"- **Max:** ${max_val:,.2f}\n"
        f"- **Average:** ${avg_val:,.2f}\n"
        f"- **Median:** ${median_val:,.2f}\n"
        f"- **Std Dev:** ${std_dev_val:,.2f}"
    )

    # Apply stats only to relevant chart types
    if chart_type in ["bar", "line"]:
        # Add annotation box for bar and line charts
        fig.add_annotation(
            text=stats_text,
            xref="paper", yref="paper",
            x=1.02, y=1,
            showarrow=False,
            align="left",
            font=dict(size=12, color="black"),
            bordercolor="gray",
            borderwidth=1,
            bgcolor="rgba(255, 255, 255, 0.85)"
        )

        # Add horizontal reference lines
        fig.add_hline(y=min_val, line_dash="dot", line_color="red", annotation_text="Min", annotation_position="bottom right")
        fig.add_hline(y=median_val, line_dash="dash", line_color="orange", annotation_text="Median", annotation_position="top right")
        fig.add_hline(y=avg_val, line_dash="dashdot", line_color="green", annotation_text="Avg", annotation_position="top right")
        fig.add_hline(y=max_val, line_dash="dot", line_color="blue", annotation_text="Max", annotation_position="top right")

    elif chart_type == "scatter":
        # Add stats annotation only, no lines for scatter plots
        fig.add_annotation(
            text=stats_text,
            xref="paper", yref="paper",
            x=1.02, y=1,
            showarrow=False,
            align="left",
            font=dict(size=12, color="black"),
            bordercolor="gray",
            borderwidth=1,
            bgcolor="rgba(255, 255, 255, 0.85)"
        )

    elif chart_type == "box":
        # Box plots inherently show distribution; no extra stats needed
        pass

    elif chart_type == "pie":
        # Pie charts represent proportions, not suitable for stats
        st.info("📊 Pie charts represent proportions. Additional stats are not applicable.")

    elif chart_type == "heatmap":
        # Heatmaps already reflect data intensity
        st.info("📊 Heatmaps inherently reflect distribution. No additional stats added.")

    else:
        st.warning(f"⚠️ No statistical overlays applied for unsupported chart type: '{chart_type}'.")

    return fig


# Dynamically generate Plotly visualizations based on GPT-4o suggestions
def generate_visualization(suggestion, df):
    """
    Generate a Plotly visualization based on GPT-4o's suggestion.
    If the Y-axis is missing, infer it intelligently.
    """
    chart_type = suggestion.get("chart_type", "bar").lower()
    x_axis = suggestion.get("x_axis")
    y_axis = suggestion.get("y_axis")
    group_by = suggestion.get("group_by")

    # Step 1: Infer Y-axis if not provided
    if not y_axis:
        numeric_columns = df.select_dtypes(include='number').columns.tolist()

        # Avoid using the same column for both axes
        if x_axis in numeric_columns:
            numeric_columns.remove(x_axis)

        # Smart guess: prioritize salary or relevant metrics if available
        priority_columns = ["salary_in_usd", "income", "earnings", "revenue"]
        for col in priority_columns:
            if col in numeric_columns:
                y_axis = col
                break

        # Fallback to the first numeric column if no priority columns exist
        if not y_axis and numeric_columns:
            y_axis = numeric_columns[0]

    # Step 2: Validate axes
    if not x_axis or not y_axis:
        st.warning("⚠️ Unable to determine appropriate columns for visualization.")
        return None

    #  Step 3: Dynamically select the Plotly function
    plotly_function = getattr(px, chart_type, None)
    if not plotly_function:
        st.warning(f"⚠️ Unsupported chart type '{chart_type}' suggested by GPT-4o.")
        return None

    #  Step 4: Prepare dynamic plot arguments
    plot_args = {"data_frame": df, "x": x_axis, "y": y_axis}
    if group_by and group_by in df.columns:
        plot_args["color"] = group_by

    try:
        # Step 5: Generate the visualization
        fig = plotly_function(**plot_args)
        fig.update_layout(
            title=f"{chart_type.title()} Plot of {y_axis.replace('_', ' ').title()} by {x_axis.replace('_', ' ').title()}",
            xaxis_title=x_axis.replace('_', ' ').title(),
            yaxis_title=y_axis.replace('_', ' ').title(),
        )

        # Step 6: Apply statistics intelligently
        fig = add_statistics_to_visualization(fig, df, y_axis, chart_type)

        return fig

    except Exception as e:
        st.error(f"⚠️ Failed to generate visualization: {e}")
        return None


def generate_multiple_visualizations(suggestions, df):
    """
    Generates one or more visualizations based on GPT-4o's suggestions.
    Handles both single and multiple suggestions.
    """
    visualizations = []

    for suggestion in suggestions:
        fig = generate_visualization(suggestion, df)
        if fig:
            # Apply chart-specific statistics
            fig = add_stats_to_figure(fig, df, suggestion["y_axis"], suggestion["chart_type"])
            visualizations.append(fig)

    if not visualizations and suggestions:
        st.warning("⚠️ No valid visualization found. Displaying the most relevant one.")
        best_suggestion = suggestions[0]
        fig = generate_visualization(best_suggestion, df)
        fig = add_stats_to_figure(fig, df, best_suggestion["y_axis"], best_suggestion["chart_type"])
        visualizations.append(fig)

    return visualizations


def handle_visualization_suggestions(suggestions, df):
    """
    Determines whether to generate a single or multiple visualizations.
    """
    visualizations = []

    # If multiple suggestions, generate multiple plots
    if isinstance(suggestions, list) and len(suggestions) > 1:
        visualizations = generate_multiple_visualizations(suggestions, df)
    
    # If only one suggestion, generate a single plot
    elif isinstance(suggestions, dict) or (isinstance(suggestions, list) and len(suggestions) == 1):
        suggestion = suggestions[0] if isinstance(suggestions, list) else suggestions
        fig = generate_visualization(suggestion, df)
        if fig:
            visualizations.append(fig)
    
    # Handle cases when no visualization could be generated
    if not visualizations:
        st.warning("⚠️ Unable to generate any visualization based on the suggestion.")

    # Display all generated visualizations
    for fig in visualizations:
        st.plotly_chart(fig, use_container_width=True)





-----------------

def ask_gpt4o_for_visualization(query, df, llm, retries=2):
    import json

    # Identify numeric and categorical columns
    numeric_columns = df.select_dtypes(include='number').columns.tolist()
    categorical_columns = df.select_dtypes(exclude='number').columns.tolist()

    # Enhanced Prompt with More Examples
    prompt = f"""
    Analyze the following query and suggest the most suitable visualization(s) using the dataset.

    **Query:** "{query}"

    **Numeric Columns (for Y-axis):** {', '.join(numeric_columns) if numeric_columns else 'None'}
    **Categorical Columns (for X-axis or grouping):** {', '.join(categorical_columns) if categorical_columns else 'None'}

    Suggest visualizations in this exact JSON format:
    [
      {{
        "chart_type": "bar/box/line/scatter/pie/heatmap",
        "x_axis": "categorical_or_time_column",
        "y_axis": "numeric_column",
        "group_by": "optional_column_for_grouping",
        "title": "Title of the chart",
        "description": "Why this chart is suitable"
      }}
    ]

    **Examples:**  
    - For salary distribution:  
      {{
        "chart_type": "box",
        "x_axis": "job_title",
        "y_axis": "salary_in_usd",
        "group_by": "experience_level",
        "title": "Salary Distribution by Job Title and Experience",
        "description": "A box plot showing salary ranges across job titles and experience levels."
      }}

    - For company size comparison:  
      {{
        "chart_type": "bar",
        "x_axis": "company_size",
        "y_axis": "salary_in_usd",
        "group_by": null,
        "title": "Average Salary by Company Size",
        "description": "A bar chart comparing the average salaries across different company sizes."
      }}

    - For revenue trends over time:  
      {{
        "chart_type": "line",
        "x_axis": "year",
        "y_axis": "revenue",
        "group_by": null,
        "title": "Revenue Growth Over Years",
        "description": "A line chart showing the trend of revenue over the years."
      }}

    - For market share breakdown:  
      {{
        "chart_type": "pie",
        "x_axis": "market_segment",
        "y_axis": null,
        "group_by": null,
        "title": "Market Share by Segment",
        "description": "A pie chart showing the distribution of market share across various segments."
      }}

    - For correlation analysis:  
      {{
        "chart_type": "scatter",
        "x_axis": "years_of_experience",
        "y_axis": "salary_in_usd",
        "group_by": "job_title",
        "title": "Experience vs Salary by Job Title",
        "description": "A scatter plot showing the relationship between years of experience and salary across job titles."
      }}

    - For data density:  
      {{
        "chart_type": "heatmap",
        "x_axis": "department",
        "y_axis": "region",
        "group_by": null,
        "title": "Employee Distribution by Department and Region",
        "description": "A heatmap showing the concentration of employees across departments and regions."
      }}

    Only suggest visualizations that make sense for the data and the query.
    """

    for attempt in range(retries + 1):
        try:
            # Generate response from the model
            response = llm.generate(prompt)

            # Load JSON response
            suggestions = json.loads(response)

            # Validate response structure
            if isinstance(suggestions, list):
                valid_suggestions = [
                    s for s in suggestions if all(k in s for k in ["chart_type", "x_axis", "y_axis"])
                ]
                if valid_suggestions:
                    return valid_suggestions
                else:
                    st.warning("⚠️ GPT-4o did not suggest valid visualizations.")
                    return None

            elif isinstance(suggestions, dict):
                if all(k in suggestions for k in ["chart_type", "x_axis", "y_axis"]):
                    return [suggestions]
                else:
                    st.warning("⚠️ GPT-4o's suggestion is incomplete.")
                    return None

        except json.JSONDecodeError:
            st.warning(f"⚠️ Attempt {attempt + 1}: GPT-4o returned invalid JSON.")
        except Exception as e:
            st.error(f"⚠️ Error during GPT-4o call: {e}")
        
        # Retry if necessary
        if attempt < retries:
            st.info("🔄 Retrying visualization suggestion...")

    st.error("❌ Failed to generate a valid visualization after multiple attempts.")
    return None