|
import json |
|
from typing import Any |
|
|
|
import gradio as gr |
|
import numpy as np |
|
import pandas as pd |
|
from datasets import Dataset |
|
from loguru import logger |
|
|
|
from app_configs import CONFIGS, UNSELECTED_PIPELINE_NAME |
|
from components import commons |
|
from components.model_pipeline.tossup_pipeline import TossupPipelineInterface, TossupPipelineState |
|
from components.typed_dicts import TossupInterfaceDefaults, TossupPipelineStateDict |
|
from display.formatting import styled_error |
|
from submission import submit |
|
from workflows import factory |
|
from workflows.qb_agents import QuizBowlTossupAgent, TossupResult |
|
|
|
from . import populate, validation |
|
from .plotting import ( |
|
create_tossup_confidence_pyplot, |
|
create_tossup_eval_dashboard, |
|
create_tossup_eval_table, |
|
create_tossup_html, |
|
prepare_tossup_results_df, |
|
) |
|
from .utils import evaluate_prediction |
|
|
|
|
|
|
|
|
|
|
|
class ScoredTossupResult(TossupResult): |
|
"""Result of a tossup question with evaluation score and position.""" |
|
|
|
score: int |
|
token_position: int |
|
|
|
|
|
def add_model_scores(run_outputs: list[dict], clean_answers: list[str], run_indices: list[int]) -> list[dict]: |
|
"""Add model scores to the model outputs.""" |
|
for output in run_outputs: |
|
output["score"] = evaluate_prediction(output["answer"], clean_answers) |
|
output["token_position"] = run_indices[output["position"] - 1] |
|
return run_outputs |
|
|
|
|
|
def prepare_buzz_evals( |
|
run_indices: list[int], model_outputs: list[dict] |
|
) -> tuple[list[str], list[tuple[int, float, bool]]]: |
|
"""Process text into tokens and assign random values for demonstration.""" |
|
if not run_indices: |
|
logger.warning("No run indices provided, returning empty results") |
|
return [], [] |
|
eval_points = [] |
|
for o in model_outputs: |
|
token_position = run_indices[o["position"] - 1] |
|
eval_points.append((token_position, o)) |
|
|
|
return eval_points |
|
|
|
|
|
def initialize_eval_interface( |
|
example: dict, |
|
run_outputs: list[dict], |
|
input_vars: list, |
|
confidence_threshold: float, |
|
prob_threshold: float | None = None, |
|
): |
|
"""Initialize the interface with example text.""" |
|
try: |
|
tokens = example["question"].split() |
|
run_indices = example["run_indices"] |
|
answer = example["answer_primary"] |
|
clean_answers = example["clean_answers"] |
|
eval_points = [(o["token_position"], o) for o in run_outputs] |
|
|
|
if not tokens: |
|
return "<div>No tokens found in the provided text.</div>", pd.DataFrame(), "{}" |
|
html_content = create_tossup_html(tokens, answer, clean_answers, run_indices, eval_points) |
|
plot_data = create_tossup_confidence_pyplot(tokens, eval_points, confidence_threshold, prob_threshold) |
|
|
|
|
|
state = {"tokens": tokens, "values": eval_points} |
|
|
|
|
|
step_outputs = {} |
|
for output in run_outputs: |
|
tok_pos = output["token_position"] |
|
key = "{pos}:{token}".format(pos=tok_pos + 1, token=tokens[tok_pos]) |
|
step_outputs[key] = {k: v for k, v in output["step_outputs"].items() if k not in input_vars} |
|
if output["logprob"] is not None: |
|
step_outputs[key]["output_probability"] = float(np.exp(output["logprob"])) |
|
|
|
return html_content, plot_data, state, step_outputs |
|
except Exception as e: |
|
logger.exception(f"Error initializing interface: {e.args}") |
|
return f"<div>Error initializing interface: {str(e)}</div>", pd.DataFrame(), "{}", {} |
|
|
|
|
|
def process_tossup_results(results: list[dict]) -> pd.DataFrame: |
|
"""Process results from tossup mode and prepare visualization data.""" |
|
data = [] |
|
for r in results: |
|
entry = { |
|
"Token Position": r["token_position"], |
|
"Correct?": "✅" if r["score"] == 1 else "❌", |
|
"Confidence": r["confidence"], |
|
} |
|
if r["logprob"] is not None: |
|
entry["Probability"] = f"{np.exp(r['logprob']):.3f}" |
|
entry["Prediction"] = r["answer"] |
|
data.append(entry) |
|
return pd.DataFrame(data) |
|
|
|
|
|
class TossupInterface: |
|
"""Gradio interface for the Tossup mode.""" |
|
|
|
def __init__( |
|
self, |
|
app: gr.Blocks, |
|
browser_state: gr.BrowserState, |
|
dataset: Dataset, |
|
model_options: dict, |
|
defaults: TossupInterfaceDefaults, |
|
): |
|
"""Initialize the Tossup interface.""" |
|
logger.info(f"Initializing Tossup interface with dataset size: {len(dataset)}") |
|
self.browser_state = browser_state |
|
self.ds = dataset |
|
self.model_options = model_options |
|
self.app = app |
|
self.defaults = defaults |
|
self.output_state = gr.State(value={}) |
|
self.render() |
|
|
|
|
|
|
|
def load_presaved_pipeline_state(self, browser_state: dict, pipeline_change: bool): |
|
logger.debug(f"Loading presaved pipeline state from browser state:\n{json.dumps(browser_state, indent=4)}") |
|
try: |
|
state_dict = browser_state["tossup"].get("pipeline_state", {}) |
|
pipeline_state = TossupPipelineState.model_validate(state_dict) |
|
pipeline_state_dict = pipeline_state.model_dump() |
|
output_state = browser_state["tossup"].get("output_state", {}) |
|
except Exception as e: |
|
logger.warning(f"Error loading presaved pipeline state: {e}") |
|
output_state = {} |
|
workflow = self.defaults["init_workflow"] |
|
pipeline_state_dict = TossupPipelineState.from_workflow(workflow).model_dump() |
|
return browser_state, not pipeline_change, pipeline_state_dict, output_state |
|
|
|
|
|
|
|
def _render_pipeline_interface(self, pipeline_state: TossupPipelineState): |
|
"""Render the model interface.""" |
|
with gr.Row(elem_classes="bonus-header-row form-inline"): |
|
self.pipeline_selector = commons.get_pipeline_selector([]) |
|
self.load_btn = gr.Button("⬇️ Import Pipeline", variant="secondary") |
|
self.import_error_display = gr.HTML(label="Import Error", elem_id="import-error-display", visible=False) |
|
self.pipeline_interface = TossupPipelineInterface( |
|
self.app, |
|
pipeline_state.workflow, |
|
ui_state=pipeline_state.ui_state, |
|
model_options=list(self.model_options.keys()), |
|
config=self.defaults, |
|
) |
|
|
|
def _render_qb_interface(self): |
|
"""Render the quizbowl interface.""" |
|
with gr.Row(elem_classes="bonus-header-row form-inline"): |
|
self.qid_selector = commons.get_qid_selector(len(self.ds)) |
|
self.early_stop_checkbox = gr.Checkbox( |
|
value=self.defaults["early_stop"], |
|
label="Early Stop", |
|
info="Stop if already buzzed", |
|
scale=0, |
|
) |
|
self.run_btn = gr.Button("Run on Tossup Question", variant="secondary") |
|
self.question_display = gr.HTML(label="Question", elem_id="tossup-question-display") |
|
self.error_display = gr.HTML(label="Error", elem_id="tossup-error-display", visible=False) |
|
with gr.Row(): |
|
self.confidence_plot = gr.Plot( |
|
label="Buzz Confidence", |
|
format="webp", |
|
) |
|
self.model_outputs_display = gr.JSON(label="Model Outputs", value="{}", show_indices=True, visible=False) |
|
self.results_table = gr.DataFrame( |
|
label="Model Outputs", |
|
value=pd.DataFrame(columns=["Token Position", "Correct?", "Confidence", "Prediction"]), |
|
visible=False, |
|
) |
|
with gr.Row(): |
|
self.eval_btn = gr.Button("Evaluate", variant="primary") |
|
|
|
self.model_name_input, self.description_input, self.submit_btn, self.submit_status = ( |
|
commons.get_model_submission_accordion(self.app) |
|
) |
|
|
|
def render(self): |
|
"""Create the Gradio interface.""" |
|
workflow = factory.create_empty_tossup_workflow() |
|
pipeline_state = TossupPipelineState.from_workflow(workflow) |
|
|
|
self.hidden_input = gr.Textbox(value="", visible=False, elem_id="hidden-index") |
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(scale=1): |
|
self._render_pipeline_interface(pipeline_state) |
|
|
|
with gr.Column(scale=1): |
|
self._render_qb_interface() |
|
|
|
self._setup_event_listeners() |
|
|
|
|
|
|
|
def get_new_question_html(self, question_id: int) -> str: |
|
"""Get the HTML for a new question.""" |
|
if question_id is None: |
|
logger.error("Question ID is None. Setting to 1") |
|
question_id = 1 |
|
try: |
|
example = self.ds[question_id - 1] |
|
question_tokens = example["question"].split() |
|
return create_tossup_html( |
|
question_tokens, example["answer_primary"], example["clean_answers"], example["run_indices"] |
|
) |
|
except Exception as e: |
|
return f"Error loading question: {str(e)}" |
|
|
|
def get_pipeline_names(self, profile: gr.OAuthProfile | None) -> list[str]: |
|
names = [UNSELECTED_PIPELINE_NAME] + populate.get_pipeline_names("tossup", profile) |
|
return gr.update(choices=names, value=UNSELECTED_PIPELINE_NAME) |
|
|
|
def load_pipeline( |
|
self, model_name: str, pipeline_change: bool, profile: gr.OAuthProfile | None |
|
) -> tuple[str, bool, TossupPipelineStateDict, dict]: |
|
try: |
|
workflow = populate.load_workflow("tossup", model_name, profile) |
|
if workflow is None: |
|
logger.warning(f"Could not load workflow for {model_name}") |
|
return UNSELECTED_PIPELINE_NAME, gr.skip(), gr.skip(), gr.update(visible=False) |
|
pipeline_state_dict = TossupPipelineState.from_workflow(workflow).model_dump() |
|
return UNSELECTED_PIPELINE_NAME, not pipeline_change, pipeline_state_dict, gr.update(visible=True) |
|
except Exception as e: |
|
logger.exception(e) |
|
error_msg = styled_error(f"Error loading pipeline: {str(e)}") |
|
return UNSELECTED_PIPELINE_NAME, gr.skip(), gr.skip(), gr.update(visible=True, value=error_msg) |
|
|
|
|
|
def get_agent_outputs( |
|
self, example: dict, pipeline_state: TossupPipelineState, early_stop: bool |
|
) -> list[ScoredTossupResult]: |
|
"""Get the model outputs for a given question ID.""" |
|
question_runs = [] |
|
tokens = example["question"].split() |
|
for run_idx in example["run_indices"]: |
|
question_runs.append(" ".join(tokens[: run_idx + 1])) |
|
agent = QuizBowlTossupAgent(pipeline_state.workflow) |
|
outputs = list(agent.run(question_runs, early_stop=early_stop)) |
|
outputs = add_model_scores(outputs, example["clean_answers"], example["run_indices"]) |
|
return outputs |
|
|
|
def single_run( |
|
self, |
|
question_id: int, |
|
state_dict: TossupPipelineStateDict, |
|
early_stop: bool = True, |
|
) -> tuple[str, Any, Any]: |
|
"""Run the agent in tossup mode with a system prompt. |
|
|
|
Returns: |
|
tuple: A tuple containing: |
|
- tokens_html (str): HTML representation of the tossup question with buzz indicators |
|
- output_state (gr.update): Update for the output state component |
|
- plot_data (gr.update): Update for the confidence plot with label and visibility |
|
- df (gr.update): Update for the dataframe component showing model outputs |
|
- step_outputs (gr.update): Update for the step outputs component |
|
- error_msg (gr.update): Update for the error message component (hidden if no errors) |
|
""" |
|
|
|
try: |
|
pipeline_state = validation.validate_tossup_workflow(state_dict) |
|
workflow = pipeline_state.workflow |
|
|
|
question_id = int(question_id - 1) |
|
if not self.ds or question_id < 0 or question_id >= len(self.ds): |
|
raise gr.Error("Invalid question ID or dataset not loaded") |
|
example = self.ds[question_id] |
|
outputs = self.get_agent_outputs(example, pipeline_state, early_stop) |
|
|
|
|
|
confidence_threshold = workflow.buzzer.confidence_threshold |
|
prob_threshold = workflow.buzzer.prob_threshold |
|
tokens_html, plot_data, output_state, step_outputs = initialize_eval_interface( |
|
example, outputs, workflow.inputs, confidence_threshold, prob_threshold |
|
) |
|
df = process_tossup_results(outputs) |
|
|
|
return ( |
|
tokens_html, |
|
gr.update(value=output_state), |
|
gr.update(value=plot_data, label=f"Buzz Confidence on Question {question_id + 1}", show_label=True), |
|
gr.update(value=df, label=f"Model Outputs for Question {question_id + 1}", visible=True), |
|
gr.update(value=step_outputs, label=f"Step Outputs for Question {question_id + 1}", visible=True), |
|
gr.update(visible=False), |
|
) |
|
except Exception as e: |
|
import traceback |
|
|
|
error_msg = styled_error(f"Error: {str(e)}\n{traceback.format_exc()}") |
|
return ( |
|
gr.skip(), |
|
gr.skip(), |
|
gr.skip(), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=True, value=error_msg), |
|
) |
|
|
|
def evaluate(self, state_dict: TossupPipelineStateDict, progress: gr.Progress = gr.Progress()): |
|
"""Evaluate the tossup questions.""" |
|
try: |
|
|
|
if not self.ds or not self.ds.num_rows: |
|
return "No dataset loaded", None, None |
|
pipeline_state = validation.validate_tossup_workflow(state_dict) |
|
model_outputs = [] |
|
for example in progress.tqdm(self.ds, desc="Evaluating tossup questions"): |
|
run_outputs = self.get_agent_outputs(example, pipeline_state, early_stop=True) |
|
model_outputs.append(run_outputs) |
|
eval_df = prepare_tossup_results_df(self.ds["run_indices"], model_outputs) |
|
plot_data = create_tossup_eval_dashboard(self.ds["run_indices"], eval_df) |
|
output_df = create_tossup_eval_table(eval_df) |
|
return ( |
|
gr.update(value=plot_data, label="Buzz Positions on Sample Set", show_label=False), |
|
gr.update(value=output_df, label="(Mean) Metrics on Sample Set", visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
) |
|
except Exception as e: |
|
import traceback |
|
|
|
logger.exception(f"Error evaluating tossups: {e.args}") |
|
return ( |
|
gr.skip(), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=True, value=styled_error(f"Error: {str(e)}")), |
|
) |
|
|
|
def submit_model( |
|
self, |
|
model_name: str, |
|
description: str, |
|
state_dict: TossupPipelineStateDict, |
|
profile: gr.OAuthProfile = None, |
|
) -> str: |
|
"""Submit the model output.""" |
|
try: |
|
pipeline_state = validation.validate_tossup_workflow(state_dict) |
|
return submit.submit_model(model_name, description, pipeline_state.workflow, "tossup", profile) |
|
except Exception as e: |
|
logger.exception(f"Error submitting model: {e.args}") |
|
return styled_error(f"Error: {str(e)}") |
|
|
|
@property |
|
def pipeline_state(self): |
|
return self.pipeline_interface.pipeline_state |
|
|
|
|
|
|
|
def _setup_event_listeners(self): |
|
gr.on( |
|
triggers=[self.app.load, self.qid_selector.change], |
|
fn=self.get_new_question_html, |
|
inputs=[self.qid_selector], |
|
outputs=[self.question_display], |
|
) |
|
|
|
gr.on( |
|
triggers=[self.app.load], |
|
fn=self.get_pipeline_names, |
|
outputs=[self.pipeline_selector], |
|
) |
|
|
|
pipeline_change = self.pipeline_interface.pipeline_change |
|
|
|
gr.on( |
|
triggers=[self.app.load], |
|
fn=self.load_presaved_pipeline_state, |
|
inputs=[self.browser_state, pipeline_change], |
|
outputs=[self.browser_state, pipeline_change, self.pipeline_state, self.output_state], |
|
) |
|
|
|
self.load_btn.click( |
|
fn=self.load_pipeline, |
|
inputs=[self.pipeline_selector, pipeline_change], |
|
outputs=[self.pipeline_selector, pipeline_change, self.pipeline_state, self.import_error_display], |
|
) |
|
self.pipeline_interface.add_triggers_for_pipeline_export([self.pipeline_state.change], self.pipeline_state) |
|
|
|
self.run_btn.click( |
|
self.single_run, |
|
inputs=[ |
|
self.qid_selector, |
|
self.pipeline_state, |
|
self.early_stop_checkbox, |
|
], |
|
outputs=[ |
|
self.question_display, |
|
self.output_state, |
|
self.confidence_plot, |
|
self.results_table, |
|
self.model_outputs_display, |
|
self.error_display, |
|
], |
|
) |
|
|
|
self.eval_btn.click( |
|
fn=self.evaluate, |
|
inputs=[self.pipeline_state], |
|
outputs=[self.confidence_plot, self.results_table, self.model_outputs_display, self.error_display], |
|
) |
|
|
|
self.submit_btn.click( |
|
fn=self.submit_model, |
|
inputs=[ |
|
self.model_name_input, |
|
self.description_input, |
|
self.pipeline_state, |
|
], |
|
outputs=[self.submit_status], |
|
) |
|
|