import streamlit as st
from main import benchmark_model_multithreaded, benchmark_model_sequential
from prompts import questions as predefined_questions
import requests
import pandas as pd

# Set the title in the browser tab
st.set_page_config(page_title="Aidan Bench - Generator")

st.title("Aidan Bench - Generator")

# API Key Inputs with Security and User Experience Enhancements
st.warning("Please keep your API keys secure and confidential. This app does not store or log your API keys.")

if "open_router_key" not in st.session_state:
    st.session_state.open_router_key = ""
if "openai_api_key" not in st.session_state:
    st.session_state.openai_api_key = ""

open_router_key = st.text_input("Enter your Open Router API Key:", type="password", value=st.session_state.open_router_key)
openai_api_key = st.text_input("Enter your OpenAI API Key:", type="password", value=st.session_state.openai_api_key)

if st.button("Confirm API Keys"):
    if open_router_key and openai_api_key:
        st.session_state.open_router_key = open_router_key
        st.session_state.openai_api_key = openai_api_key
        st.success("API keys confirmed!")
    else:
        st.warning("Please enter both API keys.")

# Access API keys from session state
if st.session_state.open_router_key and st.session_state.openai_api_key:
    # Fetch models from OpenRouter API
    try:
        response = requests.get("https://openrouter.ai/api/v1/models")
        response.raise_for_status()  # Raise an exception for bad status codes
        all_models = response.json()["data"]
        # Sort models alphabetically by their ID
        all_models.sort(key=lambda model: model["id"])

        # --- Create dictionaries for easy model lookup ---
        models_by_id = {model["id"]: model for model in all_models}
        judge_models = [model["id"] for model in all_models if "gpt" in model["id"]]
        judge_models.sort()

        model_names = list(models_by_id.keys())
    except requests.exceptions.RequestException as e:
        st.error(f"Error fetching models from OpenRouter API: {e}")
        model_names = []  # Provide an empty list if API call fails
        judge_models = []

    # Model Selection
    if model_names:
        model_name = st.selectbox("Select a Contestant Model", model_names)
        # --- Display pricing for the selected model ---
        selected_model = models_by_id.get(model_name)
        if selected_model:
            pricing_info = selected_model.get('pricing', {})
            prompt_price = float(pricing_info.get("prompt", 0)) * 1000000
            completion_price = float(pricing_info.get("completion", 0)) * 1000000

            # Display pricing information with increased precision
            st.write(f"**Prompt Pricing:** ${prompt_price:.2f}/Million tokens (if applicable)")
            st.write(f"**Completion Pricing:** ${completion_price:.2f}/Million tokens")
        else:
            st.write("**Pricing:** N/A")
    else:
        st.error("No models available. Please check your API connection.")
        st.stop()

    # Judge Model Selection
    if judge_models:
        judge_model_name = st.selectbox("Select a Judge Model", judge_models)
        # --- Display pricing for the selected judge model ---
        selected_judge_model = models_by_id.get(judge_model_name)
        if selected_judge_model:
            pricing_info = selected_judge_model.get('pricing', {})
            prompt_price = float(pricing_info.get("prompt", 0)) * 1000000
            completion_price = float(pricing_info.get("completion", 0)) * 1000000

            # Display pricing information with increased precision
            st.write(f"**Prompt Pricing:** ${prompt_price:.2f}/Million tokens (if applicable)")
            st.write(f"**Completion Pricing:** ${completion_price:.2f}/Million tokens")
        else:
            st.write("**Pricing:** N/A")
    else:
        st.error("No judge models available. Please check your API connection.")
        st.stop()


    # Initialize session state for user_questions and predefined_questions
    if "user_questions" not in st.session_state:
        st.session_state.user_questions = []

    # Threshold Sliders
    st.sidebar.subheader("Threshold Sliders")
    coherence_threshold = st.sidebar.slider("Coherence Threshold (0-5):", 0, 5, 3)
    novelty_threshold = st.sidebar.slider("Novelty Threshold (0-1):", 0.0, 1.0, 0.1)

    st.sidebar.subheader("Temp Sliders")
    temp_threshold = st.sidebar.slider("Temperature (0-2):", 0.0, 2.0, 1.0)
    top_p = st.sidebar.slider("Top P (0-1):", 0.0, 1.0, 1.0)

    # Workflow Selection
    workflow = st.radio("Select Workflow:", ["Use Predefined Questions", "Use User-Defined Questions"])

    # Handle Predefined Questions
    if workflow == "Use Predefined Questions":
        st.header("Question Selection")
        # Multiselect for predefined questions
        selected_questions = st.multiselect(
            "Select questions to benchmark:",
            predefined_questions,
            predefined_questions  # Select all by default
        )

    # Handle User-Defined Questions
    elif workflow == "Use User-Defined Questions":
        st.header("Question Input")

        # Input for adding a new question
        new_question = st.text_input("Enter a new question:")
        if st.button("Add Question") and new_question:
            new_question = new_question.strip()  # Remove leading/trailing whitespace
            if new_question and new_question not in st.session_state.user_questions:
                st.session_state.user_questions.append(new_question)  # Append to session state
                st.success(f"Question '{new_question}' added successfully.")
            else:
                st.warning("Question already exists or is empty!")

        # Display multiselect with updated user questions
        selected_questions = st.multiselect(
            "Select your custom questions:",
            options=st.session_state.user_questions,
            default=st.session_state.user_questions
        )

    # Display selected questions
    st.write("Selected Questions:", selected_questions)

    # Choose execution mode
    execution_mode = st.radio("Execution Mode:", ["Sequential", "Multithreaded"])

    # If multithreaded, allow user to configure thread pool size
    if execution_mode == "Multithreaded":
        max_threads = st.slider("Maximum Number of Threads:", 1, 10, 4)  # Default to 4 threads
    else:
        max_threads = None  # For sequential mode



    # Benchmark Execution
    if st.button("Start Benchmark"):
        if not selected_questions:
            st.warning("Please select at least one question.")
        else:
            num_questions = len(selected_questions)
            results = []

            # Stop button (not implemented yet)
            stop_button = st.button("Stop Benchmark")

            # Benchmarking logic using the chosen execution mode
            if execution_mode == "Sequential":
                question_results = benchmark_model_sequential(model_name, selected_questions, st.session_state.open_router_key, st.session_state.openai_api_key,judge_model_name,coherence_threshold,novelty_threshold,temp_threshold,top_p)
            else:  # Multithreaded
                question_results = benchmark_model_multithreaded(model_name, selected_questions, st.session_state.open_router_key, st.session_state.openai_api_key, max_threads, judge_model_name, coherence_threshold,novelty_threshold,temp_threshold,top_p)

            results.extend(question_results)

            # Display results in a table
            st.write("Results:")
            results_table = []
            for result in results:
                for answer in result["answers"]:
                    results_table.append({
                        "Question": result["question"],
                        "Answer": answer,
                        "Contestant Model": model_name,
                        "Judge Model": judge_model_name,
                        "Coherence Score": result["coherence_score"],
                        "Novelty Score": result["novelty_score"]
                    })
            st.table(results_table)

            df = pd.DataFrame(results_table)  # Create a Pandas DataFrame from the results
            csv = df.to_csv(index=False).encode('utf-8')  # Convert DataFrame to CSV
            st.download_button(
                label="Export Results as CSV",
                data=csv,
                file_name="benchmark_results.csv",
                mime='text/csv'
                )

            if stop_button:
                st.warning("Partial results displayed due to interruption.")
            else:
                st.success("Benchmark completed!")

else:
    st.warning("Please confirm your API keys first.")