Spaces:
Runtime error
Runtime error
File size: 6,637 Bytes
f0250b1 c9e00de 15bbe10 f0250b1 2cb6075 1fd9ae1 15bbe10 2cb6075 15bbe10 a4e6a71 15bbe10 a4e6a71 15bbe10 a4e6a71 c9e00de 9d92eeb a4e6a71 c9e00de a4e6a71 c9e00de a4e6a71 c9e00de a4e6a71 c9e00de 9d92eeb c9e00de 9d92eeb c9e00de a4e6a71 c9e00de a4e6a71 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import streamlit as st
from main import benchmark_model_multithreaded, benchmark_model_sequential
from prompts import questions as predefined_questions
import requests
# Set the title in the browser tab
st.set_page_config(page_title="Aidan Bench - Generator")
st.title("Aidan Bench - Generator")
# API Key Inputs with Security and User Experience Enhancements
st.warning("Please keep your API keys secure and confidential. This app does not store or log your API keys.")
if "open_router_key" not in st.session_state:
st.session_state.open_router_key = ""
if "openai_api_key" not in st.session_state:
st.session_state.openai_api_key = ""
open_router_key = st.text_input("Enter your Open Router API Key:", type="password", value=st.session_state.open_router_key)
openai_api_key = st.text_input("Enter your OpenAI API Key:", type="password", value=st.session_state.openai_api_key)
if st.button("Confirm API Keys"):
if open_router_key and openai_api_key:
st.session_state.open_router_key = open_router_key
st.session_state.openai_api_key = openai_api_key
st.success("API keys confirmed!")
else:
st.warning("Please enter both API keys.")
# Access API keys from session state
if st.session_state.open_router_key and st.session_state.openai_api_key:
# Fetch models from OpenRouter API
try:
response = requests.get("https://openrouter.ai/api/v1/models")
response.raise_for_status() # Raise an exception for bad status codes
models = response.json()["data"]
# Sort models alphabetically by their ID
models.sort(key=lambda model: model["id"])
model_names = [model["id"] for model in models]
except requests.exceptions.RequestException as e:
st.error(f"Error fetching models from OpenRouter API: {e}")
model_names = [] # Provide an empty list if API call fails
# Model Selection
if model_names:
model_name = st.selectbox("Select a Language Model", model_names)
else:
st.error("No models available. Please check your API connection.")
st.stop() # Stop execution if no models are available
# Initialize session state for user_questions and predefined_questions
if "user_questions" not in st.session_state:
st.session_state.user_questions = []
# Workflow Selection
workflow = st.radio("Select Workflow:", ["Use Predefined Questions", "Use User-Defined Questions"])
# Handle Predefined Questions
if workflow == "Use Predefined Questions":
st.header("Question Selection")
# Multiselect for predefined questions
selected_questions = st.multiselect(
"Select questions to benchmark:",
predefined_questions,
predefined_questions # Select all by default
)
# Handle User-Defined Questions
elif workflow == "Use User-Defined Questions":
st.header("Question Input")
# Input for adding a new question
new_question = st.text_input("Enter a new question:")
if st.button("Add Question") and new_question:
new_question = new_question.strip() # Remove leading/trailing whitespace
if new_question and new_question not in st.session_state.user_questions:
st.session_state.user_questions.append(new_question) # Append to session state
st.success(f"Question '{new_question}' added successfully.")
else:
st.warning("Question already exists or is empty!")
# Display multiselect with updated user questions
selected_questions = st.multiselect(
"Select your custom questions:",
options=st.session_state.user_questions,
default=st.session_state.user_questions
)
# Display selected questions
st.write("Selected Questions:", selected_questions)
# Choose execution mode
execution_mode = st.radio("Execution Mode:", ["Sequential", "Multithreaded"])
# If multithreaded, allow user to configure thread pool size
if execution_mode == "Multithreaded":
max_threads = st.slider("Maximum Number of Threads:", 1, 10, 4) # Default to 4 threads
else:
max_threads = None # For sequential mode
# Benchmark Execution
if st.button("Start Benchmark"):
if not selected_questions:
st.warning("Please select at least one question.")
else:
# Initialize progress bar
progress_bar = st.progress(0)
num_questions = len(selected_questions)
results = []
# Stop button
stop_button = st.button("Stop Benchmark")
# Benchmarking loop
for i, question in enumerate(selected_questions):
# Display current question
st.write(f"Processing question {i+1}/{num_questions}: {question}")
# ... (benchmarking logic using the chosen execution mode)
if execution_mode == "Sequential":
question_results = benchmark_model_sequential(model_name, selected_questions, st.session_state.open_router_key, st.session_state.openai_api_key)
else: # Multithreaded
question_results = benchmark_model_multithreaded(model_name, selected_questions, st.session_state.open_router_key, st.session_state.openai_api_key, max_threads)
results.extend(question_results)
# Update progress bar
progress_bar.progress((i + 1) / num_questions)
# Check if stop button is clicked
if stop_button:
st.warning("Benchmark stopped!")
break # Exit the loop
# Display results (even if interrupted)
st.write("Results:")
# ... (table generation logic - Same as before)
if stop_button:
st.warning("Partial results displayed due to interruption.")
else:
st.success("Benchmark completed!")
# Display results in a table
st.write("Results:")
results_table = []
for result in results:
for answer in result["answers"]:
results_table.append({
"Question": result["question"],
"Answer": answer,
"Coherence Score": result["coherence_score"],
"Novelty Score": result["novelty_score"]
})
st.table(results_table)
else:
st.warning("Please confirm your API keys first.")
|