import json
from typing import Any
import gradio as gr
import numpy as np
import pandas as pd
from datasets import Dataset
from loguru import logger
from app_configs import CONFIGS, UNSELECTED_PIPELINE_NAME
from components import commons
from components.model_pipeline.tossup_pipeline import TossupPipelineInterface, TossupPipelineState
from components.typed_dicts import TossupInterfaceDefaults, TossupPipelineStateDict
from display.formatting import styled_error
from submission import submit
from workflows import factory
from workflows.qb_agents import QuizBowlTossupAgent, TossupResult
from . import populate, validation
from .plotting import (
create_tossup_confidence_pyplot,
create_tossup_eval_dashboard,
create_tossup_eval_table,
create_tossup_html,
prepare_tossup_results_df,
)
from .utils import evaluate_prediction
# TODO: Error handling on run tossup and evaluate tossup and show correct messages
# TODO: ^^ Same for Bonus
class ScoredTossupResult(TossupResult):
"""Result of a tossup question with evaluation score and position."""
score: int # Correctness score of the answer
token_position: int # Position in the question where prediction was made
def add_model_scores(run_outputs: list[dict], clean_answers: list[str], run_indices: list[int]) -> list[dict]:
"""Add model scores to the model outputs."""
for output in run_outputs:
output["score"] = evaluate_prediction(output["answer"], clean_answers)
output["token_position"] = run_indices[output["position"] - 1]
return run_outputs
def prepare_buzz_evals(
run_indices: list[int], model_outputs: list[dict]
) -> tuple[list[str], list[tuple[int, float, bool]]]:
"""Process text into tokens and assign random values for demonstration."""
if not run_indices:
logger.warning("No run indices provided, returning empty results")
return [], []
eval_points = []
for o in model_outputs:
token_position = run_indices[o["position"] - 1]
eval_points.append((token_position, o))
return eval_points
def initialize_eval_interface(
example: dict,
run_outputs: list[dict],
input_vars: list,
confidence_threshold: float,
prob_threshold: float | None = None,
):
"""Initialize the interface with example text."""
try:
tokens = example["question"].split()
run_indices = example["run_indices"]
answer = example["answer_primary"]
clean_answers = example["clean_answers"]
eval_points = [(o["token_position"], o) for o in run_outputs]
if not tokens:
return "
No tokens found in the provided text.
", pd.DataFrame(), "{}"
html_content = create_tossup_html(tokens, answer, clean_answers, run_indices, eval_points)
plot_data = create_tossup_confidence_pyplot(tokens, eval_points, confidence_threshold, prob_threshold)
# Store tokens, values, and buzzes as JSON for later use
state = {"tokens": tokens, "values": eval_points}
# Preparing step outputs for the model
step_outputs = {}
for output in run_outputs:
tok_pos = output["token_position"]
key = "{pos}:{token}".format(pos=tok_pos + 1, token=tokens[tok_pos])
step_outputs[key] = {k: v for k, v in output["step_outputs"].items() if k not in input_vars}
if output["logprob"] is not None:
step_outputs[key]["output_probability"] = float(np.exp(output["logprob"]))
return html_content, plot_data, state, step_outputs
except Exception as e:
logger.exception(f"Error initializing interface: {e.args}")
return f"Error initializing interface: {str(e)}
", pd.DataFrame(), "{}", {}
def process_tossup_results(results: list[dict]) -> pd.DataFrame:
"""Process results from tossup mode and prepare visualization data."""
data = []
for r in results:
entry = {
"Token Position": r["token_position"],
"Correct?": "✅" if r["score"] == 1 else "❌",
"Confidence": r["confidence"],
}
if r["logprob"] is not None:
entry["Probability"] = f"{np.exp(r['logprob']):.3f}"
entry["Prediction"] = r["answer"]
data.append(entry)
return pd.DataFrame(data)
class TossupInterface:
"""Gradio interface for the Tossup mode."""
def __init__(
self,
app: gr.Blocks,
browser_state: gr.BrowserState,
dataset: Dataset,
model_options: dict,
defaults: TossupInterfaceDefaults,
):
"""Initialize the Tossup interface."""
logger.info(f"Initializing Tossup interface with dataset size: {len(dataset)}")
self.browser_state = browser_state
self.ds = dataset
self.model_options = model_options
self.app = app
self.defaults = defaults
self.output_state = gr.State(value={})
self.render()
# ------------------------------------- LOAD PIPELINE STATE FROM BROWSER STATE -------------------------------------
def load_presaved_pipeline_state(self, browser_state: dict, pipeline_change: bool):
logger.debug(f"Loading presaved pipeline state from browser state:\n{json.dumps(browser_state, indent=4)}")
try:
state_dict = browser_state["tossup"].get("pipeline_state", {})
pipeline_state = TossupPipelineState.model_validate(state_dict)
pipeline_state_dict = pipeline_state.model_dump()
output_state = browser_state["tossup"].get("output_state", {})
except Exception as e:
logger.warning(f"Error loading presaved pipeline state: {e}")
output_state = {}
workflow = self.defaults["init_workflow"]
pipeline_state_dict = TossupPipelineState.from_workflow(workflow).model_dump()
return browser_state, not pipeline_change, pipeline_state_dict, output_state
# ------------------------------------------ INTERFACE RENDER FUNCTIONS -------------------------------------------
def _render_pipeline_interface(self, pipeline_state: TossupPipelineState):
"""Render the model interface."""
with gr.Row(elem_classes="bonus-header-row form-inline"):
self.pipeline_selector = commons.get_pipeline_selector([])
self.load_btn = gr.Button("⬇️ Import Pipeline", variant="secondary")
self.import_error_display = gr.HTML(label="Import Error", elem_id="import-error-display", visible=False)
self.pipeline_interface = TossupPipelineInterface(
self.app,
pipeline_state.workflow,
ui_state=pipeline_state.ui_state,
model_options=list(self.model_options.keys()),
config=self.defaults,
)
def _render_qb_interface(self):
"""Render the quizbowl interface."""
with gr.Row(elem_classes="bonus-header-row form-inline"):
self.qid_selector = commons.get_qid_selector(len(self.ds))
self.early_stop_checkbox = gr.Checkbox(
value=self.defaults["early_stop"],
label="Early Stop",
info="Stop if already buzzed",
scale=0,
)
self.run_btn = gr.Button("Run on Tossup Question", variant="secondary")
self.question_display = gr.HTML(label="Question", elem_id="tossup-question-display")
self.error_display = gr.HTML(label="Error", elem_id="tossup-error-display", visible=False)
with gr.Row():
self.confidence_plot = gr.Plot(
label="Buzz Confidence",
format="webp",
)
self.model_outputs_display = gr.JSON(label="Model Outputs", value="{}", show_indices=True, visible=False)
self.results_table = gr.DataFrame(
label="Model Outputs",
value=pd.DataFrame(columns=["Token Position", "Correct?", "Confidence", "Prediction"]),
visible=False,
)
with gr.Row():
self.eval_btn = gr.Button("Evaluate", variant="primary")
self.model_name_input, self.description_input, self.submit_btn, self.submit_status = (
commons.get_model_submission_accordion(self.app)
)
def render(self):
"""Create the Gradio interface."""
workflow = factory.create_empty_tossup_workflow()
pipeline_state = TossupPipelineState.from_workflow(workflow)
self.hidden_input = gr.Textbox(value="", visible=False, elem_id="hidden-index")
with gr.Row():
# Model Panel
with gr.Column(scale=1):
self._render_pipeline_interface(pipeline_state)
with gr.Column(scale=1):
self._render_qb_interface()
self._setup_event_listeners()
# ------------------------------------- Component Updates Functions ---------------------------------------------
def get_new_question_html(self, question_id: int) -> str:
"""Get the HTML for a new question."""
if question_id is None:
logger.error("Question ID is None. Setting to 1")
question_id = 1
try:
example = self.ds[question_id - 1]
question_tokens = example["question"].split()
return create_tossup_html(
question_tokens, example["answer_primary"], example["clean_answers"], example["run_indices"]
)
except Exception as e:
return f"Error loading question: {str(e)}"
def get_pipeline_names(self, profile: gr.OAuthProfile | None) -> list[str]:
names = [UNSELECTED_PIPELINE_NAME] + populate.get_pipeline_names("tossup", profile)
return gr.update(choices=names, value=UNSELECTED_PIPELINE_NAME)
def load_pipeline(
self, model_name: str, pipeline_change: bool, profile: gr.OAuthProfile | None
) -> tuple[str, bool, TossupPipelineStateDict, dict]:
try:
workflow = populate.load_workflow("tossup", model_name, profile)
if workflow is None:
logger.warning(f"Could not load workflow for {model_name}")
return UNSELECTED_PIPELINE_NAME, gr.skip(), gr.skip(), gr.update(visible=False)
pipeline_state_dict = TossupPipelineState.from_workflow(workflow).model_dump()
return UNSELECTED_PIPELINE_NAME, not pipeline_change, pipeline_state_dict, gr.update(visible=True)
except Exception as e:
logger.exception(e)
error_msg = styled_error(f"Error loading pipeline: {str(e)}")
return UNSELECTED_PIPELINE_NAME, gr.skip(), gr.skip(), gr.update(visible=True, value=error_msg)
# ------------------------------------- Agent Functions -----------------------------------------------------------
def get_agent_outputs(
self, example: dict, pipeline_state: TossupPipelineState, early_stop: bool
) -> list[ScoredTossupResult]:
"""Get the model outputs for a given question ID."""
question_runs = []
tokens = example["question"].split()
for run_idx in example["run_indices"]:
question_runs.append(" ".join(tokens[: run_idx + 1]))
agent = QuizBowlTossupAgent(pipeline_state.workflow)
outputs = list(agent.run(question_runs, early_stop=early_stop))
outputs = add_model_scores(outputs, example["clean_answers"], example["run_indices"])
return outputs
def single_run(
self,
question_id: int,
state_dict: TossupPipelineStateDict,
early_stop: bool = True,
) -> tuple[str, Any, Any]:
"""Run the agent in tossup mode with a system prompt.
Returns:
tuple: A tuple containing:
- tokens_html (str): HTML representation of the tossup question with buzz indicators
- output_state (gr.update): Update for the output state component
- plot_data (gr.update): Update for the confidence plot with label and visibility
- df (gr.update): Update for the dataframe component showing model outputs
- step_outputs (gr.update): Update for the step outputs component
- error_msg (gr.update): Update for the error message component (hidden if no errors)
"""
try:
pipeline_state = validation.validate_tossup_workflow(state_dict)
workflow = pipeline_state.workflow
# Validate inputs
question_id = int(question_id - 1)
if not self.ds or question_id < 0 or question_id >= len(self.ds):
raise gr.Error("Invalid question ID or dataset not loaded")
example = self.ds[question_id]
outputs = self.get_agent_outputs(example, pipeline_state, early_stop)
# Process results and prepare visualization data
confidence_threshold = workflow.buzzer.confidence_threshold
prob_threshold = workflow.buzzer.prob_threshold
tokens_html, plot_data, output_state, step_outputs = initialize_eval_interface(
example, outputs, workflow.inputs, confidence_threshold, prob_threshold
)
df = process_tossup_results(outputs)
return (
tokens_html,
gr.update(value=output_state),
gr.update(value=plot_data, label=f"Buzz Confidence on Question {question_id + 1}", show_label=True),
gr.update(value=df, label=f"Model Outputs for Question {question_id + 1}", visible=True),
gr.update(value=step_outputs, label=f"Step Outputs for Question {question_id + 1}", visible=True),
gr.update(visible=False),
)
except Exception as e:
import traceback
error_msg = styled_error(f"Error: {str(e)}\n{traceback.format_exc()}")
return (
gr.skip(),
gr.skip(),
gr.skip(),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=True, value=error_msg),
)
def evaluate(self, state_dict: TossupPipelineStateDict, progress: gr.Progress = gr.Progress()):
"""Evaluate the tossup questions."""
try:
# Validate inputs
if not self.ds or not self.ds.num_rows:
return "No dataset loaded", None, None
pipeline_state = validation.validate_tossup_workflow(state_dict)
model_outputs = []
for example in progress.tqdm(self.ds, desc="Evaluating tossup questions"):
run_outputs = self.get_agent_outputs(example, pipeline_state, early_stop=True)
model_outputs.append(run_outputs)
eval_df = prepare_tossup_results_df(self.ds["run_indices"], model_outputs)
plot_data = create_tossup_eval_dashboard(self.ds["run_indices"], eval_df)
output_df = create_tossup_eval_table(eval_df)
return (
gr.update(value=plot_data, label="Buzz Positions on Sample Set", show_label=False),
gr.update(value=output_df, label="(Mean) Metrics on Sample Set", visible=True),
gr.update(visible=False),
gr.update(visible=False),
)
except Exception as e:
import traceback
logger.exception(f"Error evaluating tossups: {e.args}")
return (
gr.skip(),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=True, value=styled_error(f"Error: {str(e)}")),
)
def submit_model(
self,
model_name: str,
description: str,
state_dict: TossupPipelineStateDict,
profile: gr.OAuthProfile = None,
) -> str:
"""Submit the model output."""
try:
pipeline_state = validation.validate_tossup_workflow(state_dict)
return submit.submit_model(model_name, description, pipeline_state.workflow, "tossup", profile)
except Exception as e:
logger.exception(f"Error submitting model: {e.args}")
return styled_error(f"Error: {str(e)}")
@property
def pipeline_state(self):
return self.pipeline_interface.pipeline_state
# ------------------------------------- Event Listeners -----------------------------------------------------------
def _setup_event_listeners(self):
gr.on(
triggers=[self.app.load, self.qid_selector.change],
fn=self.get_new_question_html,
inputs=[self.qid_selector],
outputs=[self.question_display],
)
gr.on(
triggers=[self.app.load],
fn=self.get_pipeline_names,
outputs=[self.pipeline_selector],
)
pipeline_change = self.pipeline_interface.pipeline_change
gr.on(
triggers=[self.app.load],
fn=self.load_presaved_pipeline_state,
inputs=[self.browser_state, pipeline_change],
outputs=[self.browser_state, pipeline_change, self.pipeline_state, self.output_state],
)
self.load_btn.click(
fn=self.load_pipeline,
inputs=[self.pipeline_selector, pipeline_change],
outputs=[self.pipeline_selector, pipeline_change, self.pipeline_state, self.import_error_display],
)
self.pipeline_interface.add_triggers_for_pipeline_export([self.pipeline_state.change], self.pipeline_state)
self.run_btn.click(
self.single_run,
inputs=[
self.qid_selector,
self.pipeline_state,
self.early_stop_checkbox,
],
outputs=[
self.question_display,
self.output_state,
self.confidence_plot,
self.results_table,
self.model_outputs_display,
self.error_display,
],
)
self.eval_btn.click(
fn=self.evaluate,
inputs=[self.pipeline_state],
outputs=[self.confidence_plot, self.results_table, self.model_outputs_display, self.error_display],
)
self.submit_btn.click(
fn=self.submit_model,
inputs=[
self.model_name_input,
self.description_input,
self.pipeline_state,
],
outputs=[self.submit_status],
)