Maharshi Gor commited on
Commit
f9589f4
·
1 Parent(s): cbf7344

Refactored validation code in bonus/tossup interface.

Browse files
src/components/quizbowl/bonus.py CHANGED
@@ -6,7 +6,7 @@ import pandas as pd
6
  from datasets import Dataset
7
  from loguru import logger
8
 
9
- from app_configs import UNSELECTED_PIPELINE_NAME
10
  from components import commons
11
  from components.model_pipeline.model_pipeline import PipelineInterface, PipelineState, PipelineUIState
12
  from components.typed_dicts import PipelineStateDict
@@ -15,13 +15,8 @@ from submission import submit
15
  from workflows.qb_agents import QuizBowlBonusAgent
16
  from workflows.structs import ModelStep, Workflow
17
 
18
- from . import populate
19
- from .plotting import (
20
- create_bonus_confidence_plot,
21
- create_bonus_html,
22
- create_scatter_pyplot,
23
- update_tossup_plot,
24
- )
25
  from .utils import evaluate_prediction
26
 
27
 
@@ -58,58 +53,6 @@ def initialize_eval_interface(example: dict, model_outputs: list[dict]):
58
  return f"<div>Error initializing interface: {str(e)}</div>", pd.DataFrame(), "{}"
59
 
60
 
61
- def validate_workflow(workflow: Workflow):
62
- """Validate that a workflow is properly configured for the bonus task."""
63
- if not workflow.steps:
64
- raise ValueError("Workflow must have at least one step")
65
-
66
- # Ensure all steps are properly configured
67
- for step_id, step in workflow.steps.items():
68
- validate_model_step(step)
69
-
70
- # Check that the workflow has the correct structure
71
- input_vars = set(workflow.inputs)
72
- if "leadin" not in input_vars or "part" not in input_vars:
73
- raise ValueError("Workflow must have 'leadin' and 'part' as inputs")
74
-
75
- output_vars = set(workflow.outputs)
76
- if not all(var in output_vars for var in ["answer", "confidence", "explanation"]):
77
- raise ValueError("Workflow must produce 'answer', 'confidence', and 'explanation' as outputs")
78
-
79
-
80
- def validate_model_step(model_step: ModelStep):
81
- """Validate that a model step is properly configured for the bonus task."""
82
- # Check required fields
83
- if not model_step.model or not model_step.provider:
84
- raise ValueError("Model step must have both model and provider specified")
85
-
86
- if model_step.call_type != "llm":
87
- raise ValueError("Model step must have call_type 'llm'")
88
-
89
- # Validate temperature for LLM steps
90
- if model_step.temperature is None:
91
- raise ValueError("Temperature must be specified for LLM model steps")
92
-
93
- if not (0.0 <= model_step.temperature <= 1.0):
94
- raise ValueError(f"Temperature must be between 0.0 and 1.0, got {model_step.temperature}")
95
-
96
- # Validate input fields
97
- input_field_names = {field.name for field in model_step.input_fields}
98
- if "leadin" not in input_field_names or "part" not in input_field_names:
99
- raise ValueError("Model step must have 'leadin' and 'part' input fields")
100
-
101
- # Validate output fields
102
- output_field_names = {field.name for field in model_step.output_fields}
103
- required_outputs = {"answer", "confidence", "explanation"}
104
- if not all(out in output_field_names for out in required_outputs):
105
- raise ValueError("Model step must have all required output fields: answer, confidence, explanation")
106
-
107
- # Validate confidence output field is of type float
108
- for field in model_step.output_fields:
109
- if field.name == "confidence" and field.type != "float":
110
- raise ValueError("The 'confidence' output field must be of type 'float'")
111
-
112
-
113
  class BonusInterface:
114
  """Gradio interface for the Bonus mode."""
115
 
@@ -128,11 +71,12 @@ class BonusInterface:
128
  with gr.Row(elem_classes="bonus-header-row form-inline"):
129
  self.pipeline_selector = commons.get_pipeline_selector([])
130
  self.load_btn = gr.Button("⬇️ Import Pipeline", variant="secondary")
 
131
  self.pipeline_interface = PipelineInterface(
132
  self.app,
133
  workflow,
134
- simple=simple,
135
  model_options=list(self.model_options.keys()),
 
136
  )
137
 
138
  def _render_qb_interface(self):
@@ -177,6 +121,18 @@ class BonusInterface:
177
 
178
  self._setup_event_listeners()
179
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  def get_new_question_html(self, question_id: int):
181
  """Get the HTML for a new question."""
182
  if question_id is None:
@@ -237,10 +193,10 @@ class BonusInterface:
237
  ) -> tuple[str, Any, Any]:
238
  """Run the agent in bonus mode."""
239
  try:
240
- pipeline_state = PipelineState(**state_dict)
241
  question_id = int(question_id - 1)
242
  if not self.ds or question_id < 0 or question_id >= len(self.ds):
243
- return "Invalid question ID or dataset not loaded", None, None
244
 
245
  example = self.ds[question_id]
246
  outputs = self.get_model_outputs(example, pipeline_state)
@@ -272,7 +228,7 @@ class BonusInterface:
272
  def evaluate(self, state_dict: PipelineStateDict, progress: gr.Progress = gr.Progress()):
273
  """Evaluate the bonus questions."""
274
  try:
275
- pipeline_state = PipelineState(**state_dict)
276
  # Validate inputs
277
  if not self.ds or not self.ds.num_rows:
278
  return "No dataset loaded", None, None
@@ -345,15 +301,11 @@ class BonusInterface:
345
  self.load_btn.click(
346
  fn=self.load_pipeline,
347
  inputs=[self.pipeline_selector, pipeline_change],
348
- outputs=[self.pipeline_selector, pipeline_state, pipeline_change, self.error_display],
349
  )
350
  self.pipeline_interface.add_triggers_for_pipeline_export([pipeline_state.change], pipeline_state)
351
 
352
  self.run_btn.click(
353
- self.pipeline_interface.validate_workflow,
354
- inputs=[self.pipeline_interface.pipeline_state],
355
- outputs=[],
356
- ).success(
357
  self.single_run,
358
  inputs=[
359
  self.qid_selector,
 
6
  from datasets import Dataset
7
  from loguru import logger
8
 
9
+ from app_configs import CONFIGS, UNSELECTED_PIPELINE_NAME
10
  from components import commons
11
  from components.model_pipeline.model_pipeline import PipelineInterface, PipelineState, PipelineUIState
12
  from components.typed_dicts import PipelineStateDict
 
15
  from workflows.qb_agents import QuizBowlBonusAgent
16
  from workflows.structs import ModelStep, Workflow
17
 
18
+ from . import populate, validation
19
+ from .plotting import create_bonus_confidence_plot, create_bonus_html
 
 
 
 
 
20
  from .utils import evaluate_prediction
21
 
22
 
 
53
  return f"<div>Error initializing interface: {str(e)}</div>", pd.DataFrame(), "{}"
54
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  class BonusInterface:
57
  """Gradio interface for the Bonus mode."""
58
 
 
71
  with gr.Row(elem_classes="bonus-header-row form-inline"):
72
  self.pipeline_selector = commons.get_pipeline_selector([])
73
  self.load_btn = gr.Button("⬇️ Import Pipeline", variant="secondary")
74
+ self.import_error_display = gr.HTML(label="Import Error", elem_id="import-error-display", visible=False)
75
  self.pipeline_interface = PipelineInterface(
76
  self.app,
77
  workflow,
 
78
  model_options=list(self.model_options.keys()),
79
+ config=self.defaults,
80
  )
81
 
82
  def _render_qb_interface(self):
 
121
 
122
  self._setup_event_listeners()
123
 
124
+ def validate_workflow(self, state_dict: PipelineStateDict):
125
+ """Validate the workflow."""
126
+ try:
127
+ pipeline_state = PipelineState(**state_dict)
128
+ validation.validate_workflow(
129
+ pipeline_state.workflow,
130
+ required_input_vars=CONFIGS["bonus"]["required_input_vars"],
131
+ required_output_vars=CONFIGS["bonus"]["required_output_vars"],
132
+ )
133
+ except Exception as e:
134
+ raise gr.Error(f"Error validating workflow: {str(e)}")
135
+
136
  def get_new_question_html(self, question_id: int):
137
  """Get the HTML for a new question."""
138
  if question_id is None:
 
193
  ) -> tuple[str, Any, Any]:
194
  """Run the agent in bonus mode."""
195
  try:
196
+ pipeline_state = validation.validate_bonus_workflow(state_dict)
197
  question_id = int(question_id - 1)
198
  if not self.ds or question_id < 0 or question_id >= len(self.ds):
199
+ raise gr.Error("Invalid question ID or dataset not loaded")
200
 
201
  example = self.ds[question_id]
202
  outputs = self.get_model_outputs(example, pipeline_state)
 
228
  def evaluate(self, state_dict: PipelineStateDict, progress: gr.Progress = gr.Progress()):
229
  """Evaluate the bonus questions."""
230
  try:
231
+ pipeline_state = validation.validate_bonus_workflow(state_dict)
232
  # Validate inputs
233
  if not self.ds or not self.ds.num_rows:
234
  return "No dataset loaded", None, None
 
301
  self.load_btn.click(
302
  fn=self.load_pipeline,
303
  inputs=[self.pipeline_selector, pipeline_change],
304
+ outputs=[self.pipeline_selector, pipeline_state, pipeline_change, self.import_error_display],
305
  )
306
  self.pipeline_interface.add_triggers_for_pipeline_export([pipeline_state.change], pipeline_state)
307
 
308
  self.run_btn.click(
 
 
 
 
309
  self.single_run,
310
  inputs=[
311
  self.qid_selector,
src/components/quizbowl/tossup.py CHANGED
@@ -7,7 +7,7 @@ import pandas as pd
7
  from datasets import Dataset
8
  from loguru import logger
9
 
10
- from app_configs import UNSELECTED_PIPELINE_NAME
11
  from components import commons
12
  from components.model_pipeline.model_pipeline import PipelineInterface, PipelineState, PipelineUIState
13
  from components.model_pipeline.tossup_pipeline import TossupPipelineInterface, TossupPipelineState
@@ -17,7 +17,7 @@ from submission import submit
17
  from workflows.qb_agents import QuizBowlTossupAgent, TossupResult
18
  from workflows.structs import ModelStep, TossupWorkflow
19
 
20
- from . import populate
21
  from .plotting import (
22
  create_scatter_pyplot,
23
  create_tossup_confidence_pyplot,
@@ -101,77 +101,6 @@ def process_tossup_results(results: list[dict], top_k_mode: bool = False) -> pd.
101
  )
102
 
103
 
104
- def validate_workflow(workflow: TossupWorkflow):
105
- """
106
- Validate that a workflow is properly configured for the tossup task.
107
-
108
- Args:
109
- workflow (TossupWorkflow): The workflow to validate
110
-
111
- Raises:
112
- ValueError: If the workflow is not properly configured
113
- """
114
- if not workflow.steps:
115
- raise ValueError("Workflow must have at least one step")
116
-
117
- # Ensure all steps are properly configured
118
- for step_id, step in workflow.steps.items():
119
- validate_model_step(step)
120
-
121
- # Check that the workflow has the correct structure
122
- input_vars = set(workflow.inputs)
123
- if "question" not in input_vars:
124
- raise ValueError("Workflow must have 'question' as an input")
125
-
126
- output_vars = set(workflow.outputs)
127
- if not any("answer" in out_var for out_var in output_vars):
128
- raise ValueError("Workflow must produce an 'answer' as output")
129
- if not any("confidence" in out_var for out_var in output_vars):
130
- raise ValueError("Workflow must produce a 'confidence' score as output")
131
-
132
-
133
- def validate_model_step(model_step: ModelStep):
134
- """
135
- Validate that a model step is properly configured for the tossup task.
136
-
137
- Args:
138
- model_step (ModelStep): The model step to validate
139
-
140
- Raises:
141
- ValueError: If the model step is not properly configured
142
- """
143
- # Check required fields
144
- if not model_step.model or not model_step.provider:
145
- raise ValueError("Model step must have both model and provider specified")
146
-
147
- if model_step.call_type != "llm":
148
- raise ValueError("Model step must have call_type 'llm'")
149
-
150
- # Validate temperature for LLM steps
151
- if model_step.temperature is None:
152
- raise ValueError("Temperature must be specified for LLM model steps")
153
-
154
- if not (0.0 <= model_step.temperature <= 1.0):
155
- raise ValueError(f"Temperature must be between 0.0 and 1.0, got {model_step.temperature}")
156
-
157
- # Validate input fields
158
- input_field_names = {field.name for field in model_step.input_fields}
159
- if "question" not in input_field_names:
160
- raise ValueError("Model step must have a 'question' input field")
161
-
162
- # Validate output fields
163
- output_field_names = {field.name for field in model_step.output_fields}
164
- if "answer" not in output_field_names:
165
- raise ValueError("Model step must have an 'answer' output field")
166
- if "confidence" not in output_field_names:
167
- raise ValueError("Model step must have a 'confidence' output field")
168
-
169
- # Validate confidence output field is of type float
170
- for field in model_step.output_fields:
171
- if field.name == "confidence" and field.type != "float":
172
- raise ValueError("The 'confidence' output field must be of type 'float'")
173
-
174
-
175
  class TossupInterface:
176
  """Gradio interface for the Tossup mode."""
177
 
@@ -190,12 +119,12 @@ class TossupInterface:
190
  with gr.Row(elem_classes="bonus-header-row form-inline"):
191
  self.pipeline_selector = commons.get_pipeline_selector([])
192
  self.load_btn = gr.Button("⬇️ Import Pipeline", variant="secondary")
 
193
  self.pipeline_interface = TossupPipelineInterface(
194
  self.app,
195
  workflow,
196
- simple=simple,
197
  model_options=list(self.model_options.keys()),
198
- defaults=self.defaults,
199
  )
200
 
201
  def _render_qb_interface(self):
@@ -251,14 +180,6 @@ class TossupInterface:
251
 
252
  self._setup_event_listeners()
253
 
254
- def validate_workflow(self, state_dict: TossupPipelineStateDict):
255
- """Validate the workflow."""
256
- try:
257
- pipeline_state = TossupPipelineState(**state_dict)
258
- validate_workflow(pipeline_state.workflow)
259
- except Exception as e:
260
- raise gr.Error(f"Error validating workflow: {str(e)}")
261
-
262
  def get_new_question_html(self, question_id: int) -> str:
263
  """Get the HTML for a new question."""
264
  if question_id is None:
@@ -313,12 +234,12 @@ class TossupInterface:
313
  ) -> tuple[str, Any, Any]:
314
  """Run the agent in tossup mode with a system prompt."""
315
  try:
 
316
  # Validate inputs
317
  question_id = int(question_id - 1)
318
  if not self.ds or question_id < 0 or question_id >= len(self.ds):
319
- return "Invalid question ID or dataset not loaded", None, None
320
  example = self.ds[question_id]
321
- pipeline_state = TossupPipelineState(**state_dict)
322
  outputs = self.get_model_outputs(example, pipeline_state, early_stop)
323
 
324
  # Process results and prepare visualization data
@@ -352,7 +273,7 @@ class TossupInterface:
352
  # Validate inputs
353
  if not self.ds or not self.ds.num_rows:
354
  return "No dataset loaded", None, None
355
- pipeline_state = TossupPipelineState(**state_dict)
356
  buzz_counts = 0
357
  correct_buzzes = 0
358
  token_positions = []
@@ -397,10 +318,14 @@ class TossupInterface:
397
  description: str,
398
  state_dict: TossupPipelineStateDict,
399
  profile: gr.OAuthProfile = None,
400
- ):
401
  """Submit the model output."""
402
- pipeline_state = TossupPipelineState(**state_dict)
403
- return submit.submit_model(model_name, description, pipeline_state.workflow, "tossup", profile)
 
 
 
 
404
 
405
  def _setup_event_listeners(self):
406
  gr.on(
@@ -421,15 +346,11 @@ class TossupInterface:
421
  self.load_btn.click(
422
  fn=self.load_pipeline,
423
  inputs=[self.pipeline_selector, pipeline_change],
424
- outputs=[self.pipeline_selector, pipeline_state, pipeline_change, self.error_display],
425
  )
426
  self.pipeline_interface.add_triggers_for_pipeline_export([pipeline_state.change], pipeline_state)
427
 
428
  self.run_btn.click(
429
- self.pipeline_interface.validate_workflow,
430
- inputs=[self.pipeline_interface.pipeline_state],
431
- outputs=[],
432
- ).success(
433
  self.single_run,
434
  inputs=[
435
  self.qid_selector,
 
7
  from datasets import Dataset
8
  from loguru import logger
9
 
10
+ from app_configs import CONFIGS, UNSELECTED_PIPELINE_NAME
11
  from components import commons
12
  from components.model_pipeline.model_pipeline import PipelineInterface, PipelineState, PipelineUIState
13
  from components.model_pipeline.tossup_pipeline import TossupPipelineInterface, TossupPipelineState
 
17
  from workflows.qb_agents import QuizBowlTossupAgent, TossupResult
18
  from workflows.structs import ModelStep, TossupWorkflow
19
 
20
+ from . import populate, validation
21
  from .plotting import (
22
  create_scatter_pyplot,
23
  create_tossup_confidence_pyplot,
 
101
  )
102
 
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  class TossupInterface:
105
  """Gradio interface for the Tossup mode."""
106
 
 
119
  with gr.Row(elem_classes="bonus-header-row form-inline"):
120
  self.pipeline_selector = commons.get_pipeline_selector([])
121
  self.load_btn = gr.Button("⬇️ Import Pipeline", variant="secondary")
122
+ self.import_error_display = gr.HTML(label="Import Error", elem_id="import-error-display", visible=False)
123
  self.pipeline_interface = TossupPipelineInterface(
124
  self.app,
125
  workflow,
 
126
  model_options=list(self.model_options.keys()),
127
+ config=self.defaults,
128
  )
129
 
130
  def _render_qb_interface(self):
 
180
 
181
  self._setup_event_listeners()
182
 
 
 
 
 
 
 
 
 
183
  def get_new_question_html(self, question_id: int) -> str:
184
  """Get the HTML for a new question."""
185
  if question_id is None:
 
234
  ) -> tuple[str, Any, Any]:
235
  """Run the agent in tossup mode with a system prompt."""
236
  try:
237
+ pipeline_state = validation.validate_tossup_workflow(state_dict)
238
  # Validate inputs
239
  question_id = int(question_id - 1)
240
  if not self.ds or question_id < 0 or question_id >= len(self.ds):
241
+ raise gr.Error("Invalid question ID or dataset not loaded")
242
  example = self.ds[question_id]
 
243
  outputs = self.get_model_outputs(example, pipeline_state, early_stop)
244
 
245
  # Process results and prepare visualization data
 
273
  # Validate inputs
274
  if not self.ds or not self.ds.num_rows:
275
  return "No dataset loaded", None, None
276
+ pipeline_state = validation.validate_tossup_workflow(state_dict)
277
  buzz_counts = 0
278
  correct_buzzes = 0
279
  token_positions = []
 
318
  description: str,
319
  state_dict: TossupPipelineStateDict,
320
  profile: gr.OAuthProfile = None,
321
+ ) -> str:
322
  """Submit the model output."""
323
+ try:
324
+ pipeline_state = validation.validate_tossup_workflow(state_dict)
325
+ return submit.submit_model(model_name, description, pipeline_state.workflow, "tossup", profile)
326
+ except Exception as e:
327
+ logger.exception(f"Error submitting model: {e.args}")
328
+ return styled_error(f"Error: {str(e)}")
329
 
330
  def _setup_event_listeners(self):
331
  gr.on(
 
346
  self.load_btn.click(
347
  fn=self.load_pipeline,
348
  inputs=[self.pipeline_selector, pipeline_change],
349
+ outputs=[self.pipeline_selector, pipeline_state, pipeline_change, self.import_error_display],
350
  )
351
  self.pipeline_interface.add_triggers_for_pipeline_export([pipeline_state.change], pipeline_state)
352
 
353
  self.run_btn.click(
 
 
 
 
354
  self.single_run,
355
  inputs=[
356
  self.qid_selector,
src/components/quizbowl/validation.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from app_configs import CONFIGS
2
+ from components.structs import PipelineState, TossupPipelineState
3
+ from components.typed_dicts import PipelineStateDict, TossupPipelineStateDict
4
+ from workflows.structs import TossupWorkflow, Workflow
5
+ from workflows.validators import WorkflowValidator
6
+
7
+
8
+ def validate_workflow(
9
+ workflow: TossupWorkflow | Workflow, required_input_vars: list[str], required_output_vars: list[str]
10
+ ):
11
+ """
12
+ Validate that a workflow is properly configured for the tossup task.
13
+
14
+ Args:
15
+ workflow (TossupWorkflow): The workflow to validate
16
+
17
+ Raises:
18
+ ValueError: If the workflow is not properly configured
19
+ """
20
+ if not workflow.steps:
21
+ raise ValueError("Workflow must have at least one step")
22
+
23
+ # Check that the workflow has the correct structure
24
+ input_vars = set(workflow.inputs)
25
+ for req_var in required_input_vars:
26
+ if req_var not in input_vars:
27
+ raise ValueError(f"Workflow must have '{req_var}' as an input")
28
+
29
+ output_vars = set(workflow.outputs)
30
+ for req_var in required_output_vars:
31
+ if req_var not in output_vars:
32
+ raise ValueError(f"Workflow must produce '{req_var}' as an output")
33
+
34
+ # Ensure all steps are properly configured
35
+ WorkflowValidator().validate(workflow)
36
+
37
+
38
+ def validate_tossup_workflow(pipeline_state_dict: TossupPipelineStateDict) -> TossupPipelineState:
39
+ pipeline_state = TossupPipelineState(**pipeline_state_dict)
40
+ validate_workflow(
41
+ pipeline_state.workflow,
42
+ CONFIGS["tossup"]["required_input_vars"],
43
+ CONFIGS["tossup"]["required_output_vars"],
44
+ )
45
+ return pipeline_state
46
+
47
+
48
+ def validate_bonus_workflow(pipeline_state_dict: PipelineStateDict):
49
+ pipeline_state = PipelineState(**pipeline_state_dict)
50
+ validate_workflow(
51
+ pipeline_state.workflow,
52
+ CONFIGS["bonus"]["required_input_vars"],
53
+ CONFIGS["bonus"]["required_output_vars"],
54
+ )
55
+ return pipeline_state