Spaces:
Running
Running
Maharshi Gor
commited on
Commit
·
f10a835
1
Parent(s):
9b07040
Major update:
Browse files* Moved login at top.
* Browser state for pipeline states to retain user changes to pipeline from before login
* Plots and metrics for tossups single run and eval.
* Refactored model_submission pane
* TypedDicts for pipeline interface defaults
- app.py +65 -10
- src/components/commons.py +22 -0
- src/components/quizbowl/bonus.py +84 -45
- src/components/quizbowl/plotting.py +335 -64
- src/components/quizbowl/tossup.py +141 -91
- src/components/quizbowl/utils.py +0 -1
- src/components/typed_dicts.py +16 -0
app.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
import datasets
|
2 |
import gradio as gr
|
3 |
from apscheduler.schedulers.background import BackgroundScheduler
|
@@ -9,6 +11,7 @@ from about import LEADERBOARD_INTRODUCTION_TEXT, LEADERBOARD_TITLE
|
|
9 |
from app_configs import DEFAULT_SELECTIONS, THEME
|
10 |
from components.quizbowl.bonus import BonusInterface
|
11 |
from components.quizbowl.tossup import TossupInterface
|
|
|
12 |
from display.css_html_js import fonts_header, js_head, leaderboard_css
|
13 |
from display.custom_css import css_bonus, css_pipeline, css_tossup
|
14 |
from display.guide import BUILDING_MARKDOWN, GUIDE_MARKDOWN, QUICKSTART_MARKDOWN
|
@@ -76,6 +79,26 @@ def get_default_tab_id(request: gr.Request):
|
|
76 |
return gr.update(selected=tab_key_value)
|
77 |
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
if __name__ == "__main__":
|
80 |
scheduler = BackgroundScheduler()
|
81 |
scheduler.add_job(restart_space, "interval", seconds=SERVER_REFRESH_INTERVAL)
|
@@ -91,19 +114,36 @@ if __name__ == "__main__":
|
|
91 |
theme=THEME,
|
92 |
title="Quizbowl Bot",
|
93 |
) as demo:
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
with gr.Row():
|
95 |
-
gr.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
with gr.Tabs() as gtab:
|
97 |
with gr.Tab("🛎️ Tossup Agents", id="tossup"):
|
98 |
-
defaults =
|
99 |
-
"
|
100 |
-
|
101 |
-
tossup_interface = TossupInterface(demo, tossup_ds, AVAILABLE_MODELS, defaults)
|
102 |
with gr.Tab("🙋🏻♂️ Bonus Round Agents", id="bonus"):
|
103 |
-
defaults =
|
104 |
-
"
|
105 |
-
|
106 |
-
bonus_interface = BonusInterface(demo, bonus_ds, AVAILABLE_MODELS, defaults)
|
107 |
with gr.Tab("🏅 Leaderboard", elem_id="llm-benchmark-tab-table", id="leaderboard"):
|
108 |
leaderboard_timer = gr.Timer(LEADERBOARD_REFRESH_INTERVAL)
|
109 |
gr.Markdown("<a id='leaderboard' href='#leaderboard'>QANTA Leaderboard</a>")
|
@@ -126,4 +166,19 @@ if __name__ == "__main__":
|
|
126 |
with gr.Column():
|
127 |
gr.Markdown(BUILDING_MARKDOWN)
|
128 |
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
import datasets
|
4 |
import gradio as gr
|
5 |
from apscheduler.schedulers.background import BackgroundScheduler
|
|
|
11 |
from app_configs import DEFAULT_SELECTIONS, THEME
|
12 |
from components.quizbowl.bonus import BonusInterface
|
13 |
from components.quizbowl.tossup import TossupInterface
|
14 |
+
from components.typed_dicts import PipelineInterfaceDefaults, TossupInterfaceDefaults
|
15 |
from display.css_html_js import fonts_header, js_head, leaderboard_css
|
16 |
from display.custom_css import css_bonus, css_pipeline, css_tossup
|
17 |
from display.guide import BUILDING_MARKDOWN, GUIDE_MARKDOWN, QUICKSTART_MARKDOWN
|
|
|
79 |
return gr.update(selected=tab_key_value)
|
80 |
|
81 |
|
82 |
+
def presave_pipeline_state(
|
83 |
+
login_btn,
|
84 |
+
browser_state: dict,
|
85 |
+
tossup_pipeline_state: dict,
|
86 |
+
tossup_output_state: dict,
|
87 |
+
bonus_pipeline_state: dict,
|
88 |
+
bonus_output_state: dict,
|
89 |
+
):
|
90 |
+
browser_state.setdefault("tossup", {})
|
91 |
+
browser_state["tossup"]["pipeline_state"] = tossup_pipeline_state
|
92 |
+
browser_state["tossup"]["output_state"] = tossup_output_state
|
93 |
+
browser_state.setdefault("bonus", {})
|
94 |
+
browser_state["bonus"]["pipeline_state"] = bonus_pipeline_state
|
95 |
+
browser_state["bonus"]["output_state"] = bonus_output_state
|
96 |
+
logger.debug(
|
97 |
+
f"Pipeline state before login. Login button: {login_btn}, browser state: {json.dumps(browser_state, indent=4)}"
|
98 |
+
)
|
99 |
+
return login_btn, browser_state
|
100 |
+
|
101 |
+
|
102 |
if __name__ == "__main__":
|
103 |
scheduler = BackgroundScheduler()
|
104 |
scheduler.add_job(restart_space, "interval", seconds=SERVER_REFRESH_INTERVAL)
|
|
|
114 |
theme=THEME,
|
115 |
title="Quizbowl Bot",
|
116 |
) as demo:
|
117 |
+
browser_state = gr.BrowserState(
|
118 |
+
{
|
119 |
+
"tossup": {"pipeline_state": None, "output_state": None},
|
120 |
+
"bonus": {"pipeline_state": None, "output_state": None},
|
121 |
+
}
|
122 |
+
)
|
123 |
with gr.Row():
|
124 |
+
with gr.Column(scale=5):
|
125 |
+
gr.Markdown(
|
126 |
+
"## Welcome to Quizbowl Arena! \n### Create, play around, and submit your quizbowl agents.",
|
127 |
+
elem_classes="welcome-text",
|
128 |
+
)
|
129 |
+
login_btn = gr.LoginButton(scale=1)
|
130 |
+
gr.Markdown(
|
131 |
+
"**First time here?** Check out the [❓ Help](#help) tab for a quick introduction and the "
|
132 |
+
"[walkthrough documentation](https://github.com/stanford-crfm/quizbowl-arena/blob/main/docs/walkthrough.md) "
|
133 |
+
"for detailed examples and tutorials on how to create and compete with your own QuizBowl agents.",
|
134 |
+
elem_classes="help-text",
|
135 |
+
)
|
136 |
with gr.Tabs() as gtab:
|
137 |
with gr.Tab("🛎️ Tossup Agents", id="tossup"):
|
138 |
+
defaults = TossupInterfaceDefaults(
|
139 |
+
**DEFAULT_SELECTIONS["tossup"], init_workflow=factory.create_simple_qb_tossup_workflow()
|
140 |
+
)
|
141 |
+
tossup_interface = TossupInterface(demo, browser_state, tossup_ds, AVAILABLE_MODELS, defaults)
|
142 |
with gr.Tab("🙋🏻♂️ Bonus Round Agents", id="bonus"):
|
143 |
+
defaults = PipelineInterfaceDefaults(
|
144 |
+
**DEFAULT_SELECTIONS["bonus"], init_workflow=factory.create_simple_qb_bonus_workflow()
|
145 |
+
)
|
146 |
+
bonus_interface = BonusInterface(demo, browser_state, bonus_ds, AVAILABLE_MODELS, defaults)
|
147 |
with gr.Tab("🏅 Leaderboard", elem_id="llm-benchmark-tab-table", id="leaderboard"):
|
148 |
leaderboard_timer = gr.Timer(LEADERBOARD_REFRESH_INTERVAL)
|
149 |
gr.Markdown("<a id='leaderboard' href='#leaderboard'>QANTA Leaderboard</a>")
|
|
|
166 |
with gr.Column():
|
167 |
gr.Markdown(BUILDING_MARKDOWN)
|
168 |
|
169 |
+
# Event Listeners
|
170 |
+
|
171 |
+
login_btn.click(
|
172 |
+
fn=presave_pipeline_state,
|
173 |
+
inputs=[
|
174 |
+
login_btn,
|
175 |
+
browser_state,
|
176 |
+
tossup_interface.pipeline_state,
|
177 |
+
tossup_interface.output_state,
|
178 |
+
bonus_interface.pipeline_state,
|
179 |
+
bonus_interface.output_state,
|
180 |
+
],
|
181 |
+
outputs=[login_btn, browser_state],
|
182 |
+
)
|
183 |
+
|
184 |
+
demo.queue(default_concurrency_limit=40).launch()
|
src/components/commons.py
CHANGED
@@ -33,3 +33,25 @@ def get_panel_header(header: str, subheader: str | None = None):
|
|
33 |
with gr.Row(elem_classes="md panel-header-container") as row:
|
34 |
gr.HTML(html)
|
35 |
return row
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
with gr.Row(elem_classes="md panel-header-container") as row:
|
34 |
gr.HTML(html)
|
35 |
return row
|
36 |
+
|
37 |
+
|
38 |
+
def get_model_submission_accordion(app: gr.Blocks):
|
39 |
+
with gr.Accordion(
|
40 |
+
"Feel happy with your agent? Make a submission!", elem_classes="model-submission-accordion", open=True
|
41 |
+
):
|
42 |
+
with gr.Row():
|
43 |
+
model_name_input = gr.Textbox(label="Submission Name")
|
44 |
+
description_input = gr.Textbox(label="Submission Description")
|
45 |
+
with gr.Row():
|
46 |
+
# login_btn = gr.LoginButton()
|
47 |
+
submit_btn = gr.Button("Submit", variant="primary", interactive=False)
|
48 |
+
|
49 |
+
submit_status = gr.HTML(label="Submission Status")
|
50 |
+
|
51 |
+
def check_user_login(profile: gr.OAuthProfile | None):
|
52 |
+
if profile is not None:
|
53 |
+
return gr.update(interactive=True, value="Submit Agent")
|
54 |
+
return gr.update(interactive=False, value="Login to submit your agent")
|
55 |
+
|
56 |
+
gr.on(triggers=app.load, fn=check_user_login, inputs=[], outputs=[submit_btn])
|
57 |
+
return model_name_input, description_input, submit_btn, submit_status
|
src/components/quizbowl/bonus.py
CHANGED
@@ -8,12 +8,12 @@ from loguru import logger
|
|
8 |
|
9 |
from app_configs import CONFIGS, UNSELECTED_PIPELINE_NAME
|
10 |
from components import commons
|
11 |
-
from components.model_pipeline.model_pipeline import PipelineInterface, PipelineState
|
12 |
from components.typed_dicts import PipelineStateDict
|
13 |
from display.formatting import styled_error
|
14 |
from submission import submit
|
|
|
15 |
from workflows.qb_agents import QuizBowlBonusAgent
|
16 |
-
from workflows.structs import ModelStep, Workflow
|
17 |
|
18 |
from . import populate, validation
|
19 |
from .plotting import create_bonus_confidence_plot, create_bonus_html
|
@@ -56,9 +56,10 @@ def initialize_eval_interface(example: dict, model_outputs: list[dict]):
|
|
56 |
class BonusInterface:
|
57 |
"""Gradio interface for the Bonus mode."""
|
58 |
|
59 |
-
def __init__(self, app: gr.Blocks, dataset: Dataset, model_options: dict, defaults: dict):
|
60 |
"""Initialize the Bonus interface."""
|
61 |
logger.info(f"Initializing Bonus interface with dataset size: {len(dataset)}")
|
|
|
62 |
self.ds = dataset
|
63 |
self.model_options = model_options
|
64 |
self.app = app
|
@@ -66,7 +67,24 @@ class BonusInterface:
|
|
66 |
self.output_state = gr.State(value="{}")
|
67 |
self.render()
|
68 |
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
"""Render the model interface."""
|
71 |
with gr.Row(elem_classes="bonus-header-row form-inline"):
|
72 |
self.pipeline_selector = commons.get_pipeline_selector([])
|
@@ -74,7 +92,8 @@ class BonusInterface:
|
|
74 |
self.import_error_display = gr.HTML(label="Import Error", elem_id="import-error-display", visible=False)
|
75 |
self.pipeline_interface = PipelineInterface(
|
76 |
self.app,
|
77 |
-
workflow,
|
|
|
78 |
model_options=list(self.model_options.keys()),
|
79 |
config=self.defaults,
|
80 |
)
|
@@ -97,24 +116,20 @@ class BonusInterface:
|
|
97 |
with gr.Row():
|
98 |
self.eval_btn = gr.Button("Evaluate", variant="primary")
|
99 |
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
self.description_input = gr.Textbox(label="Description")
|
104 |
-
with gr.Row():
|
105 |
-
gr.LoginButton()
|
106 |
-
self.submit_btn = gr.Button("Submit", variant="primary")
|
107 |
-
self.submit_status = gr.HTML(label="Submission Status")
|
108 |
|
109 |
def render(self):
|
110 |
"""Create the Gradio interface."""
|
111 |
self.hidden_input = gr.Textbox(value="", visible=False, elem_id="hidden-index")
|
112 |
-
workflow =
|
|
|
113 |
|
114 |
with gr.Row():
|
115 |
# Model Panel
|
116 |
with gr.Column(scale=1):
|
117 |
-
self._render_pipeline_interface(
|
118 |
|
119 |
with gr.Column(scale=1):
|
120 |
self._render_qb_interface()
|
@@ -150,7 +165,27 @@ class BonusInterface:
|
|
150 |
except Exception as e:
|
151 |
return f"Error loading question: {str(e)}"
|
152 |
|
153 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
"""Get the model outputs for a given question ID."""
|
155 |
outputs = []
|
156 |
leadin = example["leadin"]
|
@@ -168,30 +203,21 @@ class BonusInterface:
|
|
168 |
|
169 |
return outputs
|
170 |
|
171 |
-
def get_pipeline_names(self, profile: gr.OAuthProfile | None) -> list[str]:
|
172 |
-
names = [UNSELECTED_PIPELINE_NAME] + populate.get_pipeline_names("bonus", profile)
|
173 |
-
return gr.update(choices=names, value=UNSELECTED_PIPELINE_NAME)
|
174 |
-
|
175 |
-
def load_pipeline(
|
176 |
-
self, model_name: str, pipeline_change: bool, profile: gr.OAuthProfile | None
|
177 |
-
) -> tuple[str, PipelineStateDict, bool, dict]:
|
178 |
-
try:
|
179 |
-
workflow = populate.load_workflow("bonus", model_name, profile)
|
180 |
-
if workflow is None:
|
181 |
-
logger.warning(f"Could not load workflow for {model_name}")
|
182 |
-
return UNSELECTED_PIPELINE_NAME, gr.skip(), gr.skip(), gr.update(visible=False)
|
183 |
-
pipeline_state_dict = PipelineState.from_workflow(workflow).model_dump()
|
184 |
-
return UNSELECTED_PIPELINE_NAME, pipeline_state_dict, not pipeline_change, gr.update(visible=True)
|
185 |
-
except Exception as e:
|
186 |
-
error_msg = styled_error(f"Error loading pipeline: {str(e)}")
|
187 |
-
return UNSELECTED_PIPELINE_NAME, gr.skip(), gr.skip(), gr.update(visible=True, value=error_msg)
|
188 |
-
|
189 |
def single_run(
|
190 |
self,
|
191 |
question_id: int,
|
192 |
state_dict: PipelineStateDict,
|
193 |
) -> tuple[str, Any, Any]:
|
194 |
-
"""Run the agent in bonus mode.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
try:
|
196 |
pipeline_state = validation.validate_bonus_workflow(state_dict)
|
197 |
question_id = int(question_id - 1)
|
@@ -199,7 +225,7 @@ class BonusInterface:
|
|
199 |
raise gr.Error("Invalid question ID or dataset not loaded")
|
200 |
|
201 |
example = self.ds[question_id]
|
202 |
-
outputs = self.
|
203 |
|
204 |
# Process results and prepare visualization data
|
205 |
html_content, plot_data, output_state = initialize_eval_interface(example, outputs)
|
@@ -239,7 +265,7 @@ class BonusInterface:
|
|
239 |
part_numbers = []
|
240 |
|
241 |
for example in progress.tqdm(self.ds, desc="Evaluating bonus questions"):
|
242 |
-
model_outputs = self.
|
243 |
|
244 |
for output in model_outputs:
|
245 |
total_parts += 1
|
@@ -263,11 +289,12 @@ class BonusInterface:
|
|
263 |
return (
|
264 |
gr.update(value=df, label="Scores on Sample Set"),
|
265 |
gr.update(visible=False),
|
|
|
266 |
)
|
267 |
except Exception as e:
|
268 |
error_msg = styled_error(f"Error evaluating bonus: {e.args}")
|
269 |
logger.exception(f"Error evaluating bonus: {e.args}")
|
270 |
-
return gr.skip(), gr.update(visible=True, value=error_msg)
|
271 |
|
272 |
def submit_model(
|
273 |
self,
|
@@ -280,6 +307,12 @@ class BonusInterface:
|
|
280 |
pipeline_state = PipelineState(**state_dict)
|
281 |
return submit.submit_model(model_name, description, pipeline_state.workflow, "bonus", profile)
|
282 |
|
|
|
|
|
|
|
|
|
|
|
|
|
283 |
def _setup_event_listeners(self):
|
284 |
# Initialize with the default question (ID 0)
|
285 |
|
@@ -296,20 +329,26 @@ class BonusInterface:
|
|
296 |
outputs=[self.pipeline_selector],
|
297 |
)
|
298 |
|
299 |
-
pipeline_state = self.pipeline_interface.pipeline_state
|
300 |
pipeline_change = self.pipeline_interface.pipeline_change
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
301 |
self.load_btn.click(
|
302 |
fn=self.load_pipeline,
|
303 |
inputs=[self.pipeline_selector, pipeline_change],
|
304 |
-
outputs=[self.pipeline_selector,
|
305 |
)
|
306 |
-
self.pipeline_interface.add_triggers_for_pipeline_export([pipeline_state.change], pipeline_state)
|
307 |
|
308 |
self.run_btn.click(
|
309 |
self.single_run,
|
310 |
inputs=[
|
311 |
self.qid_selector,
|
312 |
-
self.
|
313 |
],
|
314 |
outputs=[
|
315 |
self.question_display,
|
@@ -322,8 +361,8 @@ class BonusInterface:
|
|
322 |
|
323 |
self.eval_btn.click(
|
324 |
fn=self.evaluate,
|
325 |
-
inputs=[self.
|
326 |
-
outputs=[self.results_table, self.error_display],
|
327 |
)
|
328 |
|
329 |
self.submit_btn.click(
|
@@ -331,7 +370,7 @@ class BonusInterface:
|
|
331 |
inputs=[
|
332 |
self.model_name_input,
|
333 |
self.description_input,
|
334 |
-
self.
|
335 |
],
|
336 |
outputs=[self.submit_status],
|
337 |
)
|
|
|
8 |
|
9 |
from app_configs import CONFIGS, UNSELECTED_PIPELINE_NAME
|
10 |
from components import commons
|
11 |
+
from components.model_pipeline.model_pipeline import PipelineInterface, PipelineState
|
12 |
from components.typed_dicts import PipelineStateDict
|
13 |
from display.formatting import styled_error
|
14 |
from submission import submit
|
15 |
+
from workflows import factory
|
16 |
from workflows.qb_agents import QuizBowlBonusAgent
|
|
|
17 |
|
18 |
from . import populate, validation
|
19 |
from .plotting import create_bonus_confidence_plot, create_bonus_html
|
|
|
56 |
class BonusInterface:
|
57 |
"""Gradio interface for the Bonus mode."""
|
58 |
|
59 |
+
def __init__(self, app: gr.Blocks, browser_state: dict, dataset: Dataset, model_options: dict, defaults: dict):
|
60 |
"""Initialize the Bonus interface."""
|
61 |
logger.info(f"Initializing Bonus interface with dataset size: {len(dataset)}")
|
62 |
+
self.browser_state = browser_state
|
63 |
self.ds = dataset
|
64 |
self.model_options = model_options
|
65 |
self.app = app
|
|
|
67 |
self.output_state = gr.State(value="{}")
|
68 |
self.render()
|
69 |
|
70 |
+
# ------------------------------------- LOAD PIPELINE STATE FROM BROWSER STATE -------------------------------------
|
71 |
+
|
72 |
+
def load_presaved_pipeline_state(self, browser_state: dict, pipeline_change: bool):
|
73 |
+
logger.debug(f"Loading presaved pipeline state from browser state:\n{json.dumps(browser_state, indent=4)}")
|
74 |
+
try:
|
75 |
+
state_dict = browser_state["bonus"].get("pipeline_state", {})
|
76 |
+
pipeline_state = PipelineState.model_validate(state_dict)
|
77 |
+
pipeline_state_dict = pipeline_state.model_dump()
|
78 |
+
output_state = browser_state["bonus"].get("output_state", "{}")
|
79 |
+
except Exception as e:
|
80 |
+
logger.warning(f"Error loading presaved pipeline state: {e}")
|
81 |
+
output_state = "{}"
|
82 |
+
workflow = self.defaults["init_workflow"]
|
83 |
+
pipeline_state_dict = PipelineState.from_workflow(workflow).model_dump()
|
84 |
+
return browser_state, not pipeline_change, pipeline_state_dict, output_state
|
85 |
+
|
86 |
+
# ------------------------------------------ INTERFACE RENDER FUNCTIONS -------------------------------------------
|
87 |
+
def _render_pipeline_interface(self, pipeline_state: PipelineState):
|
88 |
"""Render the model interface."""
|
89 |
with gr.Row(elem_classes="bonus-header-row form-inline"):
|
90 |
self.pipeline_selector = commons.get_pipeline_selector([])
|
|
|
92 |
self.import_error_display = gr.HTML(label="Import Error", elem_id="import-error-display", visible=False)
|
93 |
self.pipeline_interface = PipelineInterface(
|
94 |
self.app,
|
95 |
+
pipeline_state.workflow,
|
96 |
+
ui_state=pipeline_state.ui_state,
|
97 |
model_options=list(self.model_options.keys()),
|
98 |
config=self.defaults,
|
99 |
)
|
|
|
116 |
with gr.Row():
|
117 |
self.eval_btn = gr.Button("Evaluate", variant="primary")
|
118 |
|
119 |
+
self.model_name_input, self.description_input, self.submit_btn, self.submit_status = (
|
120 |
+
commons.get_model_submission_accordion(self.app)
|
121 |
+
)
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
def render(self):
|
124 |
"""Create the Gradio interface."""
|
125 |
self.hidden_input = gr.Textbox(value="", visible=False, elem_id="hidden-index")
|
126 |
+
workflow = factory.create_empty_tossup_workflow()
|
127 |
+
pipeline_state = PipelineState.from_workflow(workflow)
|
128 |
|
129 |
with gr.Row():
|
130 |
# Model Panel
|
131 |
with gr.Column(scale=1):
|
132 |
+
self._render_pipeline_interface(pipeline_state)
|
133 |
|
134 |
with gr.Column(scale=1):
|
135 |
self._render_qb_interface()
|
|
|
165 |
except Exception as e:
|
166 |
return f"Error loading question: {str(e)}"
|
167 |
|
168 |
+
def get_pipeline_names(self, profile: gr.OAuthProfile | None) -> list[str]:
|
169 |
+
names = [UNSELECTED_PIPELINE_NAME] + populate.get_pipeline_names("bonus", profile)
|
170 |
+
return gr.update(choices=names, value=UNSELECTED_PIPELINE_NAME)
|
171 |
+
|
172 |
+
def load_pipeline(
|
173 |
+
self, model_name: str, pipeline_change: bool, profile: gr.OAuthProfile | None
|
174 |
+
) -> tuple[str, bool, PipelineStateDict, dict]:
|
175 |
+
try:
|
176 |
+
workflow = populate.load_workflow("bonus", model_name, profile)
|
177 |
+
if workflow is None:
|
178 |
+
logger.warning(f"Could not load workflow for {model_name}")
|
179 |
+
return UNSELECTED_PIPELINE_NAME, gr.skip(), gr.skip(), gr.update(visible=False)
|
180 |
+
pipeline_state_dict = PipelineState.from_workflow(workflow).model_dump()
|
181 |
+
return UNSELECTED_PIPELINE_NAME, not pipeline_change, pipeline_state_dict, gr.update(visible=True)
|
182 |
+
except Exception as e:
|
183 |
+
error_msg = styled_error(f"Error loading pipeline: {str(e)}")
|
184 |
+
return UNSELECTED_PIPELINE_NAME, gr.skip(), gr.skip(), gr.update(visible=True, value=error_msg)
|
185 |
+
|
186 |
+
# ------------------------------------- Agent Functions -----------------------------------------------------------
|
187 |
+
|
188 |
+
def get_agent_outputs(self, example: dict, pipeline_state: PipelineState):
|
189 |
"""Get the model outputs for a given question ID."""
|
190 |
outputs = []
|
191 |
leadin = example["leadin"]
|
|
|
203 |
|
204 |
return outputs
|
205 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
def single_run(
|
207 |
self,
|
208 |
question_id: int,
|
209 |
state_dict: PipelineStateDict,
|
210 |
) -> tuple[str, Any, Any]:
|
211 |
+
"""Run the agent in bonus mode and updates the interface.
|
212 |
+
|
213 |
+
Returns:
|
214 |
+
tuple: Contains the following components:
|
215 |
+
- question_display: HTML display content of the question
|
216 |
+
- output_state: Updated state with question parts and outputs
|
217 |
+
- results_table: DataFrame with model predictions and scores
|
218 |
+
- model_outputs_display: Detailed step outputs from the model
|
219 |
+
- error_display: Any error messages (if applicable)
|
220 |
+
"""
|
221 |
try:
|
222 |
pipeline_state = validation.validate_bonus_workflow(state_dict)
|
223 |
question_id = int(question_id - 1)
|
|
|
225 |
raise gr.Error("Invalid question ID or dataset not loaded")
|
226 |
|
227 |
example = self.ds[question_id]
|
228 |
+
outputs = self.get_agent_outputs(example, pipeline_state)
|
229 |
|
230 |
# Process results and prepare visualization data
|
231 |
html_content, plot_data, output_state = initialize_eval_interface(example, outputs)
|
|
|
265 |
part_numbers = []
|
266 |
|
267 |
for example in progress.tqdm(self.ds, desc="Evaluating bonus questions"):
|
268 |
+
model_outputs = self.get_agent_outputs(example, pipeline_state)
|
269 |
|
270 |
for output in model_outputs:
|
271 |
total_parts += 1
|
|
|
289 |
return (
|
290 |
gr.update(value=df, label="Scores on Sample Set"),
|
291 |
gr.update(visible=False),
|
292 |
+
gr.update(visible=False),
|
293 |
)
|
294 |
except Exception as e:
|
295 |
error_msg = styled_error(f"Error evaluating bonus: {e.args}")
|
296 |
logger.exception(f"Error evaluating bonus: {e.args}")
|
297 |
+
return gr.skip(), gr.skip(), gr.update(visible=True, value=error_msg)
|
298 |
|
299 |
def submit_model(
|
300 |
self,
|
|
|
307 |
pipeline_state = PipelineState(**state_dict)
|
308 |
return submit.submit_model(model_name, description, pipeline_state.workflow, "bonus", profile)
|
309 |
|
310 |
+
@property
|
311 |
+
def pipeline_state(self):
|
312 |
+
return self.pipeline_interface.pipeline_state
|
313 |
+
|
314 |
+
# ------------------------------------- Event Listeners -----------------------------------------------------------
|
315 |
+
|
316 |
def _setup_event_listeners(self):
|
317 |
# Initialize with the default question (ID 0)
|
318 |
|
|
|
329 |
outputs=[self.pipeline_selector],
|
330 |
)
|
331 |
|
|
|
332 |
pipeline_change = self.pipeline_interface.pipeline_change
|
333 |
+
|
334 |
+
gr.on(
|
335 |
+
triggers=[self.app.load],
|
336 |
+
fn=self.load_presaved_pipeline_state,
|
337 |
+
inputs=[self.browser_state, pipeline_change],
|
338 |
+
outputs=[self.browser_state, pipeline_change, self.pipeline_state, self.output_state],
|
339 |
+
)
|
340 |
self.load_btn.click(
|
341 |
fn=self.load_pipeline,
|
342 |
inputs=[self.pipeline_selector, pipeline_change],
|
343 |
+
outputs=[self.pipeline_selector, pipeline_change, self.pipeline_state, self.import_error_display],
|
344 |
)
|
345 |
+
self.pipeline_interface.add_triggers_for_pipeline_export([self.pipeline_state.change], self.pipeline_state)
|
346 |
|
347 |
self.run_btn.click(
|
348 |
self.single_run,
|
349 |
inputs=[
|
350 |
self.qid_selector,
|
351 |
+
self.pipeline_state,
|
352 |
],
|
353 |
outputs=[
|
354 |
self.question_display,
|
|
|
361 |
|
362 |
self.eval_btn.click(
|
363 |
fn=self.evaluate,
|
364 |
+
inputs=[self.pipeline_state],
|
365 |
+
outputs=[self.results_table, self.model_outputs_display, self.error_display],
|
366 |
)
|
367 |
|
368 |
self.submit_btn.click(
|
|
|
370 |
inputs=[
|
371 |
self.model_name_input,
|
372 |
self.description_input,
|
373 |
+
self.pipeline_state,
|
374 |
],
|
375 |
outputs=[self.submit_status],
|
376 |
)
|
src/components/quizbowl/plotting.py
CHANGED
@@ -1,9 +1,11 @@
|
|
|
|
1 |
import json
|
2 |
import logging
|
3 |
import re
|
4 |
from collections import Counter
|
5 |
|
6 |
import matplotlib.pyplot as plt
|
|
|
7 |
import pandas as pd
|
8 |
|
9 |
|
@@ -25,7 +27,7 @@ def _get_token_classes(confidence, buzz, score) -> str:
|
|
25 |
if confidence is None:
|
26 |
return "token"
|
27 |
elif not buzz:
|
28 |
-
return "token guess-point
|
29 |
else:
|
30 |
return f"token guess-point buzz-{score}"
|
31 |
|
@@ -44,12 +46,19 @@ def _create_token_tooltip_html(values) -> str:
|
|
44 |
|
45 |
color = "#a3c9a3" if score else "#ebbec4" # Light green for correct, light pink for incorrect
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
return f"""
|
48 |
<div class="tooltip card" style="background-color: {color}; border-radius: 8px; padding: 12px; box-shadow: 2px 4px 8px rgba(0, 0, 0, 0.15);">
|
49 |
<div class="tooltip-content" style="font-family: 'Arial', sans-serif; color: #000;">
|
50 |
<h4 style="margin: 0 0 8px; color: #000;">💡 Answer</h4>
|
51 |
-
<p style="font-weight: bold; margin: 0 0 8px; color: #000;">{answer}</p>
|
52 |
-
<p style="margin: 0 0 4px; color: #000;"
|
|
|
53 |
<p style="margin: 0; color: #000;">🔍 <b style="color: #000;">Status:</b> {"✅ Correct" if score else "❌ Incorrect" if buzz else "🚫 No Buzz"}</p>
|
54 |
</div>
|
55 |
</div>
|
@@ -145,86 +154,48 @@ def create_bonus_html(leadin: str, parts: list[dict]) -> str:
|
|
145 |
return html_content
|
146 |
|
147 |
|
148 |
-
def create_line_plot(eval_points: list[tuple[int, dict]], highlighted_index: int = -1) -> pd.DataFrame:
|
149 |
-
"""Create a Gradio LinePlot of token values with optional highlighting using DataFrame."""
|
150 |
-
try:
|
151 |
-
# Create base confidence data
|
152 |
-
data = []
|
153 |
-
|
154 |
-
# Add buzz points to the plot
|
155 |
-
for i, (v, b) in eval_points:
|
156 |
-
color = "#ff4444" if b == 0 else "#228b22"
|
157 |
-
data.append(
|
158 |
-
{
|
159 |
-
"position": i,
|
160 |
-
"value": v,
|
161 |
-
"type": "buzz",
|
162 |
-
"highlight": True,
|
163 |
-
"color": color,
|
164 |
-
}
|
165 |
-
)
|
166 |
-
|
167 |
-
if highlighted_index >= 0:
|
168 |
-
# Add vertical line for the highlighted token
|
169 |
-
data.extend(
|
170 |
-
[
|
171 |
-
{
|
172 |
-
"position": highlighted_index,
|
173 |
-
"value": 0,
|
174 |
-
"type": "hover-line",
|
175 |
-
"color": "#000000",
|
176 |
-
"highlight": True,
|
177 |
-
},
|
178 |
-
{
|
179 |
-
"position": highlighted_index,
|
180 |
-
"value": 1,
|
181 |
-
"type": "hover-line",
|
182 |
-
"color": "#000000",
|
183 |
-
"highlight": True,
|
184 |
-
},
|
185 |
-
]
|
186 |
-
)
|
187 |
-
|
188 |
-
return pd.DataFrame(data)
|
189 |
-
except Exception as e:
|
190 |
-
logging.error(f"Error creating line plot: {e}", exc_info=True)
|
191 |
-
# Return an empty DataFrame with the expected columns
|
192 |
-
return pd.DataFrame(columns=["position", "value", "type", "highlight", "color"])
|
193 |
-
|
194 |
-
|
195 |
def create_tossup_confidence_pyplot(
|
196 |
-
tokens: list[str],
|
|
|
|
|
|
|
197 |
) -> plt.Figure:
|
198 |
"""Create a pyplot of token values with optional highlighting."""
|
199 |
plt.style.use("ggplot") # Set theme to grid paper
|
200 |
-
fig = plt.figure(figsize=(
|
201 |
ax = fig.add_subplot(111)
|
202 |
-
x = [0]
|
203 |
-
|
204 |
-
for
|
205 |
-
|
206 |
-
y.append(v["confidence"])
|
207 |
|
208 |
-
ax.plot(x,
|
|
|
209 |
for i, v in eval_points:
|
210 |
if not v["buzz"]:
|
211 |
continue
|
212 |
-
confidence = v["confidence"]
|
213 |
color = "green" if v["score"] else "red"
|
214 |
-
|
|
|
|
|
|
|
|
|
215 |
if i >= len(tokens):
|
216 |
print(f"Token index {i} is out of bounds for n_tokens: {len(tokens)}")
|
217 |
-
ax.annotate(f"{tokens[i]}", (i + 1,
|
218 |
|
219 |
-
|
220 |
-
|
221 |
-
|
|
|
|
|
222 |
|
223 |
ax.set_title("Buzz Confidence")
|
224 |
ax.set_xlabel("Token Index")
|
225 |
ax.set_ylabel("Confidence")
|
226 |
ax.set_xticks(x)
|
227 |
ax.set_xticklabels(x)
|
|
|
228 |
return fig
|
229 |
|
230 |
|
@@ -300,3 +271,303 @@ def update_tossup_plot(highlighted_index: int, state: str) -> pd.DataFrame:
|
|
300 |
except Exception as e:
|
301 |
logging.error(f"Error updating plot: {e}")
|
302 |
return pd.DataFrame()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# %%
|
2 |
import json
|
3 |
import logging
|
4 |
import re
|
5 |
from collections import Counter
|
6 |
|
7 |
import matplotlib.pyplot as plt
|
8 |
+
import numpy as np
|
9 |
import pandas as pd
|
10 |
|
11 |
|
|
|
27 |
if confidence is None:
|
28 |
return "token"
|
29 |
elif not buzz:
|
30 |
+
return f"token guess-point buzz-{score}"
|
31 |
else:
|
32 |
return f"token guess-point buzz-{score}"
|
33 |
|
|
|
46 |
|
47 |
color = "#a3c9a3" if score else "#ebbec4" # Light green for correct, light pink for incorrect
|
48 |
|
49 |
+
if values.get("logprob", None) is not None:
|
50 |
+
prob = np.exp(values["logprob"])
|
51 |
+
prob_str = f"<p style='margin: 0 0 4px; color: #000;'> 📈 <b style='color: #000;'>Output Probability:</b> {prob:.3f}</p>"
|
52 |
+
else:
|
53 |
+
prob_str = ""
|
54 |
+
|
55 |
return f"""
|
56 |
<div class="tooltip card" style="background-color: {color}; border-radius: 8px; padding: 12px; box-shadow: 2px 4px 8px rgba(0, 0, 0, 0.15);">
|
57 |
<div class="tooltip-content" style="font-family: 'Arial', sans-serif; color: #000;">
|
58 |
<h4 style="margin: 0 0 8px; color: #000;">💡 Answer</h4>
|
59 |
+
<p><code style="font-weight: bold; margin: 0 0 8px; color: #000;">{answer}</code></p>
|
60 |
+
<p style="margin: 0 0 4px; color: #000;">📈 <b style="color: #000;">Confidence:</b> {confidence:.2f}</p>
|
61 |
+
{prob_str}
|
62 |
<p style="margin: 0; color: #000;">🔍 <b style="color: #000;">Status:</b> {"✅ Correct" if score else "❌ Incorrect" if buzz else "🚫 No Buzz"}</p>
|
63 |
</div>
|
64 |
</div>
|
|
|
154 |
return html_content
|
155 |
|
156 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
def create_tossup_confidence_pyplot(
|
158 |
+
tokens: list[str],
|
159 |
+
eval_points: list[tuple[int, dict]],
|
160 |
+
confidence_threshold: float = 0.5,
|
161 |
+
prob_threshold: float | None = None,
|
162 |
) -> plt.Figure:
|
163 |
"""Create a pyplot of token values with optional highlighting."""
|
164 |
plt.style.use("ggplot") # Set theme to grid paper
|
165 |
+
fig = plt.figure(figsize=(10, 4), dpi=300) # Set figure size to 11x5
|
166 |
ax = fig.add_subplot(111)
|
167 |
+
x = [0] + [int(i + 1) for i, _ in eval_points]
|
168 |
+
y_conf = [0] + [v["confidence"] for _, v in eval_points]
|
169 |
+
logprob_values = [v["logprob"] for _, v in eval_points if v["logprob"] is not None]
|
170 |
+
y_prob = [0] + [np.exp(v) for v in logprob_values]
|
|
|
171 |
|
172 |
+
ax.plot(x, y_prob, "o-", color="#f2b150", label="Probability")
|
173 |
+
ax.plot(x, y_conf, "o-", color="#4996de", label="Confidence")
|
174 |
for i, v in eval_points:
|
175 |
if not v["buzz"]:
|
176 |
continue
|
|
|
177 |
color = "green" if v["score"] else "red"
|
178 |
+
conf = v["confidence"]
|
179 |
+
ax.plot(i + 1, conf, "o", color=color, markerfacecolor="none", markersize=12, markeredgewidth=2.5)
|
180 |
+
if v["logprob"] is not None:
|
181 |
+
prob = np.exp(v["logprob"])
|
182 |
+
ax.plot(i + 1, prob, "o", color=color, markerfacecolor="none", markersize=12, markeredgewidth=2.5)
|
183 |
if i >= len(tokens):
|
184 |
print(f"Token index {i} is out of bounds for n_tokens: {len(tokens)}")
|
185 |
+
ax.annotate(f"{tokens[i]}", (i + 1, conf), textcoords="offset points", xytext=(0, 10), ha="center")
|
186 |
|
187 |
+
# Add horizontal dashed line for confidence threshold
|
188 |
+
ax.axhline(y=confidence_threshold, color="#9370DB", linestyle="--", xmin=0, xmax=1, label="Confidence Threshold")
|
189 |
+
# Add horizontal dashed line for probability threshold if provided
|
190 |
+
if prob_threshold is not None:
|
191 |
+
ax.axhline(y=prob_threshold, color="#cf5757", linestyle="--", xmin=0, xmax=1, label="Probability Threshold")
|
192 |
|
193 |
ax.set_title("Buzz Confidence")
|
194 |
ax.set_xlabel("Token Index")
|
195 |
ax.set_ylabel("Confidence")
|
196 |
ax.set_xticks(x)
|
197 |
ax.set_xticklabels(x)
|
198 |
+
ax.legend()
|
199 |
return fig
|
200 |
|
201 |
|
|
|
271 |
except Exception as e:
|
272 |
logging.error(f"Error updating plot: {e}")
|
273 |
return pd.DataFrame()
|
274 |
+
|
275 |
+
|
276 |
+
# %%
|
277 |
+
|
278 |
+
|
279 |
+
def create_df_entry(run_indices: list[int], run_outputs: list[dict]) -> dict:
|
280 |
+
"""Create a dataframe entry from a list of model outputs."""
|
281 |
+
chosen_idx = None
|
282 |
+
earliest_ok_idx = None
|
283 |
+
is_correct = None
|
284 |
+
for i, o in enumerate(run_outputs):
|
285 |
+
if chosen_idx is None and o["buzz"]:
|
286 |
+
chosen_idx = run_indices[o["position"] - 1] + 1
|
287 |
+
is_correct = o["score"]
|
288 |
+
if earliest_ok_idx is None and o["score"]:
|
289 |
+
earliest_ok_idx = run_indices[o["position"] - 1] + 1
|
290 |
+
if is_correct is None:
|
291 |
+
is_correct = False
|
292 |
+
|
293 |
+
# if buzz is not the last index, correct scores 10, incorrect scores -5
|
294 |
+
# if buzz is the final index, correct scores 5, incorrect scores 0
|
295 |
+
|
296 |
+
if chosen_idx == -1:
|
297 |
+
tossup_score = 0
|
298 |
+
elif chosen_idx == run_indices[-1] + 1:
|
299 |
+
tossup_score = 5 if is_correct else 0
|
300 |
+
else:
|
301 |
+
tossup_score = 10 if is_correct else -5
|
302 |
+
|
303 |
+
gap = None if (chosen_idx is None or earliest_ok_idx is None) else chosen_idx - earliest_ok_idx
|
304 |
+
if earliest_ok_idx is None:
|
305 |
+
cls = "hopeless"
|
306 |
+
elif chosen_idx is None:
|
307 |
+
cls = "never-buzzed" # Opportunity missed to score
|
308 |
+
elif chosen_idx == earliest_ok_idx:
|
309 |
+
cls = "best-buzz" # Perfect timing
|
310 |
+
elif chosen_idx > earliest_ok_idx:
|
311 |
+
cls = "late-buzz" # Opportunity missed to buzz earlier
|
312 |
+
elif chosen_idx < earliest_ok_idx:
|
313 |
+
cls = "premature" # Opportunity missed to score
|
314 |
+
|
315 |
+
return {
|
316 |
+
"chosen_idx": chosen_idx,
|
317 |
+
"earliest_ok_idx": earliest_ok_idx,
|
318 |
+
"gap": gap,
|
319 |
+
"cls": cls,
|
320 |
+
"tossup_score": tossup_score,
|
321 |
+
"is_correct": int(is_correct),
|
322 |
+
}
|
323 |
+
|
324 |
+
|
325 |
+
def prepare_tossup_results_df(run_indices: list[list[int]], model_outputs: list[list[dict]]) -> pd.DataFrame:
|
326 |
+
"""Create a dataframe from a list of model outputs."""
|
327 |
+
records = []
|
328 |
+
for indices, outputs in zip(run_indices, model_outputs):
|
329 |
+
entry = create_df_entry(indices, outputs)
|
330 |
+
records.append(entry)
|
331 |
+
return pd.DataFrame.from_records(records)
|
332 |
+
|
333 |
+
|
334 |
+
def create_tossup_eval_table(df: pd.DataFrame) -> pd.DataFrame:
|
335 |
+
"""Create a table from a dataframe."""
|
336 |
+
# Prepare a dataframe of aggregated metrics:
|
337 |
+
# - Mean Tossup Score
|
338 |
+
# - Buzz Accuracy
|
339 |
+
# - Mean +ve Gap
|
340 |
+
# - Mean -ve Gap
|
341 |
+
# - Mean Buzz Position
|
342 |
+
|
343 |
+
positions = df["chosen_idx"].dropna()
|
344 |
+
gaps = df["gap"].dropna()
|
345 |
+
pos_gaps = gaps.loc[gaps >= 0]
|
346 |
+
neg_gaps = gaps.loc[gaps < 0]
|
347 |
+
|
348 |
+
mean_tossup_score = df["tossup_score"].sum() / len(df)
|
349 |
+
|
350 |
+
return pd.DataFrame(
|
351 |
+
[
|
352 |
+
{
|
353 |
+
"Tossup Score (10)": f"{mean_tossup_score:5.1f}",
|
354 |
+
"Buzz Accuracy": f"{df['is_correct'].mean():5.1%}",
|
355 |
+
"Buzz Position": f"{np.mean(positions):5.1f}",
|
356 |
+
"+ve Gap": f"{pos_gaps.mean():5.1f}",
|
357 |
+
"-ve Gap": f"{neg_gaps.mean():5.1f}",
|
358 |
+
}
|
359 |
+
]
|
360 |
+
)
|
361 |
+
|
362 |
+
|
363 |
+
def create_tossup_eval_dashboard(run_indices: list[list[int]], df: pd.DataFrame, *, figsize=(15, 8), title_prefix=""):
|
364 |
+
"""
|
365 |
+
Visualise buzzing behaviour with three sub-plots:
|
366 |
+
|
367 |
+
1. Ceiling-accuracy vs. prefix length
|
368 |
+
2. Scatter of earliest-correct idx vs. chosen-buzz idx
|
369 |
+
3. Frequency distribution of narrative classes (vertical bars)
|
370 |
+
|
371 |
+
Parameters
|
372 |
+
----------
|
373 |
+
df : pd.DataFrame
|
374 |
+
Output of `build_buzz_dataframe` – must contain
|
375 |
+
columns: earliest_ok_idx, chosen_idx, cls.
|
376 |
+
eval_indices : sequence[int]
|
377 |
+
Token positions at which the model was probed.
|
378 |
+
figsize : tuple, optional
|
379 |
+
Figure size passed to `plt.subplots`.
|
380 |
+
title_prefix : str, optional
|
381 |
+
Prepended to each subplot title (useful when comparing models).
|
382 |
+
"""
|
383 |
+
# ------------------------------------------------------------------
|
384 |
+
# 0. Prep (variables reused throughout the function)
|
385 |
+
# ------------------------------------------------------------------
|
386 |
+
# Collect all evaluation indices across questions so we know the
|
387 |
+
# x-axis domain and the padding for NaNs.
|
388 |
+
eval_indices = np.asarray(sorted({idx for indices in run_indices for idx in indices}))
|
389 |
+
|
390 |
+
# Narrative classes and their colours
|
391 |
+
classes = [
|
392 |
+
"best-buzz",
|
393 |
+
"late-buzz",
|
394 |
+
"never-buzzed",
|
395 |
+
"premature",
|
396 |
+
"hopeless",
|
397 |
+
]
|
398 |
+
colors = ["tab:green", "tab:olive", "tab:orange", "tab:red", "tab:gray"]
|
399 |
+
palette = dict(zip(classes, colors))
|
400 |
+
|
401 |
+
max_idx = eval_indices.max() * 1.25 # padding for NaN replacement / axis limits
|
402 |
+
|
403 |
+
# ------------------------------------------------------------------
|
404 |
+
# 1. Figure / axes layout
|
405 |
+
# ------------------------------------------------------------------
|
406 |
+
# GridSpec layout → 2 rows × 3 cols.
|
407 |
+
# ┌────────────┬────────────┬────────┐
|
408 |
+
# │ Ceiling │ Scatter │ Bars │ (row 0)
|
409 |
+
# ├────────────┴────────────┴────────┤
|
410 |
+
# │ Descriptions (spans all 3 cols) │ (row 1)
|
411 |
+
# └──────────────────────────────────┘
|
412 |
+
# Having a dedicated row for the narrative-class descriptions avoids
|
413 |
+
# overlapping with sub-plots and makes the whole figure more compact.
|
414 |
+
|
415 |
+
plt.style.use("ggplot")
|
416 |
+
fig = plt.figure(figsize=figsize)
|
417 |
+
gs = fig.add_gridspec(
|
418 |
+
nrows=2,
|
419 |
+
ncols=3,
|
420 |
+
height_ratios=[5, 1], # extra space for plots vs. descriptions
|
421 |
+
width_ratios=[2.2, 2.2, 1],
|
422 |
+
hspace=0.2, # reduced vertical spacing between plots
|
423 |
+
wspace=0.2, # reduced horizontal spacing between plots
|
424 |
+
left=0.05, # reduced left margin
|
425 |
+
right=0.95, # reduced right margin
|
426 |
+
top=0.9, # reduced top margin
|
427 |
+
bottom=0.05, # reduced bottom margin
|
428 |
+
)
|
429 |
+
|
430 |
+
ax_ceiling = fig.add_subplot(gs[0, 0]) # Ceiling accuracy curve
|
431 |
+
ax_scatter = fig.add_subplot(gs[0, 1]) # Earliest vs. chosen scatter
|
432 |
+
ax_bars = fig.add_subplot(gs[0, 2]) # Outcome distribution bars
|
433 |
+
ax_desc = fig.add_subplot(gs[1, :]) # Textual descriptions
|
434 |
+
ax_desc.axis("off")
|
435 |
+
|
436 |
+
fig.suptitle("Buzzing behaviour", fontsize=16, fontweight="bold")
|
437 |
+
|
438 |
+
# ------------------------------------------------------------------
|
439 |
+
# 2. Ceiling accuracy curve
|
440 |
+
# ------------------------------------------------------------------
|
441 |
+
ceiling = [((df["earliest_ok_idx"].notna()) & (df["earliest_ok_idx"] <= idx)).mean() for idx in eval_indices]
|
442 |
+
ax_ceiling.plot(eval_indices, ceiling, marker="o", color="#4698cf")
|
443 |
+
ax_ceiling.set_xlabel("Token index shown")
|
444 |
+
ax_ceiling.set_ylabel("Proportion of questions correct")
|
445 |
+
ax_ceiling.set_ylim(0, 1.01)
|
446 |
+
ax_ceiling.set_title(f"{title_prefix}Ceiling accuracy vs. prefix")
|
447 |
+
|
448 |
+
# ------------------------------------------------------------------
|
449 |
+
# 3. Earliest-vs-Chosen scatter
|
450 |
+
# ------------------------------------------------------------------
|
451 |
+
for cls in classes:
|
452 |
+
sub = df[df["cls"] == cls]
|
453 |
+
if sub.empty:
|
454 |
+
continue
|
455 |
+
x = sub["earliest_ok_idx"].fillna(max_idx)
|
456 |
+
y = sub["chosen_idx"].fillna(max_idx)
|
457 |
+
ax_scatter.scatter(
|
458 |
+
x,
|
459 |
+
y,
|
460 |
+
label=cls,
|
461 |
+
alpha=0.7,
|
462 |
+
edgecolor="black",
|
463 |
+
linewidth=1,
|
464 |
+
marker="o",
|
465 |
+
s=90,
|
466 |
+
c=palette[cls],
|
467 |
+
facecolor="none",
|
468 |
+
)
|
469 |
+
|
470 |
+
lim = max_idx
|
471 |
+
ax_scatter.plot([0, lim], [0, lim], linestyle=":", linewidth=1)
|
472 |
+
ax_scatter.set_xlim(0, lim)
|
473 |
+
ax_scatter.set_ylim(0, lim)
|
474 |
+
ax_scatter.set_xlabel("Earliest index with correct answer")
|
475 |
+
ax_scatter.set_ylabel("Chosen buzz index")
|
476 |
+
ax_scatter.set_title(f"{title_prefix}Earliest vs. chosen index")
|
477 |
+
ax_scatter.legend(frameon=False, fontsize="small")
|
478 |
+
|
479 |
+
# ------------------------------------------------------------------
|
480 |
+
# 4. Outcome distribution (horizontal bars)
|
481 |
+
# ------------------------------------------------------------------
|
482 |
+
counts = df["cls"].value_counts().reindex(classes).fillna(0)
|
483 |
+
ax_bars.barh(
|
484 |
+
counts.index,
|
485 |
+
counts.values,
|
486 |
+
color=[palette[c] for c in counts.index],
|
487 |
+
alpha=0.7,
|
488 |
+
edgecolor="black",
|
489 |
+
linewidth=1,
|
490 |
+
)
|
491 |
+
ax_bars.set_xlabel("Number of questions")
|
492 |
+
ax_bars.set_title(f"{title_prefix}Outcome distribution")
|
493 |
+
|
494 |
+
# Ensure x-axis shows integer ticks only
|
495 |
+
from matplotlib.ticker import MaxNLocator
|
496 |
+
|
497 |
+
ax_bars.xaxis.set_major_locator(MaxNLocator(integer=True))
|
498 |
+
|
499 |
+
# ------------------------------------------------------------------
|
500 |
+
# 5. Narrative-class descriptions (bottom panel)
|
501 |
+
# ------------------------------------------------------------------
|
502 |
+
descriptions = {
|
503 |
+
"best-buzz": "Perfect timing. Buzzed at the earliest possible correct position",
|
504 |
+
"late-buzz": "Missed opportunity. Buzzed correctly but later than optimal",
|
505 |
+
"never-buzzed": "Missed opportunity. Never buzzed despite knowing the answer",
|
506 |
+
"premature": "Incorrect buzz. Buzzing at a later position could have been correct",
|
507 |
+
"hopeless": "Never knew the answer. No correct answer at any position",
|
508 |
+
}
|
509 |
+
|
510 |
+
y_pos = 1.0 # start at top of the description axis
|
511 |
+
|
512 |
+
for cls, color in zip(classes, colors):
|
513 |
+
ax_desc.text(
|
514 |
+
0.01,
|
515 |
+
y_pos,
|
516 |
+
f"■ {cls}: {descriptions[cls]}",
|
517 |
+
ha="left",
|
518 |
+
va="top",
|
519 |
+
color=color,
|
520 |
+
fontweight="bold",
|
521 |
+
fontsize=11, # increased font size from 9 to 11
|
522 |
+
transform=ax_desc.transAxes,
|
523 |
+
)
|
524 |
+
|
525 |
+
y_pos -= 0.25 # increased vertical step inside the axis for more line height
|
526 |
+
|
527 |
+
# ------------------------------------------------------------------
|
528 |
+
# 6. Return the final figure
|
529 |
+
# ------------------------------------------------------------------
|
530 |
+
return fig
|
531 |
+
|
532 |
+
|
533 |
+
# %%
|
534 |
+
|
535 |
+
|
536 |
+
# Create dummy data for testing
|
537 |
+
def create_dummy_model_outputs(n_entries=10, n_positions=5):
|
538 |
+
"""Create dummy model outputs for testing."""
|
539 |
+
np.random.seed(42)
|
540 |
+
dummy_outputs = []
|
541 |
+
|
542 |
+
for _ in range(n_entries):
|
543 |
+
run_indices = sorted(np.random.choice(range(10, 50), n_positions, replace=False))
|
544 |
+
outputs = []
|
545 |
+
|
546 |
+
for i in range(n_positions):
|
547 |
+
# Randomly decide if model will buzz at this position
|
548 |
+
will_buzz = np.random.random() > 0.7
|
549 |
+
# Randomly decide if answer is correct
|
550 |
+
is_correct = np.random.random() > 0.4
|
551 |
+
|
552 |
+
outputs.append(
|
553 |
+
{
|
554 |
+
"position": i + 1,
|
555 |
+
"buzz": will_buzz,
|
556 |
+
"score": 1 if is_correct else 0,
|
557 |
+
"confidence": np.random.random(),
|
558 |
+
"logprob": np.log(np.random.random()),
|
559 |
+
"answer": f"Answer {i + 1}",
|
560 |
+
}
|
561 |
+
)
|
562 |
+
|
563 |
+
dummy_outputs.append({"run_indices": run_indices, "outputs": outputs})
|
564 |
+
|
565 |
+
return dummy_outputs
|
566 |
+
|
567 |
+
|
568 |
+
# dummy_data = create_dummy_model_outputs()
|
569 |
+
# dummy_df = pd.DataFrame([create_df_entry(entry["run_indices"], entry["outputs"]) for entry in dummy_data])
|
570 |
+
# dummy_df
|
571 |
+
# plot_buzz_dashboard(dummy_df, dummy_data[0]["run_indices"])
|
572 |
+
|
573 |
+
# %%
|
src/components/quizbowl/tossup.py
CHANGED
@@ -9,20 +9,20 @@ from loguru import logger
|
|
9 |
|
10 |
from app_configs import CONFIGS, UNSELECTED_PIPELINE_NAME
|
11 |
from components import commons
|
12 |
-
from components.model_pipeline.model_pipeline import PipelineInterface, PipelineState, PipelineUIState
|
13 |
from components.model_pipeline.tossup_pipeline import TossupPipelineInterface, TossupPipelineState
|
14 |
-
from components.typed_dicts import
|
15 |
from display.formatting import styled_error
|
16 |
from submission import submit
|
|
|
17 |
from workflows.qb_agents import QuizBowlTossupAgent, TossupResult
|
18 |
-
from workflows.structs import ModelStep, TossupWorkflow
|
19 |
|
20 |
from . import populate, validation
|
21 |
from .plotting import (
|
22 |
-
create_scatter_pyplot,
|
23 |
create_tossup_confidence_pyplot,
|
|
|
|
|
24 |
create_tossup_html,
|
25 |
-
|
26 |
)
|
27 |
from .utils import evaluate_prediction
|
28 |
|
@@ -53,13 +53,16 @@ def prepare_buzz_evals(
|
|
53 |
logger.warning("No run indices provided, returning empty results")
|
54 |
return [], []
|
55 |
eval_points = []
|
56 |
-
for
|
57 |
-
|
|
|
58 |
|
59 |
return eval_points
|
60 |
|
61 |
|
62 |
-
def initialize_eval_interface(
|
|
|
|
|
63 |
"""Initialize the interface with example text."""
|
64 |
try:
|
65 |
tokens = example["question"].split()
|
@@ -70,9 +73,8 @@ def initialize_eval_interface(example, model_outputs: list[dict]):
|
|
70 |
|
71 |
if not tokens:
|
72 |
return "<div>No tokens found in the provided text.</div>", pd.DataFrame(), "{}"
|
73 |
-
highlighted_index = next((int(i) for i, v in eval_points if v["buzz"] == 1), -1)
|
74 |
html_content = create_tossup_html(tokens, answer, clean_answers, run_indices, eval_points)
|
75 |
-
plot_data = create_tossup_confidence_pyplot(tokens, eval_points,
|
76 |
|
77 |
# Store tokens, values, and buzzes as JSON for later use
|
78 |
state = json.dumps({"tokens": tokens, "values": eval_points})
|
@@ -83,30 +85,36 @@ def initialize_eval_interface(example, model_outputs: list[dict]):
|
|
83 |
return f"<div>Error initializing interface: {str(e)}</div>", pd.DataFrame(), "{}"
|
84 |
|
85 |
|
86 |
-
def process_tossup_results(results: list[dict]
|
87 |
"""Process results from tossup mode and prepare visualization data."""
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
]
|
101 |
-
)
|
102 |
|
103 |
|
104 |
class TossupInterface:
|
105 |
"""Gradio interface for the Tossup mode."""
|
106 |
|
107 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
"""Initialize the Tossup interface."""
|
109 |
logger.info(f"Initializing Tossup interface with dataset size: {len(dataset)}")
|
|
|
110 |
self.ds = dataset
|
111 |
self.model_options = model_options
|
112 |
self.app = app
|
@@ -114,7 +122,25 @@ class TossupInterface:
|
|
114 |
self.output_state = gr.State(value="{}")
|
115 |
self.render()
|
116 |
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
"""Render the model interface."""
|
119 |
with gr.Row(elem_classes="bonus-header-row form-inline"):
|
120 |
self.pipeline_selector = commons.get_pipeline_selector([])
|
@@ -122,7 +148,8 @@ class TossupInterface:
|
|
122 |
self.import_error_display = gr.HTML(label="Import Error", elem_id="import-error-display", visible=False)
|
123 |
self.pipeline_interface = TossupPipelineInterface(
|
124 |
self.app,
|
125 |
-
workflow,
|
|
|
126 |
model_options=list(self.model_options.keys()),
|
127 |
config=self.defaults,
|
128 |
)
|
@@ -154,32 +181,29 @@ class TossupInterface:
|
|
154 |
with gr.Row():
|
155 |
self.eval_btn = gr.Button("Evaluate", variant="primary")
|
156 |
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
self.description_input = gr.Textbox(label="Description")
|
161 |
-
with gr.Row():
|
162 |
-
gr.LoginButton()
|
163 |
-
self.submit_btn = gr.Button("Submit", variant="primary")
|
164 |
-
self.submit_status = gr.HTML(label="Submission Status")
|
165 |
|
166 |
def render(self):
|
167 |
"""Create the Gradio interface."""
|
|
|
|
|
168 |
|
169 |
self.hidden_input = gr.Textbox(value="", visible=False, elem_id="hidden-index")
|
170 |
|
171 |
-
workflow = self.defaults["init_workflow"]
|
172 |
-
|
173 |
with gr.Row():
|
174 |
# Model Panel
|
175 |
with gr.Column(scale=1):
|
176 |
-
self._render_pipeline_interface(
|
177 |
|
178 |
with gr.Column(scale=1):
|
179 |
self._render_qb_interface()
|
180 |
|
181 |
self._setup_event_listeners()
|
182 |
|
|
|
|
|
183 |
def get_new_question_html(self, question_id: int) -> str:
|
184 |
"""Get the HTML for a new question."""
|
185 |
if question_id is None:
|
@@ -194,62 +218,89 @@ class TossupInterface:
|
|
194 |
except Exception as e:
|
195 |
return f"Error loading question: {str(e)}"
|
196 |
|
197 |
-
def get_model_outputs(
|
198 |
-
self, example: dict, pipeline_state: PipelineState, early_stop: bool
|
199 |
-
) -> list[ScoredTossupResult]:
|
200 |
-
"""Get the model outputs for a given question ID."""
|
201 |
-
question_runs = []
|
202 |
-
tokens = example["question"].split()
|
203 |
-
for run_idx in example["run_indices"]:
|
204 |
-
question_runs.append(" ".join(tokens[: run_idx + 1]))
|
205 |
-
agent = QuizBowlTossupAgent(pipeline_state.workflow)
|
206 |
-
outputs = list(agent.run(question_runs, early_stop=early_stop))
|
207 |
-
outputs = add_model_scores(outputs, example["clean_answers"], example["run_indices"])
|
208 |
-
return outputs
|
209 |
-
|
210 |
def get_pipeline_names(self, profile: gr.OAuthProfile | None) -> list[str]:
|
211 |
names = [UNSELECTED_PIPELINE_NAME] + populate.get_pipeline_names("tossup", profile)
|
212 |
return gr.update(choices=names, value=UNSELECTED_PIPELINE_NAME)
|
213 |
|
214 |
def load_pipeline(
|
215 |
self, model_name: str, pipeline_change: bool, profile: gr.OAuthProfile | None
|
216 |
-
) -> tuple[str,
|
217 |
try:
|
218 |
workflow = populate.load_workflow("tossup", model_name, profile)
|
219 |
if workflow is None:
|
220 |
logger.warning(f"Could not load workflow for {model_name}")
|
221 |
return UNSELECTED_PIPELINE_NAME, gr.skip(), gr.skip(), gr.update(visible=False)
|
222 |
pipeline_state_dict = TossupPipelineState.from_workflow(workflow).model_dump()
|
223 |
-
return UNSELECTED_PIPELINE_NAME,
|
224 |
except Exception as e:
|
225 |
logger.exception(e)
|
226 |
error_msg = styled_error(f"Error loading pipeline: {str(e)}")
|
227 |
return UNSELECTED_PIPELINE_NAME, gr.skip(), gr.skip(), gr.update(visible=True, value=error_msg)
|
228 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
def single_run(
|
230 |
self,
|
231 |
question_id: int,
|
232 |
state_dict: TossupPipelineStateDict,
|
233 |
early_stop: bool = True,
|
234 |
) -> tuple[str, Any, Any]:
|
235 |
-
"""Run the agent in tossup mode with a system prompt.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
try:
|
237 |
pipeline_state = validation.validate_tossup_workflow(state_dict)
|
|
|
238 |
# Validate inputs
|
239 |
question_id = int(question_id - 1)
|
240 |
if not self.ds or question_id < 0 or question_id >= len(self.ds):
|
241 |
raise gr.Error("Invalid question ID or dataset not loaded")
|
242 |
example = self.ds[question_id]
|
243 |
-
outputs = self.
|
244 |
|
245 |
# Process results and prepare visualization data
|
246 |
-
|
|
|
|
|
|
|
|
|
247 |
df = process_tossup_results(outputs)
|
248 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
return (
|
250 |
tokens_html,
|
251 |
gr.update(value=output_state),
|
252 |
-
gr.update(value=plot_data, label=f"Buzz Confidence on Question {question_id + 1}"),
|
253 |
gr.update(value=df, label=f"Model Outputs for Question {question_id + 1}", visible=True),
|
254 |
gr.update(value=step_outputs, label=f"Step Outputs for Question {question_id + 1}", visible=True),
|
255 |
gr.update(visible=False),
|
@@ -274,32 +325,17 @@ class TossupInterface:
|
|
274 |
if not self.ds or not self.ds.num_rows:
|
275 |
return "No dataset loaded", None, None
|
276 |
pipeline_state = validation.validate_tossup_workflow(state_dict)
|
277 |
-
|
278 |
-
correct_buzzes = 0
|
279 |
-
token_positions = []
|
280 |
-
correctness = []
|
281 |
for example in progress.tqdm(self.ds, desc="Evaluating tossup questions"):
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
token_positions.append(model_outputs[-1]["token_position"])
|
288 |
-
correctness.append(model_outputs[-1]["score"])
|
289 |
-
buzz_accuracy = correct_buzzes / buzz_counts
|
290 |
-
df = pd.DataFrame(
|
291 |
-
[
|
292 |
-
{
|
293 |
-
"Avg Buzz Position": f"{np.mean(token_positions):.2f}",
|
294 |
-
"Buzz Accuracy": f"{buzz_accuracy:.2%}",
|
295 |
-
"Total Score": f"{correct_buzzes}/{len(self.ds)}",
|
296 |
-
}
|
297 |
-
]
|
298 |
-
)
|
299 |
-
plot_data = create_scatter_pyplot(token_positions, correctness)
|
300 |
return (
|
301 |
-
gr.update(value=plot_data, label="Buzz Positions on Sample Set"),
|
302 |
-
gr.update(value=
|
|
|
303 |
gr.update(visible=False),
|
304 |
)
|
305 |
except Exception as e:
|
@@ -309,7 +345,8 @@ class TossupInterface:
|
|
309 |
return (
|
310 |
gr.skip(),
|
311 |
gr.update(visible=False),
|
312 |
-
gr.update(visible=
|
|
|
313 |
)
|
314 |
|
315 |
def submit_model(
|
@@ -327,6 +364,12 @@ class TossupInterface:
|
|
327 |
logger.exception(f"Error submitting model: {e.args}")
|
328 |
return styled_error(f"Error: {str(e)}")
|
329 |
|
|
|
|
|
|
|
|
|
|
|
|
|
330 |
def _setup_event_listeners(self):
|
331 |
gr.on(
|
332 |
triggers=[self.app.load, self.qid_selector.change],
|
@@ -341,20 +384,27 @@ class TossupInterface:
|
|
341 |
outputs=[self.pipeline_selector],
|
342 |
)
|
343 |
|
344 |
-
pipeline_state = self.pipeline_interface.pipeline_state
|
345 |
pipeline_change = self.pipeline_interface.pipeline_change
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
346 |
self.load_btn.click(
|
347 |
fn=self.load_pipeline,
|
348 |
inputs=[self.pipeline_selector, pipeline_change],
|
349 |
-
outputs=[self.pipeline_selector,
|
350 |
)
|
351 |
-
self.pipeline_interface.add_triggers_for_pipeline_export([pipeline_state.change], pipeline_state)
|
352 |
|
353 |
self.run_btn.click(
|
354 |
self.single_run,
|
355 |
inputs=[
|
356 |
self.qid_selector,
|
357 |
-
self.
|
358 |
self.early_stop_checkbox,
|
359 |
],
|
360 |
outputs=[
|
@@ -369,8 +419,8 @@ class TossupInterface:
|
|
369 |
|
370 |
self.eval_btn.click(
|
371 |
fn=self.evaluate,
|
372 |
-
inputs=[self.
|
373 |
-
outputs=[self.confidence_plot, self.results_table, self.error_display],
|
374 |
)
|
375 |
|
376 |
self.submit_btn.click(
|
@@ -378,7 +428,7 @@ class TossupInterface:
|
|
378 |
inputs=[
|
379 |
self.model_name_input,
|
380 |
self.description_input,
|
381 |
-
self.
|
382 |
],
|
383 |
outputs=[self.submit_status],
|
384 |
)
|
|
|
9 |
|
10 |
from app_configs import CONFIGS, UNSELECTED_PIPELINE_NAME
|
11 |
from components import commons
|
|
|
12 |
from components.model_pipeline.tossup_pipeline import TossupPipelineInterface, TossupPipelineState
|
13 |
+
from components.typed_dicts import TossupInterfaceDefaults, TossupPipelineStateDict
|
14 |
from display.formatting import styled_error
|
15 |
from submission import submit
|
16 |
+
from workflows import factory
|
17 |
from workflows.qb_agents import QuizBowlTossupAgent, TossupResult
|
|
|
18 |
|
19 |
from . import populate, validation
|
20 |
from .plotting import (
|
|
|
21 |
create_tossup_confidence_pyplot,
|
22 |
+
create_tossup_eval_dashboard,
|
23 |
+
create_tossup_eval_table,
|
24 |
create_tossup_html,
|
25 |
+
prepare_tossup_results_df,
|
26 |
)
|
27 |
from .utils import evaluate_prediction
|
28 |
|
|
|
53 |
logger.warning("No run indices provided, returning empty results")
|
54 |
return [], []
|
55 |
eval_points = []
|
56 |
+
for o in model_outputs:
|
57 |
+
token_position = run_indices[o["position"] - 1]
|
58 |
+
eval_points.append((token_position, o))
|
59 |
|
60 |
return eval_points
|
61 |
|
62 |
|
63 |
+
def initialize_eval_interface(
|
64 |
+
example: dict, model_outputs: list[dict], confidence_threshold: float, prob_threshold: float | None = None
|
65 |
+
):
|
66 |
"""Initialize the interface with example text."""
|
67 |
try:
|
68 |
tokens = example["question"].split()
|
|
|
73 |
|
74 |
if not tokens:
|
75 |
return "<div>No tokens found in the provided text.</div>", pd.DataFrame(), "{}"
|
|
|
76 |
html_content = create_tossup_html(tokens, answer, clean_answers, run_indices, eval_points)
|
77 |
+
plot_data = create_tossup_confidence_pyplot(tokens, eval_points, confidence_threshold, prob_threshold)
|
78 |
|
79 |
# Store tokens, values, and buzzes as JSON for later use
|
80 |
state = json.dumps({"tokens": tokens, "values": eval_points})
|
|
|
85 |
return f"<div>Error initializing interface: {str(e)}</div>", pd.DataFrame(), "{}"
|
86 |
|
87 |
|
88 |
+
def process_tossup_results(results: list[dict]) -> pd.DataFrame:
|
89 |
"""Process results from tossup mode and prepare visualization data."""
|
90 |
+
data = []
|
91 |
+
for r in results:
|
92 |
+
entry = {
|
93 |
+
"Token Position": r["token_position"],
|
94 |
+
"Correct?": "✅" if r["score"] == 1 else "❌",
|
95 |
+
"Confidence": r["confidence"],
|
96 |
+
}
|
97 |
+
if r["logprob"] is not None:
|
98 |
+
entry["Probability"] = f"{np.exp(r['logprob']):.3f}"
|
99 |
+
entry["Prediction"] = r["answer"]
|
100 |
+
data.append(entry)
|
101 |
+
return pd.DataFrame(data)
|
|
|
|
|
102 |
|
103 |
|
104 |
class TossupInterface:
|
105 |
"""Gradio interface for the Tossup mode."""
|
106 |
|
107 |
+
def __init__(
|
108 |
+
self,
|
109 |
+
app: gr.Blocks,
|
110 |
+
browser_state: gr.BrowserState,
|
111 |
+
dataset: Dataset,
|
112 |
+
model_options: dict,
|
113 |
+
defaults: TossupInterfaceDefaults,
|
114 |
+
):
|
115 |
"""Initialize the Tossup interface."""
|
116 |
logger.info(f"Initializing Tossup interface with dataset size: {len(dataset)}")
|
117 |
+
self.browser_state = browser_state
|
118 |
self.ds = dataset
|
119 |
self.model_options = model_options
|
120 |
self.app = app
|
|
|
122 |
self.output_state = gr.State(value="{}")
|
123 |
self.render()
|
124 |
|
125 |
+
# ------------------------------------- LOAD PIPELINE STATE FROM BROWSER STATE -------------------------------------
|
126 |
+
|
127 |
+
def load_presaved_pipeline_state(self, browser_state: dict, pipeline_change: bool):
|
128 |
+
logger.debug(f"Loading presaved pipeline state from browser state:\n{json.dumps(browser_state, indent=4)}")
|
129 |
+
try:
|
130 |
+
state_dict = browser_state["tossup"].get("pipeline_state", {})
|
131 |
+
pipeline_state = TossupPipelineState.model_validate(state_dict)
|
132 |
+
pipeline_state_dict = pipeline_state.model_dump()
|
133 |
+
output_state = browser_state["tossup"].get("output_state", "{}")
|
134 |
+
except Exception as e:
|
135 |
+
logger.warning(f"Error loading presaved pipeline state: {e}")
|
136 |
+
output_state = "{}"
|
137 |
+
workflow = self.defaults["init_workflow"]
|
138 |
+
pipeline_state_dict = TossupPipelineState.from_workflow(workflow).model_dump()
|
139 |
+
return browser_state, not pipeline_change, pipeline_state_dict, output_state
|
140 |
+
|
141 |
+
# ------------------------------------------ INTERFACE RENDER FUNCTIONS -------------------------------------------
|
142 |
+
|
143 |
+
def _render_pipeline_interface(self, pipeline_state: TossupPipelineState):
|
144 |
"""Render the model interface."""
|
145 |
with gr.Row(elem_classes="bonus-header-row form-inline"):
|
146 |
self.pipeline_selector = commons.get_pipeline_selector([])
|
|
|
148 |
self.import_error_display = gr.HTML(label="Import Error", elem_id="import-error-display", visible=False)
|
149 |
self.pipeline_interface = TossupPipelineInterface(
|
150 |
self.app,
|
151 |
+
pipeline_state.workflow,
|
152 |
+
ui_state=pipeline_state.ui_state,
|
153 |
model_options=list(self.model_options.keys()),
|
154 |
config=self.defaults,
|
155 |
)
|
|
|
181 |
with gr.Row():
|
182 |
self.eval_btn = gr.Button("Evaluate", variant="primary")
|
183 |
|
184 |
+
self.model_name_input, self.description_input, self.submit_btn, self.submit_status = (
|
185 |
+
commons.get_model_submission_accordion(self.app)
|
186 |
+
)
|
|
|
|
|
|
|
|
|
|
|
187 |
|
188 |
def render(self):
|
189 |
"""Create the Gradio interface."""
|
190 |
+
workflow = factory.create_empty_tossup_workflow()
|
191 |
+
pipeline_state = TossupPipelineState.from_workflow(workflow)
|
192 |
|
193 |
self.hidden_input = gr.Textbox(value="", visible=False, elem_id="hidden-index")
|
194 |
|
|
|
|
|
195 |
with gr.Row():
|
196 |
# Model Panel
|
197 |
with gr.Column(scale=1):
|
198 |
+
self._render_pipeline_interface(pipeline_state)
|
199 |
|
200 |
with gr.Column(scale=1):
|
201 |
self._render_qb_interface()
|
202 |
|
203 |
self._setup_event_listeners()
|
204 |
|
205 |
+
# ------------------------------------- Component Updates Functions ---------------------------------------------
|
206 |
+
|
207 |
def get_new_question_html(self, question_id: int) -> str:
|
208 |
"""Get the HTML for a new question."""
|
209 |
if question_id is None:
|
|
|
218 |
except Exception as e:
|
219 |
return f"Error loading question: {str(e)}"
|
220 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
def get_pipeline_names(self, profile: gr.OAuthProfile | None) -> list[str]:
|
222 |
names = [UNSELECTED_PIPELINE_NAME] + populate.get_pipeline_names("tossup", profile)
|
223 |
return gr.update(choices=names, value=UNSELECTED_PIPELINE_NAME)
|
224 |
|
225 |
def load_pipeline(
|
226 |
self, model_name: str, pipeline_change: bool, profile: gr.OAuthProfile | None
|
227 |
+
) -> tuple[str, bool, TossupPipelineStateDict, dict]:
|
228 |
try:
|
229 |
workflow = populate.load_workflow("tossup", model_name, profile)
|
230 |
if workflow is None:
|
231 |
logger.warning(f"Could not load workflow for {model_name}")
|
232 |
return UNSELECTED_PIPELINE_NAME, gr.skip(), gr.skip(), gr.update(visible=False)
|
233 |
pipeline_state_dict = TossupPipelineState.from_workflow(workflow).model_dump()
|
234 |
+
return UNSELECTED_PIPELINE_NAME, not pipeline_change, pipeline_state_dict, gr.update(visible=True)
|
235 |
except Exception as e:
|
236 |
logger.exception(e)
|
237 |
error_msg = styled_error(f"Error loading pipeline: {str(e)}")
|
238 |
return UNSELECTED_PIPELINE_NAME, gr.skip(), gr.skip(), gr.update(visible=True, value=error_msg)
|
239 |
|
240 |
+
# ------------------------------------- Agent Functions -----------------------------------------------------------
|
241 |
+
def get_agent_outputs(
|
242 |
+
self, example: dict, pipeline_state: TossupPipelineState, early_stop: bool
|
243 |
+
) -> list[ScoredTossupResult]:
|
244 |
+
"""Get the model outputs for a given question ID."""
|
245 |
+
question_runs = []
|
246 |
+
tokens = example["question"].split()
|
247 |
+
for run_idx in example["run_indices"]:
|
248 |
+
question_runs.append(" ".join(tokens[: run_idx + 1]))
|
249 |
+
agent = QuizBowlTossupAgent(pipeline_state.workflow)
|
250 |
+
outputs = list(agent.run(question_runs, early_stop=early_stop))
|
251 |
+
outputs = add_model_scores(outputs, example["clean_answers"], example["run_indices"])
|
252 |
+
return outputs
|
253 |
+
|
254 |
def single_run(
|
255 |
self,
|
256 |
question_id: int,
|
257 |
state_dict: TossupPipelineStateDict,
|
258 |
early_stop: bool = True,
|
259 |
) -> tuple[str, Any, Any]:
|
260 |
+
"""Run the agent in tossup mode with a system prompt.
|
261 |
+
|
262 |
+
Returns:
|
263 |
+
tuple: A tuple containing:
|
264 |
+
- tokens_html (str): HTML representation of the tossup question with buzz indicators
|
265 |
+
- output_state (gr.update): Update for the output state component
|
266 |
+
- plot_data (gr.update): Update for the confidence plot with label and visibility
|
267 |
+
- df (gr.update): Update for the dataframe component showing model outputs
|
268 |
+
- step_outputs (gr.update): Update for the step outputs component
|
269 |
+
- error_msg (gr.update): Update for the error message component (hidden if no errors)
|
270 |
+
"""
|
271 |
+
|
272 |
try:
|
273 |
pipeline_state = validation.validate_tossup_workflow(state_dict)
|
274 |
+
workflow = pipeline_state.workflow
|
275 |
# Validate inputs
|
276 |
question_id = int(question_id - 1)
|
277 |
if not self.ds or question_id < 0 or question_id >= len(self.ds):
|
278 |
raise gr.Error("Invalid question ID or dataset not loaded")
|
279 |
example = self.ds[question_id]
|
280 |
+
outputs = self.get_agent_outputs(example, pipeline_state, early_stop)
|
281 |
|
282 |
# Process results and prepare visualization data
|
283 |
+
confidence_threshold = workflow.buzzer.confidence_threshold
|
284 |
+
prob_threshold = workflow.buzzer.prob_threshold
|
285 |
+
tokens_html, plot_data, output_state = initialize_eval_interface(
|
286 |
+
example, outputs, confidence_threshold, prob_threshold
|
287 |
+
)
|
288 |
df = process_tossup_results(outputs)
|
289 |
+
tokens = example["question"].split()
|
290 |
+
step_outputs = {}
|
291 |
+
for output in outputs:
|
292 |
+
pos = output["token_position"]
|
293 |
+
token = tokens[pos - 1]
|
294 |
+
key = f"{pos}:{token}"
|
295 |
+
step_outputs[key] = {k: v for k, v in output["step_outputs"].items() if k not in workflow.inputs}
|
296 |
+
if output["logprob"] is not None:
|
297 |
+
step_outputs[key]["logprob"] = output["logprob"]
|
298 |
+
step_outputs[key]["prob"] = float(np.exp(output["logprob"]))
|
299 |
+
|
300 |
return (
|
301 |
tokens_html,
|
302 |
gr.update(value=output_state),
|
303 |
+
gr.update(value=plot_data, label=f"Buzz Confidence on Question {question_id + 1}", show_label=True),
|
304 |
gr.update(value=df, label=f"Model Outputs for Question {question_id + 1}", visible=True),
|
305 |
gr.update(value=step_outputs, label=f"Step Outputs for Question {question_id + 1}", visible=True),
|
306 |
gr.update(visible=False),
|
|
|
325 |
if not self.ds or not self.ds.num_rows:
|
326 |
return "No dataset loaded", None, None
|
327 |
pipeline_state = validation.validate_tossup_workflow(state_dict)
|
328 |
+
model_outputs = []
|
|
|
|
|
|
|
329 |
for example in progress.tqdm(self.ds, desc="Evaluating tossup questions"):
|
330 |
+
run_outputs = self.get_agent_outputs(example, pipeline_state, early_stop=True)
|
331 |
+
model_outputs.append(run_outputs)
|
332 |
+
eval_df = prepare_tossup_results_df(self.ds["run_indices"], model_outputs)
|
333 |
+
plot_data = create_tossup_eval_dashboard(self.ds["run_indices"], eval_df)
|
334 |
+
output_df = create_tossup_eval_table(eval_df)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
335 |
return (
|
336 |
+
gr.update(value=plot_data, label="Buzz Positions on Sample Set", show_label=False),
|
337 |
+
gr.update(value=output_df, label="(Mean) Metrics on Sample Set", visible=True),
|
338 |
+
gr.update(visible=False),
|
339 |
gr.update(visible=False),
|
340 |
)
|
341 |
except Exception as e:
|
|
|
345 |
return (
|
346 |
gr.skip(),
|
347 |
gr.update(visible=False),
|
348 |
+
gr.update(visible=False),
|
349 |
+
gr.update(visible=True, value=styled_error(f"Error: {str(e)}")),
|
350 |
)
|
351 |
|
352 |
def submit_model(
|
|
|
364 |
logger.exception(f"Error submitting model: {e.args}")
|
365 |
return styled_error(f"Error: {str(e)}")
|
366 |
|
367 |
+
@property
|
368 |
+
def pipeline_state(self):
|
369 |
+
return self.pipeline_interface.pipeline_state
|
370 |
+
|
371 |
+
# ------------------------------------- Event Listeners -----------------------------------------------------------
|
372 |
+
|
373 |
def _setup_event_listeners(self):
|
374 |
gr.on(
|
375 |
triggers=[self.app.load, self.qid_selector.change],
|
|
|
384 |
outputs=[self.pipeline_selector],
|
385 |
)
|
386 |
|
|
|
387 |
pipeline_change = self.pipeline_interface.pipeline_change
|
388 |
+
|
389 |
+
gr.on(
|
390 |
+
triggers=[self.app.load],
|
391 |
+
fn=self.load_presaved_pipeline_state,
|
392 |
+
inputs=[self.browser_state, pipeline_change],
|
393 |
+
outputs=[self.browser_state, pipeline_change, self.pipeline_state, self.output_state],
|
394 |
+
)
|
395 |
+
|
396 |
self.load_btn.click(
|
397 |
fn=self.load_pipeline,
|
398 |
inputs=[self.pipeline_selector, pipeline_change],
|
399 |
+
outputs=[self.pipeline_selector, pipeline_change, self.pipeline_state, self.import_error_display],
|
400 |
)
|
401 |
+
self.pipeline_interface.add_triggers_for_pipeline_export([self.pipeline_state.change], self.pipeline_state)
|
402 |
|
403 |
self.run_btn.click(
|
404 |
self.single_run,
|
405 |
inputs=[
|
406 |
self.qid_selector,
|
407 |
+
self.pipeline_state,
|
408 |
self.early_stop_checkbox,
|
409 |
],
|
410 |
outputs=[
|
|
|
419 |
|
420 |
self.eval_btn.click(
|
421 |
fn=self.evaluate,
|
422 |
+
inputs=[self.pipeline_state],
|
423 |
+
outputs=[self.confidence_plot, self.results_table, self.model_outputs_display, self.error_display],
|
424 |
)
|
425 |
|
426 |
self.submit_btn.click(
|
|
|
428 |
inputs=[
|
429 |
self.model_name_input,
|
430 |
self.description_input,
|
431 |
+
self.pipeline_state,
|
432 |
],
|
433 |
outputs=[self.submit_status],
|
434 |
)
|
src/components/quizbowl/utils.py
CHANGED
@@ -14,7 +14,6 @@ def evaluate_prediction(prediction: str, clean_answers: list[str] | str) -> int:
|
|
14 |
for answer in clean_answers:
|
15 |
answer = answer.strip().lower()
|
16 |
if answer and answer in pred:
|
17 |
-
print(f"Found {answer} in {pred}")
|
18 |
return 1
|
19 |
return 0
|
20 |
|
|
|
14 |
for answer in clean_answers:
|
15 |
answer = answer.strip().lower()
|
16 |
if answer and answer in pred:
|
|
|
17 |
return 1
|
18 |
return 0
|
19 |
|
src/components/typed_dicts.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union
|
2 |
|
|
|
|
|
3 |
|
4 |
# TypedDicts for workflows/structs.py
|
5 |
class InputFieldDict(TypedDict):
|
@@ -62,3 +64,17 @@ class PipelineStateDict(TypedDict):
|
|
62 |
|
63 |
class TossupPipelineStateDict(PipelineStateDict):
|
64 |
workflow: TossupWorkflowDict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union
|
2 |
|
3 |
+
from workflows.structs import TossupWorkflow, Workflow
|
4 |
+
|
5 |
|
6 |
# TypedDicts for workflows/structs.py
|
7 |
class InputFieldDict(TypedDict):
|
|
|
64 |
|
65 |
class TossupPipelineStateDict(PipelineStateDict):
|
66 |
workflow: TossupWorkflowDict
|
67 |
+
|
68 |
+
|
69 |
+
class PipelineInterfaceDefaults(TypedDict):
|
70 |
+
init_workflow: Workflow
|
71 |
+
simple_workflow: bool
|
72 |
+
model: str
|
73 |
+
temperature: float
|
74 |
+
max_temperature: float
|
75 |
+
|
76 |
+
|
77 |
+
class TossupInterfaceDefaults(PipelineInterfaceDefaults):
|
78 |
+
init_workflow: TossupWorkflow
|
79 |
+
confidence_threshold: float
|
80 |
+
early_stop: bool
|