import json from typing import Any import gradio as gr import numpy as np import pandas as pd from datasets import Dataset from loguru import logger from app_configs import CONFIGS, UNSELECTED_PIPELINE_NAME from components import commons from components.model_pipeline.tossup_pipeline import TossupPipelineInterface, TossupPipelineState from components.typed_dicts import TossupInterfaceDefaults, TossupPipelineStateDict from display.formatting import styled_error from submission import submit from workflows import factory from workflows.qb_agents import QuizBowlTossupAgent, TossupResult from . import populate, validation from .plotting import ( create_tossup_confidence_pyplot, create_tossup_eval_dashboard, create_tossup_eval_table, create_tossup_html, prepare_tossup_results_df, ) from .utils import evaluate_prediction # TODO: Error handling on run tossup and evaluate tossup and show correct messages # TODO: ^^ Same for Bonus class ScoredTossupResult(TossupResult): """Result of a tossup question with evaluation score and position.""" score: int # Correctness score of the answer token_position: int # Position in the question where prediction was made def add_model_scores(run_outputs: list[dict], clean_answers: list[str], run_indices: list[int]) -> list[dict]: """Add model scores to the model outputs.""" for output in run_outputs: output["score"] = evaluate_prediction(output["answer"], clean_answers) output["token_position"] = run_indices[output["position"] - 1] return run_outputs def prepare_buzz_evals( run_indices: list[int], model_outputs: list[dict] ) -> tuple[list[str], list[tuple[int, float, bool]]]: """Process text into tokens and assign random values for demonstration.""" if not run_indices: logger.warning("No run indices provided, returning empty results") return [], [] eval_points = [] for o in model_outputs: token_position = run_indices[o["position"] - 1] eval_points.append((token_position, o)) return eval_points def initialize_eval_interface( example: dict, run_outputs: list[dict], input_vars: list, confidence_threshold: float, prob_threshold: float | None = None, ): """Initialize the interface with example text.""" try: tokens = example["question"].split() run_indices = example["run_indices"] answer = example["answer_primary"] clean_answers = example["clean_answers"] eval_points = [(o["token_position"], o) for o in run_outputs] if not tokens: return "
No tokens found in the provided text.
", pd.DataFrame(), "{}" html_content = create_tossup_html(tokens, answer, clean_answers, run_indices, eval_points) plot_data = create_tossup_confidence_pyplot(tokens, eval_points, confidence_threshold, prob_threshold) # Store tokens, values, and buzzes as JSON for later use state = {"tokens": tokens, "values": eval_points} # Preparing step outputs for the model step_outputs = {} for output in run_outputs: tok_pos = output["token_position"] key = "{pos}:{token}".format(pos=tok_pos + 1, token=tokens[tok_pos]) step_outputs[key] = {k: v for k, v in output["step_outputs"].items() if k not in input_vars} if output["logprob"] is not None: step_outputs[key]["output_probability"] = float(np.exp(output["logprob"])) return html_content, plot_data, state, step_outputs except Exception as e: logger.exception(f"Error initializing interface: {e.args}") return f"
Error initializing interface: {str(e)}
", pd.DataFrame(), "{}", {} def process_tossup_results(results: list[dict]) -> pd.DataFrame: """Process results from tossup mode and prepare visualization data.""" data = [] for r in results: entry = { "Token Position": r["token_position"], "Correct?": "✅" if r["score"] == 1 else "❌", "Confidence": r["confidence"], } if r["logprob"] is not None: entry["Probability"] = f"{np.exp(r['logprob']):.3f}" entry["Prediction"] = r["answer"] data.append(entry) return pd.DataFrame(data) class TossupInterface: """Gradio interface for the Tossup mode.""" def __init__( self, app: gr.Blocks, browser_state: gr.BrowserState, dataset: Dataset, model_options: dict, defaults: TossupInterfaceDefaults, ): """Initialize the Tossup interface.""" logger.info(f"Initializing Tossup interface with dataset size: {len(dataset)}") self.browser_state = browser_state self.ds = dataset self.model_options = model_options self.app = app self.defaults = defaults self.output_state = gr.State(value={}) self.render() # ------------------------------------- LOAD PIPELINE STATE FROM BROWSER STATE ------------------------------------- def load_presaved_pipeline_state(self, browser_state: dict, pipeline_change: bool): logger.debug(f"Loading presaved pipeline state from browser state:\n{json.dumps(browser_state, indent=4)}") try: state_dict = browser_state["tossup"].get("pipeline_state", {}) pipeline_state = TossupPipelineState.model_validate(state_dict) pipeline_state_dict = pipeline_state.model_dump() output_state = browser_state["tossup"].get("output_state", {}) except Exception as e: logger.warning(f"Error loading presaved pipeline state: {e}") output_state = {} workflow = self.defaults["init_workflow"] pipeline_state_dict = TossupPipelineState.from_workflow(workflow).model_dump() return browser_state, not pipeline_change, pipeline_state_dict, output_state # ------------------------------------------ INTERFACE RENDER FUNCTIONS ------------------------------------------- def _render_pipeline_interface(self, pipeline_state: TossupPipelineState): """Render the model interface.""" with gr.Row(elem_classes="bonus-header-row form-inline"): self.pipeline_selector = commons.get_pipeline_selector([]) self.load_btn = gr.Button("⬇️ Import Pipeline", variant="secondary") self.import_error_display = gr.HTML(label="Import Error", elem_id="import-error-display", visible=False) self.pipeline_interface = TossupPipelineInterface( self.app, pipeline_state.workflow, ui_state=pipeline_state.ui_state, model_options=list(self.model_options.keys()), config=self.defaults, ) def _render_qb_interface(self): """Render the quizbowl interface.""" with gr.Row(elem_classes="bonus-header-row form-inline"): self.qid_selector = commons.get_qid_selector(len(self.ds)) self.early_stop_checkbox = gr.Checkbox( value=self.defaults["early_stop"], label="Early Stop", info="Stop if already buzzed", scale=0, ) self.run_btn = gr.Button("Run on Tossup Question", variant="secondary") self.question_display = gr.HTML(label="Question", elem_id="tossup-question-display") self.error_display = gr.HTML(label="Error", elem_id="tossup-error-display", visible=False) with gr.Row(): self.confidence_plot = gr.Plot( label="Buzz Confidence", format="webp", ) self.model_outputs_display = gr.JSON(label="Model Outputs", value="{}", show_indices=True, visible=False) self.results_table = gr.DataFrame( label="Model Outputs", value=pd.DataFrame(columns=["Token Position", "Correct?", "Confidence", "Prediction"]), visible=False, ) with gr.Row(): self.eval_btn = gr.Button("Evaluate", variant="primary") self.model_name_input, self.description_input, self.submit_btn, self.submit_status = ( commons.get_model_submission_accordion(self.app) ) def render(self): """Create the Gradio interface.""" workflow = factory.create_empty_tossup_workflow() pipeline_state = TossupPipelineState.from_workflow(workflow) self.hidden_input = gr.Textbox(value="", visible=False, elem_id="hidden-index") with gr.Row(): # Model Panel with gr.Column(scale=1): self._render_pipeline_interface(pipeline_state) with gr.Column(scale=1): self._render_qb_interface() self._setup_event_listeners() # ------------------------------------- Component Updates Functions --------------------------------------------- def get_new_question_html(self, question_id: int) -> str: """Get the HTML for a new question.""" if question_id is None: logger.error("Question ID is None. Setting to 1") question_id = 1 try: example = self.ds[question_id - 1] question_tokens = example["question"].split() return create_tossup_html( question_tokens, example["answer_primary"], example["clean_answers"], example["run_indices"] ) except Exception as e: return f"Error loading question: {str(e)}" def get_pipeline_names(self, profile: gr.OAuthProfile | None) -> list[str]: names = [UNSELECTED_PIPELINE_NAME] + populate.get_pipeline_names("tossup", profile) return gr.update(choices=names, value=UNSELECTED_PIPELINE_NAME) def load_pipeline( self, model_name: str, pipeline_change: bool, profile: gr.OAuthProfile | None ) -> tuple[str, bool, TossupPipelineStateDict, dict]: try: workflow = populate.load_workflow("tossup", model_name, profile) if workflow is None: logger.warning(f"Could not load workflow for {model_name}") return UNSELECTED_PIPELINE_NAME, gr.skip(), gr.skip(), gr.update(visible=False) pipeline_state_dict = TossupPipelineState.from_workflow(workflow).model_dump() return UNSELECTED_PIPELINE_NAME, not pipeline_change, pipeline_state_dict, gr.update(visible=True) except Exception as e: logger.exception(e) error_msg = styled_error(f"Error loading pipeline: {str(e)}") return UNSELECTED_PIPELINE_NAME, gr.skip(), gr.skip(), gr.update(visible=True, value=error_msg) # ------------------------------------- Agent Functions ----------------------------------------------------------- def get_agent_outputs( self, example: dict, pipeline_state: TossupPipelineState, early_stop: bool ) -> list[ScoredTossupResult]: """Get the model outputs for a given question ID.""" question_runs = [] tokens = example["question"].split() for run_idx in example["run_indices"]: question_runs.append(" ".join(tokens[: run_idx + 1])) agent = QuizBowlTossupAgent(pipeline_state.workflow) outputs = list(agent.run(question_runs, early_stop=early_stop)) outputs = add_model_scores(outputs, example["clean_answers"], example["run_indices"]) return outputs def single_run( self, question_id: int, state_dict: TossupPipelineStateDict, early_stop: bool = True, ) -> tuple[str, Any, Any]: """Run the agent in tossup mode with a system prompt. Returns: tuple: A tuple containing: - tokens_html (str): HTML representation of the tossup question with buzz indicators - output_state (gr.update): Update for the output state component - plot_data (gr.update): Update for the confidence plot with label and visibility - df (gr.update): Update for the dataframe component showing model outputs - step_outputs (gr.update): Update for the step outputs component - error_msg (gr.update): Update for the error message component (hidden if no errors) """ try: pipeline_state = validation.validate_tossup_workflow(state_dict) workflow = pipeline_state.workflow # Validate inputs question_id = int(question_id - 1) if not self.ds or question_id < 0 or question_id >= len(self.ds): raise gr.Error("Invalid question ID or dataset not loaded") example = self.ds[question_id] outputs = self.get_agent_outputs(example, pipeline_state, early_stop) # Process results and prepare visualization data confidence_threshold = workflow.buzzer.confidence_threshold prob_threshold = workflow.buzzer.prob_threshold tokens_html, plot_data, output_state, step_outputs = initialize_eval_interface( example, outputs, workflow.inputs, confidence_threshold, prob_threshold ) df = process_tossup_results(outputs) return ( tokens_html, gr.update(value=output_state), gr.update(value=plot_data, label=f"Buzz Confidence on Question {question_id + 1}", show_label=True), gr.update(value=df, label=f"Model Outputs for Question {question_id + 1}", visible=True), gr.update(value=step_outputs, label=f"Step Outputs for Question {question_id + 1}", visible=True), gr.update(visible=False), ) except Exception as e: import traceback error_msg = styled_error(f"Error: {str(e)}\n{traceback.format_exc()}") return ( gr.skip(), gr.skip(), gr.skip(), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True, value=error_msg), ) def evaluate(self, state_dict: TossupPipelineStateDict, progress: gr.Progress = gr.Progress()): """Evaluate the tossup questions.""" try: # Validate inputs if not self.ds or not self.ds.num_rows: return "No dataset loaded", None, None pipeline_state = validation.validate_tossup_workflow(state_dict) model_outputs = [] for example in progress.tqdm(self.ds, desc="Evaluating tossup questions"): run_outputs = self.get_agent_outputs(example, pipeline_state, early_stop=True) model_outputs.append(run_outputs) eval_df = prepare_tossup_results_df(self.ds["run_indices"], model_outputs) plot_data = create_tossup_eval_dashboard(self.ds["run_indices"], eval_df) output_df = create_tossup_eval_table(eval_df) return ( gr.update(value=plot_data, label="Buzz Positions on Sample Set", show_label=False), gr.update(value=output_df, label="(Mean) Metrics on Sample Set", visible=True), gr.update(visible=False), gr.update(visible=False), ) except Exception as e: import traceback logger.exception(f"Error evaluating tossups: {e.args}") return ( gr.skip(), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True, value=styled_error(f"Error: {str(e)}")), ) def submit_model( self, model_name: str, description: str, state_dict: TossupPipelineStateDict, profile: gr.OAuthProfile = None, ) -> str: """Submit the model output.""" try: pipeline_state = validation.validate_tossup_workflow(state_dict) return submit.submit_model(model_name, description, pipeline_state.workflow, "tossup", profile) except Exception as e: logger.exception(f"Error submitting model: {e.args}") return styled_error(f"Error: {str(e)}") @property def pipeline_state(self): return self.pipeline_interface.pipeline_state # ------------------------------------- Event Listeners ----------------------------------------------------------- def _setup_event_listeners(self): gr.on( triggers=[self.app.load, self.qid_selector.change], fn=self.get_new_question_html, inputs=[self.qid_selector], outputs=[self.question_display], ) gr.on( triggers=[self.app.load], fn=self.get_pipeline_names, outputs=[self.pipeline_selector], ) pipeline_change = self.pipeline_interface.pipeline_change gr.on( triggers=[self.app.load], fn=self.load_presaved_pipeline_state, inputs=[self.browser_state, pipeline_change], outputs=[self.browser_state, pipeline_change, self.pipeline_state, self.output_state], ) self.load_btn.click( fn=self.load_pipeline, inputs=[self.pipeline_selector, pipeline_change], outputs=[self.pipeline_selector, pipeline_change, self.pipeline_state, self.import_error_display], ) self.pipeline_interface.add_triggers_for_pipeline_export([self.pipeline_state.change], self.pipeline_state) self.run_btn.click( self.single_run, inputs=[ self.qid_selector, self.pipeline_state, self.early_stop_checkbox, ], outputs=[ self.question_display, self.output_state, self.confidence_plot, self.results_table, self.model_outputs_display, self.error_display, ], ) self.eval_btn.click( fn=self.evaluate, inputs=[self.pipeline_state], outputs=[self.confidence_plot, self.results_table, self.model_outputs_display, self.error_display], ) self.submit_btn.click( fn=self.submit_model, inputs=[ self.model_name_input, self.description_input, self.pipeline_state, ], outputs=[self.submit_status], )