import pandas as pd
import plotly.express as px
import streamlit as st
from pandas.io.formats.style import Styler

from utils import get_leaderboard, get_model_ranks


def header(title: str) -> None:
    st.title(title)
    st.markdown(
        """
    [EnFoBench](https://github.com/attila-balint-kul/energy-forecast-benchmark-toolkit) 
    is a community driven benchmarking framework for energy forecasting models. 
    """
    )
    st.divider()


def logos() -> None:
    left, right = st.columns(2)
    with left:
        st.image("./images/ku_leuven_logo.png")
    with right:
        st.image("./images/energyville_logo.png")


def links(current: str) -> None:
    st.header("Sources")
    st.link_button(
        "GitHub Repository",
        url="https://github.com/attila-balint-kul/energy-forecast-benchmark-toolkit",
        use_container_width=True,
    )
    st.link_button(
        "Documentation",
        url="https://attila-balint-kul.github.io/energy-forecast-benchmark-toolkit/",
        use_container_width=True,
    )
    st.link_button(
        "Electricity Demand Dataset",
        url="https://huggingface.co/datasets/EDS-lab/electricity-demand",
        use_container_width=True,
    )
    st.link_button(
        "HuggingFace Organization",
        url="https://huggingface.co/EDS-lab",
        use_container_width=True,
    )

    st.header("Other Dashboards")
    if current != "ElectricityDemand":
        st.link_button(
            "Electricity Demand",
            url="https://huggingface.co/spaces/EDS-lab/EnFoBench-ElectricityDemand",
            use_container_width=True,
        )
    if current != "GasDemand":
        st.link_button(
            "Gas Demand",
            url="https://huggingface.co/spaces/EDS-lab/EnFoBench-GasDemand",
            use_container_width=True,
        )
    if current != "PVGeneration":
        st.link_button(
            "PVGeneration",
            url="https://huggingface.co/spaces/EDS-lab/EnFoBench-PVGeneration",
            use_container_width=True,
        )


def model_selector(models: list[str], data: pd.DataFrame) -> set[str]:
    # Group models by their prefix
    model_groups: dict[str, list[str]] = {}
    for model in models:
        group, model_name = model.split(".", maxsplit=1)
        if group not in model_groups:
            model_groups[group] = []
        model_groups[group].append(model_name)

    models_to_plot = set()

    st.header("Models to include")
    left, middle, right = st.columns(3)
    with left:
        best_by_mae = st.button("Best by MAE", use_container_width=True)
        if best_by_mae:
            best_models_by_mae = get_model_ranks(data, "MAE.mean").head(10).model.tolist()
            for model in models:
                if model in best_models_by_mae:
                    st.session_state[model] = True
                else:
                    st.session_state[model] = False
    with middle:
        best_by_rmse = st.button("Best by RMSE", use_container_width=True)
        if best_by_rmse:
            best_models_by_rmse = get_model_ranks(data, "RMSE.mean").head(10).model.tolist()
            for model in models:
                if model in best_models_by_rmse:
                    st.session_state[model] = True
                else:
                    st.session_state[model] = False
    with right:
        best_by_rmae = st.button("Best by rMAE", use_container_width=True)
        if best_by_rmae:
            best_models_by_rmae = get_model_ranks(data, "rMAE.mean").head(10).model.tolist()
            for model in models:
                if model in best_models_by_rmae:
                    st.session_state[model] = True
                else:
                    st.session_state[model] = False

    left, right = st.columns(2)
    with left:
        select_none = st.button("Select None", use_container_width=True)
        if select_none:
            for model in models:
                st.session_state[model] = False
    with right:
        select_all = st.button("Select All", use_container_width=True)
        if select_all:
            for model in models:
                st.session_state[model] = True

    for model_group, models in model_groups.items():
        st.text(model_group)
        for model_name in models:
            to_plot = st.checkbox(
                model_name, value=True, key=f"{model_group}.{model_name}"
            )
            if to_plot:
                models_to_plot.add(f"{model_group}.{model_name}")
    return models_to_plot


def overview_view(data: pd.DataFrame):
    st.markdown("## Leaderboard")

    leaderboard = get_leaderboard(data, ["MAE.mean", "RMSE.mean", "rMAE.mean"])

    left, middle, right = st.columns(3)
    with left:
        best_models_mae = (
            leaderboard.sort_values("MAE.mean", ascending=False)
            .head(10)
            .sort_values("MAE.mean")
        )
        fig = px.bar(best_models_mae, x="MAE.mean", y=best_models_mae.index)
        fig.update_layout(
            title="Top 10 models by MAE",
            xaxis_title="",
            yaxis_title="Model",
            height=600,
        )
        st.plotly_chart(fig, use_container_width=True)

    with middle:
        best_models_mae = (
            leaderboard.sort_values("RMSE.mean", ascending=False)
            .head(10)
            .sort_values("RMSE.mean")
        )
        fig = px.bar(best_models_mae, x="RMSE.mean", y=best_models_mae.index)
        fig.update_layout(
            title="Top 10 models by RMSE", xaxis_title="", yaxis_title="", height=600
        )
        st.plotly_chart(fig, use_container_width=True)

    with right:
        best_models_mae = (
            leaderboard.sort_values("rMAE.mean", ascending=False)
            .head(10)
            .sort_values("rMAE.mean")
        )
        fig = px.bar(best_models_mae, x="rMAE.mean", y=best_models_mae.index)
        fig.update_layout(
            title="Top 10 models by rMAE", xaxis_title="", yaxis_title="", height=600
        )
        st.plotly_chart(fig, use_container_width=True)

    st.dataframe(leaderboard, use_container_width=True)


def buildings_view(data: pd.DataFrame):
    if 'metadata.cluster_size' not in data.columns:
        data['metadata.cluster_size'] = 1
    if 'metadata.building_class' not in data.columns:
        data['metadata.building_class'] = "Unknown"

    buildings = (
        data[
            [
                "unique_id",
                "metadata.cluster_size",
                "metadata.building_class",
                "metadata.location_id",
                "metadata.timezone",
                "dataset.available_history.days",
                "dataset.available_history.observations",
                "metadata.freq",
            ]
        ]
        .groupby("unique_id")
        .first()
        .rename(
            columns={
                "metadata.cluster_size": "Cluster size",
                "metadata.building_class": "Building class",
                "metadata.location_id": "Location ID",
                "metadata.timezone": "Timezone",
                "dataset.available_history.days": "Available history (days)",
                "dataset.available_history.observations": "Available history (#)",
                "metadata.freq": "Frequency",
            }
        )
    )

    left, middle, right = st.columns(3)
    with left:
        st.metric("Number of buildings", data["unique_id"].nunique())
    with middle:
        st.metric(
            "Residential",
            data[data["metadata.building_class"] == "Residential"][
                "unique_id"
            ].nunique(),
        )
    with right:
        st.metric(
            "Commercial",
            data[data["metadata.building_class"] == "Commercial"][
                "unique_id"
            ].nunique(),
        )
    st.divider()

    left, middle, right = st.columns(3, gap="large")
    with left:
        st.markdown("#### Building classes")
        fig = px.pie(
            buildings.groupby("Building class").size().reset_index(),
            values=0,
            names="Building class",
        )
        fig.update_layout(
            legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)
        )
        st.plotly_chart(fig, use_container_width=True)

    with middle:
        st.markdown("#### Timezones")
        fig = px.pie(
            buildings.groupby("Timezone").size().reset_index(),
            values=0,
            names="Timezone",
        )
        fig.update_layout(
            legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)
        )
        st.plotly_chart(fig, use_container_width=True)

    with right:
        st.markdown("#### Frequencies")
        fig = px.pie(
            buildings.groupby("Frequency").size().reset_index(),
            values=0,
            names="Frequency",
        )
        fig.update_layout(
            legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)
        )
        st.plotly_chart(fig, use_container_width=True)

    st.divider()

    st.markdown("#### Buildings")
    st.dataframe(
        buildings.sort_values("Available history (days)"),
        use_container_width=True,
        column_config={
            "Available history (days)": st.column_config.ProgressColumn(
                "Available history (days)",
                help="Available training data during the first prediction.",
                format="%f",
                min_value=0,
                max_value=float(buildings["Available history (days)"].max()),
            ),
            "Available history (#)": st.column_config.ProgressColumn(
                "Available history (#)",
                help="Available training data during the first prediction.",
                format="%f",
                min_value=0,
                max_value=float(buildings["Available history (#)"].max()),
            ),
        },
    )


def models_view(data: pd.DataFrame):
    models = (
        data[
            [
                "model",
                "cv_config.folds",
                "cv_config.horizon",
                "cv_config.step",
                "cv_config.time",
                "model_info.repository",
                "model_info.tag",
                "model_info.variate_type",
            ]
        ]
        .groupby("model")
        .first()
        .rename(
            columns={
                "cv_config.folds": "CV Folds",
                "cv_config.horizon": "CV Horizon",
                "cv_config.step": "CV Step",
                "cv_config.time": "CV Time",
                "model_info.repository": "Image Repository",
                "model_info.tag": "Image Tag",
                "model_info.variate_type": "Variate type",
            }
        )
    )

    left, middle, right = st.columns(3)
    with left:
        st.metric("Models", len(models))
    with middle:
        st.metric(
            "Univariate",
            data[data["model_info.variate_type"] == "univariate"]["model"].nunique(),
        )
    with right:
        st.metric(
            "Univariate",
            data[data["model_info.variate_type"] == "multivariate"]["model"].nunique(),
        )
    st.divider()

    left, right = st.columns(2, gap="large")
    with left:
        st.markdown("#### Variate types")
        fig = px.pie(
            models.groupby("Variate type").size().reset_index(),
            values=0,
            names="Variate type",
        )
        st.plotly_chart(fig, use_container_width=True)

    with right:
        st.markdown("#### Frameworks")
        _df = models.copy()
        _df["Framework"] = _df.index.str.split(".").str[0]
        fig = px.pie(
            _df.groupby("Framework").size().reset_index(),
            values=0,
            names="Framework",
        )
        st.plotly_chart(fig, use_container_width=True)

    st.divider()
    st.markdown("### Models")
    st.dataframe(models, use_container_width=True)


def accuracy_view(data: pd.DataFrame, models_to_plot: set[str]):
    data_to_plot = data[data["model"].isin(models_to_plot)].sort_values(
        by="model", ascending=True
    )

    left, right = st.columns(2, gap="small")
    with left:
        metric = st.selectbox("Metric", ["MAE", "RMSE", "MBE", "rMAE"], index=0)
    with right:
        aggregation = st.selectbox(
            "Aggregation", ["min", "mean", "median", "max", "std"], index=1
        )
    st.markdown(f"#### {aggregation.capitalize()} {metric} per building")

    if data_to_plot.empty:
        st.warning("No data to display.")
    else:
        model_ranks = get_model_ranks(data_to_plot, f"{metric}.{aggregation}")

        fig = px.box(
            data_to_plot.merge(model_ranks, on="model").sort_values(by="rank"),
            x=f"{metric}.{aggregation}",
            y="model",
            color="model",
            points="all",
        )
        fig.update_layout(showlegend=False, height=50 * len(models_to_plot))
        st.plotly_chart(fig, use_container_width=True)

    st.divider()

    left, right = st.columns(2, gap="large")
    with left:
        x_metric = st.selectbox(
            "Metric", ["MAE", "RMSE", "MBE", "rMAE"], index=0, key="x_metric"
        )
        x_aggregation = st.selectbox(
            "Aggregation",
            ["min", "mean", "median", "max", "std"],
            index=1,
            key="x_aggregation",
        )
    with right:
        y_metric = st.selectbox(
            "Aggregation", ["MAE", "RMSE", "MBE", "rMAE"], index=1, key="y_metric"
        )
        y_aggregation = st.selectbox(
            "Aggregation",
            ["min", "mean", "median", "max", "std"],
            index=1,
            key="y_aggregation",
        )

    st.markdown(
        f"#### {x_aggregation.capitalize()} {x_metric} vs {y_aggregation.capitalize()} {y_metric}"
    )
    if data_to_plot.empty:
        st.warning("No data to display.")
    else:
        fig = px.scatter(
            data_to_plot,
            x=f"{x_metric}.{x_aggregation}",
            y=f"{y_metric}.{y_aggregation}",
            color="model",
        )
        fig.update_layout(height=600)
        st.plotly_chart(fig, use_container_width=True)

    st.divider()

    left, right = st.columns(2, gap="small")
    with left:
        metric = st.selectbox(
            "Metric", ["MAE", "RMSE", "MBE", "rMAE"], index=0, key="table_metric"
        )
    with right:
        aggregation = st.selectbox(
            "Aggregation across folds",
            ["min", "mean", "median", "max", "std"],
            index=1,
            key="table_aggregation",
        )

    metrics_table = data_to_plot.groupby(["model"]).agg(aggregation, numeric_only=True)[
        [
            f"{metric}.min",
            f"{metric}.mean",
            f"{metric}.median",
            f"{metric}.max",
            f"{metric}.std",
        ]
    ].sort_values(by=f"{metric}.mean")

    def custom_table(styler):
        styler.background_gradient(cmap="seismic", axis=0)
        styler.format(precision=2)

        # center text and increase font size
        styler.map(lambda x: "text-align: center; font-size: 14px;")
        return styler

    st.markdown(f"#### {aggregation.capitalize()} {metric} stats per model")
    styled_table = metrics_table.style.pipe(custom_table)
    st.dataframe(styled_table, use_container_width=True)

    metrics_per_building_table = (
        data_to_plot.groupby(["model", "unique_id"])
        .apply(aggregation, numeric_only=True)
        .reset_index()
        .pivot(index="model", columns="unique_id", values=f"{metric}.{aggregation}")
    )
    metrics_per_building_table.insert(
        0, "mean", metrics_per_building_table.mean(axis=1)
    )
    metrics_per_building_table = metrics_per_building_table.sort_values(by="mean").drop(columns="mean")

    def custom_table(styler: Styler):
        styler.background_gradient(cmap="seismic", axis=None)
        styler.format(precision=2)

        # center text and increase font size
        styler.map(lambda x: "text-align: center; font-size: 14px;")
        return styler

    st.markdown(f"#### {aggregation.capitalize()} {metric} stats per building")
    styled_table = metrics_per_building_table.style.pipe(custom_table)
    st.dataframe(styled_table, use_container_width=True)


def relative_performance_view(data: pd.DataFrame, models_to_plot: set[str]):
    data_to_plot = data[data["model"].isin(models_to_plot)].sort_values(
        by="model", ascending=True
    )

    st.markdown("#### Relative performance")
    if data_to_plot.empty:
        st.warning("No data to display.")
    else:
        baseline_choices = sorted(
            data.filter(like="better_than")
            .columns.str.removeprefix("better_than.")
            .tolist()
        )
        if len(baseline_choices) > 1:
            better_than_baseline = st.selectbox("Baseline model", options=baseline_choices)
        else:
            better_than_baseline = baseline_choices[0]
        data_to_plot.loc[:, f"better_than.{better_than_baseline}.percentage"] = (
            pd.json_normalize(data_to_plot[f"better_than.{better_than_baseline}"])[
                "percentage"
            ].values
            * 100
        )
        model_rank = get_model_ranks(data_to_plot, f"better_than.{better_than_baseline}.percentage")

        fig = px.box(
            data_to_plot.merge(model_rank).sort_values(by="rank"),
            x=f"better_than.{better_than_baseline}.percentage",
            y="model",
            points="all",
        )
        fig.update_xaxes(range=[0, 100], title_text="Better than baseline (%)")
        fig.update_layout(
            showlegend=False,
            height=50 * len(models_to_plot),
            title=f"Better than {better_than_baseline} on % of days per building",
        )
        st.plotly_chart(fig, use_container_width=True)


def computation_view(data: pd.DataFrame, models_to_plot: set[str]):
    data_to_plot = data[data["model"].isin(models_to_plot)].sort_values(
        by="model", ascending=True
    )
    data_to_plot["resource_usage.CPU"] /= 3600

    st.markdown("#### Computational Resources")

    left, center, right = st.columns(3, gap="small")
    with left:
        metric = st.selectbox("Metric", ["MAE", "RMSE", "MBE", "rMAE"], index=0)
    with center:
        aggregation_per_building = st.selectbox(
            "Aggregation per building", ["min", "mean", "median", "max", "std"], index=1
        )
    with right:
        aggregation_per_model = st.selectbox(
            "Aggregation per model", ["min", "mean", "median", "max", "std"], index=1
        )

    st.markdown(
        f"#### {aggregation_per_model.capitalize()} {aggregation_per_building.capitalize()} {metric} vs CPU usage"
    )
    if data_to_plot.empty:
        st.warning("No data to display.")
    else:
        aggregated_data = (
            data_to_plot.groupby("model")
            .agg(aggregation_per_building, numeric_only=True)
            .reset_index()
        )
        fig = px.scatter(
            aggregated_data,
            x="resource_usage.CPU",
            y=f"{metric}.{aggregation_per_model}",
            color="model",
            log_x=True,
        )
        fig.update_layout(height=600)
        fig.update_xaxes(title_text="CPU usage (hours)")
        fig.update_yaxes(
            title_text=f"{metric} ({aggregation_per_building}, {aggregation_per_model})"
        )
        st.plotly_chart(fig, use_container_width=True)

    st.divider()

    st.markdown("#### Computational time vs historical data")
    if data_to_plot.empty:
        st.warning("No data to display.")
    else:
        fig = px.scatter(
            data_to_plot,
            x="dataset.available_history.observations",
            y="resource_usage.CPU",
            color="model",
            trendline="ols",
            hover_data=["model", "unique_id"],
        )
        fig.update_layout(height=600)
        fig.update_xaxes(title_text="Available historical observations (#)")
        fig.update_yaxes(title_text="CPU usage (hours)")
        st.plotly_chart(fig, use_container_width=True)

    st.divider()

    cpu_per_building_table = (
        data_to_plot.pivot(index="model", columns="unique_id", values="resource_usage.CPU")
    )

    def custom_table(styler: Styler):
        styler.background_gradient(cmap="seismic", axis=None)
        styler.format(precision=2)

        # center text and increase font size
        styler.map(lambda x: "text-align: center; font-size: 14px;")
        return styler

    st.markdown(f"#### Computational time per building")
    styled_table = cpu_per_building_table.style.pipe(custom_table)
    st.dataframe(styled_table, use_container_width=True)