Maharshi Gor commited on
Commit
193db9d
·
1 Parent(s): a562808

First Working commit

Browse files
.gitignore ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Purpose: Ignore certain files and directories in git
2
+ *.pyc
3
+ *.pyi
4
+ *.pyo
5
+ *.pyd
6
+ *.pyw
7
+ *.pyz
8
+ *.pywz
9
+ *.pywz
10
+
11
+ auto_evals/
12
+ venv/
13
+ __pycache__/
14
+ .env
15
+ .ipynb_checkpoints
16
+ *ipynb
17
+ .vscode/
18
+
19
+ eval-queue/
20
+ eval-results/
21
+ eval-queue-bk/
22
+ eval-results-bk/
23
+ logs/
Makefile ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .PHONY: style format
2
+
3
+
4
+ style:
5
+ python -m black --line-length 119 .
6
+ python -m isort .
7
+ ruff check --fix .
8
+
9
+
10
+ quality:
11
+ python -m black --check --line-length 119 .
12
+ python -m isort --check-only .
13
+ ruff check .
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datasets
2
+ import gradio as gr
3
+
4
+ from components.quizbowl.bonus import BonusInterface
5
+ from components.quizbowl.tossup import TossupInterface
6
+ from display.custom_css import css_pipeline, css_tossup
7
+
8
+ # Constants
9
+ from envs import AVAILABLE_MODELS, DEFAULT_SELECTIONS, PLAYGROUND_DATASET_NAMES, THEME
10
+ from workflows import factory
11
+
12
+ js_preamble = """
13
+ <link href="https://fonts.cdnfonts.com/css/roboto-mono" rel="stylesheet">
14
+
15
+ <script>
16
+ const gradioApp = document.getElementsByTagName('gradio-app')[0];
17
+ console.log("Gradio app:", gradioApp);
18
+ console.log(gradioApp.querySelectorAll('.token'));
19
+ console.log(document.querySelectorAll('.token'));
20
+
21
+ // Function to trigger Python callback
22
+ const setHiddenIndex = (index) => {
23
+ console.log("Setting hidden index to:", index);
24
+ const hiddenIndex = gradioApp.querySelector("#hidden-index textarea");
25
+ if (hiddenIndex) {
26
+ hiddenIndex.value = index;
27
+ let event = new Event("input", { bubbles: true});
28
+ Object.defineProperty(event, "target", { value: hiddenIndex});
29
+ hiddenIndex.dispatchEvent(event);
30
+ }
31
+ };
32
+
33
+ // Add event listeners to all tokens
34
+ function setupTokenListeners() {
35
+ const tokens = gradioApp.querySelectorAll('.token');
36
+ console.log("Tokens:", tokens);
37
+ tokens.forEach(token => {
38
+ token.addEventListener('mouseover', function() {
39
+ const index = parseInt(this.getAttribute('data-index'));
40
+ console.log("Mouseover token index:", index);
41
+
42
+ // Reset all tokens
43
+ gradioApp.querySelectorAll('.token').forEach(el => {
44
+ el.classList.remove('highlighted');
45
+ });
46
+
47
+ // Highlight this token
48
+ this.classList.add('highlighted');
49
+
50
+ // Update the hidden index to trigger the Python callback
51
+ setHiddenIndex(index);
52
+ });
53
+ });
54
+ }
55
+ console.log("Preamble complete");
56
+
57
+ document.addEventListener("DOMContentLoaded", function() {
58
+ // Setup initial listeners
59
+ console.log("DOM fully loaded and parsed");
60
+ setupTokenListeners();
61
+
62
+ // Setup a mutation observer to handle dynamically added tokens
63
+ const observer = new MutationObserver(function(mutations) {
64
+ mutations.forEach(function(mutation) {
65
+ if (mutation.addedNodes.length) {
66
+ setupTokenListeners();
67
+ }
68
+ });
69
+ });
70
+
71
+ // Start observing the token container for changes
72
+ const tokenContainer = gradioApp.querySelector('.token-container');
73
+ console.log("Token container:", tokenContainer);
74
+ if (tokenContainer) {
75
+ observer.observe(tokenContainer.parentNode, { childList: true, subtree: true });
76
+ }
77
+ console.log("Listener setup complete");
78
+ });
79
+ </script>
80
+ """
81
+
82
+
83
+ def load_dataset(mode: str):
84
+ if mode == "tossup":
85
+ ds = datasets.load_dataset(PLAYGROUND_DATASET_NAMES["tossup"], split="eval")
86
+ ds = ds.filter(lambda x: x["qid"].split("-")[2] == "1" and int(x["qid"].split("-")[3]) <= 10)
87
+ elif mode == "bonus":
88
+ ds = datasets.load_dataset(PLAYGROUND_DATASET_NAMES["bonus"], split="eval")
89
+ ds = ds.filter(lambda x: x["qid"].split("-")[2] == "1" and int(x["qid"].split("-")[3]) <= 10)
90
+ else:
91
+ raise ValueError(f"Invalid mode: {mode}")
92
+
93
+ return ds
94
+
95
+
96
+ def main():
97
+ tossup_ds = load_dataset("tossup")
98
+ bonus_ds = load_dataset("bonus")
99
+ app = gr.Blocks(
100
+ css=css_pipeline + css_tossup,
101
+ head=js_preamble,
102
+ theme=THEME,
103
+ title="Quizbowl Bot",
104
+ )
105
+ with app:
106
+ with gr.Tabs():
107
+ with gr.Tab("Tossup Agents"):
108
+ defaults = DEFAULT_SELECTIONS["tossup"] | {
109
+ "init_workflow": factory.create_quizbowl_simple_workflow(),
110
+ "simple_workflow": False,
111
+ }
112
+ tossup_interface = TossupInterface(app, tossup_ds, AVAILABLE_MODELS, defaults)
113
+ # ModelStepComponent(value=factory.create_quizbowl_simple_step())
114
+ with gr.Tab("Bonus Round Agents"):
115
+ defaults = DEFAULT_SELECTIONS["bonus"] | {
116
+ "init_workflow": factory.create_quizbowl_bonus_simple_workflow(),
117
+ "simple_workflow": True,
118
+ }
119
+ bonus_interface = BonusInterface(app, bonus_ds, AVAILABLE_MODELS, defaults)
120
+
121
+ app.queue(api_open=True).launch()
122
+
123
+
124
+ if __name__ == "__main__":
125
+ main()
pyproject.toml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.ruff]
2
+ # Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default.
3
+ select = ["E", "F"]
4
+ ignore = ["E501"] # line too long (black is taking care of this)
5
+ line-length = 119
6
+ fixable = ["A", "B", "C", "D", "E", "F", "G", "I", "N", "Q", "S", "T", "W", "ANN", "ARG", "BLE", "COM", "DJ", "DTZ", "EM", "ERA", "EXE", "FBT", "ICN", "INP", "ISC", "NPY", "PD", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP", "YTT"]
7
+
8
+ [tool.isort]
9
+ profile = "black"
10
+ line_length = 119
11
+
12
+ [tool.black]
13
+ line-length = 119
requirements.txt ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ APScheduler
2
+ black
3
+ datasets
4
+ gradio
5
+ modelscope_studio
6
+ gradio[oauth]
7
+ gradio_leaderboard
8
+ gradio_client
9
+ huggingface-hub>=0.18.0
10
+ matplotlib
11
+ numpy<2.0.0
12
+ pandas
13
+ python-dateutil
14
+ tqdm
15
+ transformers
16
+ tokenizers>=0.15.0
17
+ sentencepiece
18
+ litellm
19
+ openai
20
+ anthropic
21
+ cohere
22
+ langchain
23
+ langchain-core
24
+ langchain-community
25
+ langchain-anthropic
26
+ langchain-openai
27
+ langchain-cohere
src/components/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Components package
src/components/model_pipeline/__init__.py ADDED
File without changes
src/components/model_pipeline/model_pipeline.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import gradio as gr
4
+ import yaml
5
+
6
+ from components.model_pipeline.state_manager import (
7
+ ModelStepUIState,
8
+ PipelineState,
9
+ PipelineStateManager,
10
+ PipelineUIState,
11
+ )
12
+ from components.model_step.model_step import ModelStepComponent
13
+ from components.utils import make_state
14
+ from workflows.structs import ModelStep, Workflow
15
+ from workflows.validators import WorkflowValidator
16
+
17
+
18
+ def validate_simple_workflow(workflow: Workflow, required_output_variables: list[str]) -> Workflow:
19
+ """Validate the workflow."""
20
+ step = next(iter(workflow.steps.values()))
21
+ if not step.output_fields:
22
+ raise ValueError("No output fields found in the workflow")
23
+ output_field_names = {output.name for output in step.output_fields}
24
+ if not set(required_output_variables) <= output_field_names:
25
+ missing_vars = required_output_variables - output_field_names
26
+ raise ValueError(f"Missing required output variables: {missing_vars}")
27
+ return workflow
28
+
29
+
30
+ def validate_complex_workflow(workflow: Workflow, required_output_variables: list[str]) -> Workflow:
31
+ """Validate the workflow."""
32
+ print("Validating complex workflow.")
33
+ return workflow
34
+ step = next(iter(workflow.steps.values()))
35
+ if not step.output_fields:
36
+ raise ValueError("No output fields found in the workflow")
37
+ output_field_names = {output.name for output in step.output_fields}
38
+ if not output_field_names <= set(required_output_variables):
39
+ missing_vars = output_field_names - set(required_output_variables)
40
+ raise ValueError(f"Missing required output variables: {missing_vars}")
41
+ return workflow
42
+
43
+
44
+ def parse_yaml_workflow(yaml_str: str) -> Workflow:
45
+ """Parse a YAML workflow."""
46
+ workflow = yaml.safe_load(yaml_str)
47
+ return Workflow(**workflow)
48
+
49
+
50
+ def update_workflow_from_code(yaml_str: str, ui_state: PipelineUIState) -> PipelineState:
51
+ """Update a workflow from a YAML string."""
52
+ workflow = parse_yaml_workflow(yaml_str)
53
+ ui_state = PipelineUIState.from_workflow(workflow)
54
+ return PipelineState(workflow=workflow, ui_state=ui_state)
55
+
56
+
57
+ class PipelineInterface:
58
+ """UI for the pipeline."""
59
+
60
+ def __init__(
61
+ self,
62
+ workflow: Workflow,
63
+ ui_state: PipelineUIState | None = None,
64
+ model_options: list[str] = None,
65
+ simple: bool = False,
66
+ ):
67
+ self.model_options = model_options
68
+ self.simple = simple
69
+ if not ui_state:
70
+ ui_state = PipelineUIState.from_workflow(workflow)
71
+ self.ui_state = make_state(ui_state)
72
+ self.pipeline_state = make_state(PipelineState(workflow=workflow, ui_state=ui_state))
73
+ self.variables_state = make_state(workflow.get_available_variables())
74
+
75
+ self.sm = PipelineStateManager()
76
+ self.input_variables = workflow.inputs
77
+ self.required_output_variables = list(workflow.outputs.keys())
78
+
79
+ # UI elements
80
+ self.steps_container = None
81
+ self.components = []
82
+
83
+ # Render the pipeline UI
84
+ self.render()
85
+
86
+ def _render_step(
87
+ self,
88
+ model_step: ModelStep,
89
+ step_ui_state: ModelStepUIState,
90
+ available_variables: list[str],
91
+ position: int = 0,
92
+ ):
93
+ with gr.Column(elem_classes="step-container"):
94
+ # Create the step component
95
+ step_interface = ModelStepComponent(
96
+ value=model_step,
97
+ ui_state=step_ui_state,
98
+ model_options=self.model_options,
99
+ input_variables=available_variables,
100
+ pipeline_state_manager=self.sm,
101
+ )
102
+
103
+ step_interface.on_model_step_change(
104
+ self.sm.update_model_step_state,
105
+ inputs=[self.pipeline_state, step_interface.model_step_state, step_interface.ui_state],
106
+ outputs=[self.pipeline_state, self.ui_state, self.variables_state],
107
+ )
108
+
109
+ step_interface.on_ui_change(
110
+ self.sm.update_model_step_ui,
111
+ inputs=[self.pipeline_state, step_interface.ui_state, gr.State(model_step.id)],
112
+ outputs=[self.pipeline_state, self.ui_state],
113
+ )
114
+
115
+ if self.simple:
116
+ return step_interface
117
+
118
+ # Add step controls below
119
+ with gr.Row(elem_classes="step-controls"):
120
+ up_button = gr.Button("⬆️ Move Up", elem_classes="step-control-btn")
121
+ down_button = gr.Button("⬇️ Move Down", elem_classes="step-control-btn")
122
+ remove_button = gr.Button("🗑️ Remove", elem_classes="step-control-btn")
123
+
124
+ buttons = (up_button, down_button, remove_button)
125
+ self._assign_step_controls(buttons, position)
126
+
127
+ return (step_interface, *buttons)
128
+
129
+ def _assign_step_controls(self, buttons: tuple[gr.Button, gr.Button, gr.Button], position: int):
130
+ up_button, down_button, remove_button = buttons
131
+ position = gr.State(position)
132
+ up_button.click(self.sm.move_up, inputs=[self.ui_state, position], outputs=self.ui_state)
133
+ down_button.click(self.sm.move_down, inputs=[self.ui_state, position], outputs=self.ui_state)
134
+ remove_button.click(
135
+ self.sm.remove_step,
136
+ inputs=[self.pipeline_state, position],
137
+ outputs=[self.pipeline_state, self.ui_state, self.variables_state],
138
+ )
139
+
140
+ def _render_add_step_button(self, position: int):
141
+ if position not in {0, -1}:
142
+ raise ValueError("Position must be 0 or -1")
143
+ row_class = "pipeline-header" if position == 0 else "pipeline-footer"
144
+ with gr.Row(elem_classes=row_class):
145
+ add_step_btn = gr.Button("➕ Add Step", elem_classes="add-step-button")
146
+ add_step_btn.click(
147
+ self.sm.add_step,
148
+ inputs=[self.pipeline_state, gr.State(position)],
149
+ outputs=[self.pipeline_state, self.ui_state, self.variables_state],
150
+ )
151
+ return add_step_btn
152
+
153
+ def _render_output_fields(self, available_variables: list[str], pipeline_state: PipelineState):
154
+ dropdowns = {}
155
+ UNSET_VALUE = "Choose variable..."
156
+ variable_options = [UNSET_VALUE] + [v for v in available_variables if v not in self.input_variables]
157
+ with gr.Column(elem_classes="step-accordion"):
158
+ with gr.Row(elem_classes="output-fields-header"):
159
+ gr.Markdown("#### Final output variables mapping:")
160
+ with gr.Row(elem_classes="output-fields-row"):
161
+ for output_field in self.required_output_variables:
162
+ value = pipeline_state.workflow.outputs[output_field]
163
+ if not value:
164
+ value = UNSET_VALUE
165
+ dropdown = gr.Dropdown(
166
+ label=output_field,
167
+ value=value,
168
+ choices=variable_options,
169
+ interactive=True,
170
+ elem_classes="output-field-variable",
171
+ # show_label=False,
172
+ )
173
+ dropdown.change(
174
+ self.sm.update_output_variables,
175
+ inputs=[self.pipeline_state, gr.State(output_field), dropdown],
176
+ outputs=[self.pipeline_state],
177
+ )
178
+ dropdowns[output_field] = dropdown
179
+
180
+ def update_choices(available_variables):
181
+ """Update the choices for the dropdowns"""
182
+ return [
183
+ gr.update(choices=available_variables, value=None, selected=None) for dropdown in dropdowns.values()
184
+ ]
185
+
186
+ self.variables_state.change(
187
+ update_choices,
188
+ inputs=[self.variables_state],
189
+ outputs=list(dropdowns.values()),
190
+ )
191
+ return dropdowns
192
+
193
+ def validate_workflow(self, state: PipelineState) -> PipelineState:
194
+ """Validate the workflow."""
195
+ try:
196
+ if self.simple:
197
+ workflow = validate_simple_workflow(state.workflow, self.required_output_variables)
198
+ else:
199
+ workflow = validate_complex_workflow(state.workflow, self.required_output_variables)
200
+ state.workflow = workflow
201
+ return state
202
+ except ValueError as e:
203
+ raise gr.Error(e)
204
+
205
+ def _render_pipeline_header(self):
206
+ # Add Step button at top
207
+ input_variables_str = ", ".join([f"`{variable}`" for variable in self.input_variables])
208
+ output_variables_str = ", ".join([f"`{variable}`" for variable in self.required_output_variables])
209
+ if self.simple:
210
+ instruction = "Create a simple single LLM call pipeline that takes in the following input variables and outputs the following output variables:"
211
+ else:
212
+ instruction = "Create a pipeline that takes in the following input variables and outputs the following output variables:"
213
+ gr.Markdown(f"### {instruction}")
214
+ gr.Markdown(f"Input Variables: {input_variables_str}")
215
+ gr.Markdown(f"Output Variables: {output_variables_str}")
216
+
217
+ # if not self.simple:
218
+ # self._render_add_step_button(0)
219
+
220
+ def render(self):
221
+ """Render the pipeline UI."""
222
+ # Create a placeholder for all the step components
223
+ self.all_components = []
224
+
225
+ # self.pipeline_state.change(
226
+ # lambda x, y: print(f"Pipeline state changed! UI:\n{x}\n\n Data:\n{y}"),
227
+ # inputs=[self.ui_state, self.pipeline_state],
228
+ # outputs=[],
229
+ # )
230
+
231
+ self._render_pipeline_header()
232
+
233
+ # Function to render all steps
234
+ @gr.render(inputs=[self.pipeline_state, self.ui_state])
235
+ def render_steps(state, ui_state):
236
+ """Render all steps in the pipeline"""
237
+ workflow = state.workflow
238
+ print(f"\nRerender triggered! Current UI State:{ui_state}")
239
+ components = []
240
+
241
+ step_objects = [] # Reset step objects list
242
+ for i, step_id in enumerate(ui_state.step_ids):
243
+ step_data = workflow.steps[step_id]
244
+ step_ui_state = ui_state.steps[step_id]
245
+ available_variables = self.sm.get_all_variables(state, step_id)
246
+ sub_components = self._render_step(step_data, step_ui_state, available_variables, i)
247
+ step_objects.append(sub_components)
248
+
249
+ components.append(step_objects)
250
+
251
+ # Bottom buttons
252
+ if not self.simple:
253
+ self._render_add_step_button(-1)
254
+
255
+ @gr.render(inputs=[self.variables_state, self.pipeline_state])
256
+ def render_output_fields(available_variables, pipeline_state):
257
+ return self._render_output_fields(available_variables, pipeline_state)
258
+
259
+ export_btn = gr.Button("Export Pipeline", elem_classes="export-button")
260
+ # components.append(export_btn)
261
+
262
+ # Add a code box to display the workflow JSON
263
+ # with gr.Column(elem_classes="workflow-json-container"):
264
+ with gr.Accordion("Pipeline Preview", open=False, elem_classes="pipeline-preview") as config_accordion:
265
+ config_output = gr.Code(
266
+ label="Workflow Configuration",
267
+ language="yaml",
268
+ elem_classes="workflow-json",
269
+ interactive=True,
270
+ autocomplete=True,
271
+ )
272
+ # components.append(config_accordion)
273
+
274
+ config_output.blur(
275
+ fn=update_workflow_from_code,
276
+ inputs=[config_output, self.ui_state],
277
+ outputs=[self.pipeline_state],
278
+ )
279
+
280
+ # Connect the export button to show the workflow JSON
281
+ export_btn.click(self.validate_workflow, inputs=[self.pipeline_state], outputs=[self.pipeline_state]).success(
282
+ fn=lambda: gr.update(visible=True, open=True), outputs=[config_accordion]
283
+ )
284
+ export_btn.click(
285
+ fn=self.sm.get_formatted_config,
286
+ inputs=[self.pipeline_state, gr.State("yaml")],
287
+ outputs=[config_output],
288
+ js="() => {document.querySelector('.pipeline-preview').scrollIntoView({behavior: 'smooth'})}",
289
+ )
290
+
291
+ # self.all_components = components
src/components/model_pipeline/state_manager.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ from typing import Any, Literal
4
+
5
+ import gradio as gr
6
+ import yaml
7
+ from pydantic import BaseModel, Field
8
+
9
+ from components import utils
10
+ from workflows.factory import create_new_llm_step
11
+ from workflows.structs import ModelStep, Workflow
12
+
13
+
14
+ def make_step_id(step_id: int):
15
+ """Make a step id from a step name."""
16
+ if step_id < 26:
17
+ return chr(ord("A") + step_id)
18
+ else:
19
+ # For more than 26 steps, use AA, AB, AC, etc.
20
+ first_char = chr(ord("A") + (step_id // 26) - 1)
21
+ second_char = chr(ord("A") + (step_id % 26))
22
+ return f"{first_char}{second_char}"
23
+
24
+
25
+ class ModelStepUIState(BaseModel):
26
+ """Represents the UI state for a model step component."""
27
+
28
+ expanded: bool = True
29
+ active_tab: Literal["model-tab", "inputs-tab", "outputs-tab"] = "model-tab"
30
+
31
+ def update(self, key: str, value: Any) -> "ModelStepUIState":
32
+ """Update the UI state."""
33
+ new_state = self.model_copy(update={key: value})
34
+ logging.warning("UI state updated: %s", self)
35
+ return new_state
36
+
37
+
38
+ class PipelineUIState(BaseModel):
39
+ """Represents the UI state for a pipeline component."""
40
+
41
+ step_ids: list[str] = Field(default_factory=list)
42
+ steps: dict[str, ModelStepUIState] = Field(default_factory=dict)
43
+
44
+ def model_post_init(self, __context: utils.Any) -> None:
45
+ if not self.steps and self.step_ids:
46
+ self.steps = {step_id: ModelStepUIState() for step_id in self.step_ids}
47
+ return super().model_post_init(__context)
48
+
49
+ def get_step_position(self, step_id: str):
50
+ """Get the position of a step in the pipeline."""
51
+ return next((i for i, step in enumerate(self.step_ids) if step == step_id), None)
52
+
53
+ @classmethod
54
+ def from_workflow(cls, workflow: Workflow):
55
+ """Create a pipeline UI state from a workflow."""
56
+ return PipelineUIState(
57
+ step_ids=list(workflow.steps.keys()),
58
+ steps={step_id: ModelStepUIState() for step_id in workflow.steps.keys()},
59
+ )
60
+
61
+
62
+ class PipelineState(BaseModel):
63
+ """Represents the state for a pipeline component."""
64
+
65
+ workflow: Workflow
66
+ ui_state: PipelineUIState
67
+
68
+ def insert_step(self, position: int, step: ModelStep):
69
+ if step.id in self.workflow.steps:
70
+ raise ValueError(f"Step {step.id} already exists in pipeline")
71
+
72
+ # Validate position
73
+ if position != -1 and (position < 0 or position > self.n_steps):
74
+ raise ValueError(f"Invalid position: {position}. Must be between 0 and {self.n_steps} or -1")
75
+
76
+ self.workflow.steps[step.id] = step
77
+
78
+ self.ui_state = self.ui_state.model_copy()
79
+ self.ui_state.steps[step.id] = ModelStepUIState()
80
+ if position == -1:
81
+ self.ui_state.step_ids.append(step.id)
82
+ else:
83
+ self.ui_state.step_ids.insert(position, step.id)
84
+ return self
85
+
86
+ def remove_step(self, position: int):
87
+ step_id = self.ui_state.step_ids.pop(position)
88
+ self.workflow.steps.pop(step_id)
89
+ self.ui_state = self.ui_state.model_copy()
90
+ self.ui_state.steps.pop(step_id)
91
+ self.update_output_variables_mapping()
92
+
93
+ def update_output_variables_mapping(self):
94
+ available_variables = set(self.available_variables)
95
+ for output_field in self.workflow.outputs:
96
+ if self.workflow.outputs[output_field] not in available_variables:
97
+ self.workflow.outputs[output_field] = None
98
+ return self
99
+
100
+ @property
101
+ def available_variables(self):
102
+ return self.workflow.get_available_variables()
103
+
104
+ @property
105
+ def n_steps(self):
106
+ return len(self.workflow.steps)
107
+
108
+
109
+ class PipelineStateManager:
110
+ """Manages a pipeline of multiple steps."""
111
+
112
+ def get_formatted_config(self, state: PipelineState, format: Literal["json", "yaml"] = "yaml"):
113
+ """Get the full pipeline configuration."""
114
+ config = state.workflow.model_dump(exclude_defaults=True)
115
+ if format == "yaml":
116
+ return yaml.dump(config, default_flow_style=False, sort_keys=False, indent=4)
117
+ else:
118
+ return json.dumps(config, indent=4, sort_keys=False)
119
+
120
+ def count_state(self):
121
+ return gr.State(len(self.steps))
122
+
123
+ def add_step(self, state: PipelineState, position: int = -1, name=""):
124
+ """Create a new step and return its state."""
125
+ step_id = make_step_id(state.n_steps)
126
+ step_name = name or f"Step {state.n_steps + 1}"
127
+ new_step = create_new_llm_step(step_id=step_id, name=step_name)
128
+ state = state.insert_step(position, new_step)
129
+ return state, state.ui_state, state.available_variables
130
+
131
+ def remove_step(self, state: PipelineState, position: int):
132
+ """Remove a step from the pipeline."""
133
+ if 0 <= position < state.n_steps:
134
+ state = state.remove_step(position)
135
+ else:
136
+ raise ValueError(f"Invalid step position: {position}")
137
+ return state, state.ui_state, state.available_variables
138
+
139
+ def move_up(self, ui_state: PipelineUIState, position: int):
140
+ """Move a step up in the pipeline."""
141
+ utils.move_item(ui_state.step_ids, position, "up")
142
+ return ui_state.model_copy()
143
+
144
+ def move_down(self, ui_state: PipelineUIState, position: int):
145
+ """Move a step down in the pipeline."""
146
+ utils.move_item(ui_state.step_ids, position, "down")
147
+ return ui_state.model_copy()
148
+
149
+ def update_model_step_state(self, state: PipelineState, model_step: ModelStep, ui_state: ModelStepUIState):
150
+ """Update a step in the pipeline."""
151
+ state.workflow.steps[model_step.id] = model_step.model_copy()
152
+ state.ui_state.steps[model_step.id] = ui_state.model_copy()
153
+ state.ui_state = state.ui_state.model_copy()
154
+ state.update_output_variables_mapping()
155
+ return state, state.ui_state, state.available_variables
156
+
157
+ def update_output_variables(self, state: PipelineState, target: str, produced_variable: str):
158
+ if produced_variable == "Choose variable...":
159
+ produced_variable = None
160
+ """Update the output variables for a step."""
161
+ state.workflow.outputs.update({target: produced_variable})
162
+ return state
163
+
164
+ def update_model_step_ui(self, state: PipelineState, step_ui: ModelStepUIState, step_id: str):
165
+ """Update a step in the pipeline."""
166
+ state.ui_state.steps[step_id] = step_ui.model_copy()
167
+ return state, state.ui_state
168
+
169
+ def get_all_variables(self, state: PipelineState, model_step_id: str | None = None) -> list[str]:
170
+ """Get all variables from all steps."""
171
+ available_variables = state.available_variables
172
+ if model_step_id is None:
173
+ return available_variables
174
+ else:
175
+ prefix = f"{model_step_id}."
176
+ return [var for var in available_variables if not var.startswith(prefix)]
177
+
178
+ def get_pipeline_config(self):
179
+ """Get the full pipeline configuration."""
180
+ return self.workflow
src/components/model_step/__init__.py ADDED
File without changes
src/components/model_step/model_step.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import Any
3
+
4
+ import gradio as gr
5
+ from gradio.components import FormComponent
6
+
7
+ from components.model_pipeline.state_manager import ModelStepUIState, PipelineState, PipelineStateManager
8
+ from utils import get_full_model_name
9
+ from workflows.structs import ModelStep
10
+
11
+ from .state_manager import ModelStepStateManager
12
+ from .ui_components import InputRowButtonGroup, OutputRowButtonGroup
13
+
14
+
15
+ def _make_accordion_label(model_step: ModelStep):
16
+ name = model_step.name if model_step.name else "Untitled"
17
+ input_field_names = [field.name for field in model_step.input_fields]
18
+ inputs_str = ", ".join(input_field_names)
19
+ output_field_names = [field.name for field in model_step.output_fields]
20
+ outputs_str = ", ".join(output_field_names)
21
+ return "{}: {} ({}) → ({})".format(model_step.id, name, inputs_str, outputs_str)
22
+
23
+
24
+ class ModelStepComponent(FormComponent):
25
+ """
26
+ A custom Gradio component representing a single Step in a pipeline.
27
+ It contains:
28
+ 1. Model Provider & System Prompt
29
+ 2. Inputs – fields with name, description, and variable used
30
+ 3. Outputs – fields with name, description, and variable used
31
+
32
+ Listens to events:
33
+ - on_model_step_change
34
+ - on_ui_change
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ value: ModelStep | gr.State,
40
+ ui_state: ModelStepUIState | gr.State | None = None,
41
+ model_options: list[str] | None = None,
42
+ input_variables: list[str] | None = None,
43
+ max_input_fields=5,
44
+ max_output_fields=5,
45
+ pipeline_state_manager: PipelineStateManager | None = None,
46
+ **kwargs,
47
+ ):
48
+ self.max_fields = {
49
+ "input": max_input_fields,
50
+ "output": max_output_fields,
51
+ }
52
+ self.model_options = model_options
53
+ self.input_variables = input_variables
54
+ self.sm = ModelStepStateManager(max_input_fields, max_output_fields)
55
+ self.pipeline_sm: PipelineStateManager = pipeline_state_manager
56
+
57
+ self.model_step_state = gr.State(value)
58
+ ui_state = ui_state or ModelStepUIState()
59
+ if not isinstance(ui_state, gr.State):
60
+ ui_state = gr.State(ui_state)
61
+ self.ui_state: gr.State = ui_state
62
+
63
+ self.inputs_count_state = gr.State(len(value.input_fields))
64
+ self.outputs_count_state = gr.State(len(value.output_fields))
65
+
66
+ # UI components that will be created in render
67
+ self.accordion = None
68
+ self.ui = None
69
+ self.step_name_input = None
70
+ self.model_selection = None
71
+ self.system_prompt = None
72
+ self.input_rows = []
73
+ self.output_rows = []
74
+
75
+ super().__init__(**kwargs)
76
+ # self.render()
77
+ self.setup_event_listeners()
78
+
79
+ @property
80
+ def model_step(self) -> ModelStep:
81
+ return self.model_step_state.value
82
+
83
+ @property
84
+ def step_id(self) -> str:
85
+ return self.model_step.id
86
+
87
+ def get_step_config(self) -> dict:
88
+ return self.model_step.model_dump()
89
+
90
+ # UI state accessors
91
+ def is_open(self) -> bool:
92
+ return self.ui_state.value.expanded
93
+
94
+ def get_active_tab(self) -> str:
95
+ """Get the current active tab."""
96
+ return self.ui_state.value.active_tab
97
+
98
+ def _render_input_row(self, i: int) -> tuple[gr.Row, tuple, tuple]:
99
+ """Render a single input row at index i."""
100
+ inputs = self.model_step.input_fields
101
+ is_visible = i < len(inputs)
102
+ label_visible = i == 0
103
+ initial_name = inputs[i].name if is_visible else ""
104
+ initial_desc = inputs[i].description if is_visible else ""
105
+ initial_var = inputs[i].variable if is_visible else "question_text"
106
+
107
+ with gr.Row(visible=is_visible, elem_classes="field-row form") as row:
108
+ button_group = InputRowButtonGroup()
109
+
110
+ inp_name = gr.Textbox(
111
+ label="Input Name",
112
+ placeholder="Field name",
113
+ value=initial_name,
114
+ elem_classes="field-name",
115
+ scale=1,
116
+ show_label=label_visible,
117
+ )
118
+
119
+ # Get variable choices safely
120
+ # variable_choices = []
121
+ # if self.pipeline_sm is not None:
122
+ # variable_choices = self.pipeline_sm.get_all_variables(self.step_id)
123
+
124
+ inp_var = gr.Dropdown(
125
+ choices=self.input_variables,
126
+ label="Variable Used",
127
+ value=initial_var,
128
+ elem_classes="field-variable",
129
+ scale=1,
130
+ show_label=label_visible,
131
+ )
132
+ inp_desc = gr.Textbox(
133
+ label="Description",
134
+ placeholder="Field description",
135
+ value=initial_desc,
136
+ elem_classes="field-description",
137
+ scale=3,
138
+ show_label=label_visible,
139
+ )
140
+ fields = (inp_name, inp_var, inp_desc)
141
+ # buttons = (delete_button, add_button)
142
+ return row, fields, button_group
143
+
144
+ def _render_output_row(self, i: int) -> tuple[gr.Row, tuple, tuple]:
145
+ """Render a single output row at index i."""
146
+ outputs = self.model_step.output_fields
147
+ is_visible = i < len(outputs)
148
+ label_visible = i == 0
149
+ initial_name = outputs[i].name if is_visible else ""
150
+ initial_desc = outputs[i].description if is_visible else ""
151
+ initial_type = outputs[i].type if is_visible else "str"
152
+ with gr.Row(visible=is_visible, elem_classes="field-row") as row:
153
+ button_group = OutputRowButtonGroup()
154
+
155
+ out_name = gr.Textbox(
156
+ label="Output Field",
157
+ placeholder="Variable identifier",
158
+ value=initial_name,
159
+ elem_classes="field-name",
160
+ scale=1,
161
+ show_label=label_visible,
162
+ )
163
+ out_type = gr.Dropdown(
164
+ choices=["str", "int", "float", "bool"],
165
+ allow_custom_value=True,
166
+ label="Type",
167
+ value=initial_type,
168
+ elem_classes="field-type",
169
+ scale=0,
170
+ show_label=label_visible,
171
+ interactive=True,
172
+ )
173
+ out_desc = gr.Textbox(
174
+ label="Description",
175
+ placeholder="Field description",
176
+ value=initial_desc,
177
+ elem_classes="field-description",
178
+ scale=3,
179
+ show_label=label_visible,
180
+ )
181
+
182
+ fields = (out_name, out_type, out_desc)
183
+ return row, fields, button_group
184
+
185
+ def _render_prompt_tab_content(self):
186
+ self.system_prompt = gr.Textbox(
187
+ label="System Prompt",
188
+ placeholder="Enter the system prompt for this step",
189
+ lines=5,
190
+ value=self.model_step.system_prompt,
191
+ elem_classes="system-prompt",
192
+ )
193
+
194
+ def _render_inputs_tab_content(self):
195
+ with gr.Column(variant="panel", elem_classes="fields-panel") as self.inputs_column:
196
+ # Render input rows using helper method
197
+ for i in range(self.max_fields["input"]):
198
+ row = self._render_input_row(i)
199
+ self.input_rows.append(row)
200
+
201
+ def _render_outputs_tab_content(self):
202
+ with gr.Column(variant="panel", elem_classes="fields-panel") as self.outputs_column:
203
+ # Render output rows using helper method
204
+ for i in range(self.max_fields["output"]):
205
+ row = self._render_output_row(i)
206
+ self.output_rows.append(row)
207
+
208
+ def _render_tab_content(self, tab_id: str):
209
+ if tab_id == "model-tab":
210
+ self._render_prompt_tab_content()
211
+ elif tab_id == "inputs-tab":
212
+ self._render_inputs_tab_content()
213
+ elif tab_id == "outputs-tab":
214
+ self._render_outputs_tab_content()
215
+
216
+ def _render_header(self, model_options: tuple[str]):
217
+ # Header with step name
218
+ with gr.Row(elem_classes="step-header-row"):
219
+ self.step_name_input = gr.Textbox(
220
+ label="",
221
+ value=self.model_step.name,
222
+ elem_classes="step-name",
223
+ show_label=False,
224
+ placeholder="Model name...",
225
+ )
226
+ unselected_choice = "Select Model..."
227
+ current_value = (
228
+ get_full_model_name(self.model_step.model, self.model_step.provider)
229
+ if self.model_step.model
230
+ else unselected_choice
231
+ )
232
+ self.model_selection = gr.Dropdown(
233
+ choices=[unselected_choice] + model_options,
234
+ label="Model Provider",
235
+ show_label=False,
236
+ value=current_value,
237
+ elem_classes="model-dropdown",
238
+ scale=1,
239
+ )
240
+ self.temperature_slider = gr.Slider(
241
+ value=self.model_step.temperature,
242
+ minimum=0.0,
243
+ maximum=5,
244
+ step=0.05,
245
+ info="Temperature",
246
+ show_label=False,
247
+ )
248
+
249
+ def render(self):
250
+ """Render the component UI"""
251
+ # Reset UI component lists
252
+ self.input_rows = []
253
+ self.output_rows = []
254
+ self.tabs = {}
255
+
256
+ # Create the accordion for this step
257
+ accordion_label = _make_accordion_label(self.model_step)
258
+ self.accordion = gr.Accordion(label=accordion_label, open=self.is_open(), elem_classes="step-accordion")
259
+
260
+ # Create the UI content inside the accordion
261
+ with self.accordion:
262
+ self._render_header(self.model_options)
263
+
264
+ # Configuration tabs
265
+ selected_tab = self.get_active_tab()
266
+ with gr.Tabs(elem_classes="step-tabs", selected=selected_tab):
267
+ tab_ids = ("model-tab", "inputs-tab", "outputs-tab")
268
+ tab_labels = ("Model", "Inputs", "Outputs")
269
+ for tab_id, label in zip(tab_ids, tab_labels):
270
+ with gr.TabItem(label, elem_classes="tab-content", id=tab_id) as tab:
271
+ self._render_tab_content(tab_id)
272
+ self.tabs[tab_id] = tab
273
+
274
+ return self.accordion
275
+
276
+ def _setup_event_listeners_for_view_change(self):
277
+ for tab_id, tab in self.tabs.items():
278
+ tab.select(
279
+ fn=self.sm.update_ui_state,
280
+ inputs=[self.ui_state, gr.State("active_tab"), gr.State(tab_id)],
281
+ outputs=[self.ui_state],
282
+ )
283
+ self.accordion.collapse(
284
+ fn=self.sm.update_ui_state,
285
+ inputs=[self.ui_state, gr.State("expanded"), gr.State(False)],
286
+ outputs=[self.ui_state],
287
+ )
288
+ self.accordion.expand(
289
+ fn=self.sm.update_ui_state,
290
+ inputs=[self.ui_state, gr.State("expanded"), gr.State(True)],
291
+ outputs=[self.ui_state],
292
+ )
293
+
294
+ def _setup_event_listeners_model_tab(self):
295
+ # Step name change
296
+ self.step_name_input.blur(
297
+ fn=self._update_state_and_label,
298
+ inputs=[self.model_step_state, self.step_name_input],
299
+ outputs=[self.model_step_state, self.accordion],
300
+ )
301
+
302
+ self.temperature_slider.release(
303
+ fn=self.sm.update_temperature,
304
+ inputs=[self.model_step_state, self.temperature_slider],
305
+ outputs=[self.model_step_state],
306
+ )
307
+
308
+ # Model and system prompt
309
+ self.model_selection.change(
310
+ fn=self.sm.update_model_and_provider,
311
+ inputs=[self.model_step_state, self.model_selection],
312
+ outputs=[self.model_step_state],
313
+ )
314
+
315
+ self.system_prompt.blur(
316
+ fn=self.sm.update_system_prompt,
317
+ inputs=[self.model_step_state, self.system_prompt],
318
+ outputs=[self.model_step_state],
319
+ )
320
+
321
+ def _setup_event_listeners_inputs_tab(self):
322
+ # Setup input row events
323
+ for i, (row, fields, button_group) in enumerate(self.input_rows):
324
+ inp_name, inp_var, inp_desc = fields
325
+ row_index = gr.State(i)
326
+
327
+ # Field change handlers
328
+ inp_name.blur(
329
+ fn=self.sm.update_input_field_name,
330
+ inputs=[self.model_step_state, inp_name, row_index],
331
+ outputs=[self.model_step_state],
332
+ )
333
+
334
+ inp_var.change(
335
+ fn=self.sm.update_input_field_variable,
336
+ inputs=[self.model_step_state, inp_var, row_index],
337
+ outputs=[self.model_step_state],
338
+ )
339
+
340
+ inp_desc.blur(
341
+ fn=self.sm.update_input_field_description,
342
+ inputs=[self.model_step_state, inp_desc, row_index],
343
+ outputs=[self.model_step_state],
344
+ )
345
+
346
+ rows = [row for (row, _, _) in self.input_rows]
347
+ input_fields = [field for (_, fields, _) in self.input_rows for field in fields]
348
+
349
+ # Button handlers
350
+ button_group.delete(
351
+ fn=self.sm.delete_input_field,
352
+ inputs=[self.model_step_state, row_index],
353
+ outputs=[self.model_step_state, self.inputs_count_state] + rows + input_fields,
354
+ )
355
+
356
+ button_group.add(
357
+ fn=self.sm.add_input_field,
358
+ inputs=[self.model_step_state, row_index],
359
+ outputs=[self.model_step_state, self.inputs_count_state] + rows + input_fields,
360
+ )
361
+
362
+ def _setup_event_listeners_outputs_tab(self):
363
+ # Setup output row events
364
+ for i, (row, fields, button_group) in enumerate(self.output_rows):
365
+ out_name, out_type, out_desc = fields
366
+
367
+ row_index = gr.State(i)
368
+
369
+ # Field change handlers
370
+ out_name.blur(
371
+ fn=self.sm.update_output_field_name,
372
+ inputs=[self.model_step_state, out_name, row_index],
373
+ outputs=[self.model_step_state],
374
+ )
375
+
376
+ out_type.change(
377
+ fn=self.sm.update_output_field_type,
378
+ inputs=[self.model_step_state, out_type, row_index],
379
+ outputs=[self.model_step_state],
380
+ )
381
+
382
+ out_desc.blur(
383
+ fn=self.sm.update_output_field_description,
384
+ inputs=[self.model_step_state, out_desc, row_index],
385
+ outputs=[self.model_step_state],
386
+ )
387
+
388
+ rows = [row for (row, _, _) in self.output_rows]
389
+ output_fields = [field for (_, fields, _) in self.output_rows for field in fields]
390
+
391
+ # Button handlers
392
+ button_group.delete(
393
+ fn=self.sm.delete_output_field,
394
+ inputs=[self.model_step_state, row_index],
395
+ outputs=[self.model_step_state, self.outputs_count_state] + rows + output_fields,
396
+ )
397
+
398
+ button_group.add(
399
+ fn=self.sm.add_output_field,
400
+ inputs=[self.model_step_state, row_index],
401
+ outputs=[self.model_step_state, self.outputs_count_state] + rows + output_fields,
402
+ )
403
+
404
+ button_group.up(
405
+ fn=self.sm.move_output_field,
406
+ inputs=[self.model_step_state, row_index, gr.State("up")],
407
+ outputs=[self.model_step_state] + output_fields,
408
+ )
409
+
410
+ button_group.down(
411
+ fn=self.sm.move_output_field,
412
+ inputs=[self.model_step_state, row_index, gr.State("down")],
413
+ outputs=[self.model_step_state] + output_fields,
414
+ )
415
+
416
+ # Function to set up event listeners - call this separately after all components are rendered
417
+ def setup_event_listeners(self):
418
+ """Set up all event listeners for this component"""
419
+ self._setup_event_listeners_for_view_change()
420
+ self._setup_event_listeners_model_tab()
421
+ self._setup_event_listeners_inputs_tab()
422
+ self._setup_event_listeners_outputs_tab()
423
+
424
+ def state_str(x, limited: bool = False):
425
+ d = x.model_dump()
426
+ if limited:
427
+ d = {k: d[k] for k in {"name", "temperature"}}
428
+ return json.dumps(d, indent=2)
429
+
430
+ def log_step_states(x, y, src: str):
431
+ print(f"{src} triggered! UI:\n{state_str(x)}\n\nData:\n{state_str(y, True)}")
432
+ print("--------------------------------")
433
+ print(f"self.model_step_state: \n{self.get_step_config()}")
434
+ print("--------------------------------")
435
+
436
+ # self.model_step_state.change(
437
+ # log_step_states,
438
+ # inputs=[self.ui_state, self.model_step_state, gr.State("Model Change")],
439
+ # )
440
+ # self.ui_state.change(
441
+ # log_step_states,
442
+ # inputs=[self.ui_state, self.model_step_state, gr.State("UI Change")],
443
+ # )
444
+
445
+ def on_model_step_change(self, fn, inputs, outputs):
446
+ """Set up an event listener for the model change event."""
447
+ self.model_step_state.change(fn, inputs, outputs)
448
+
449
+ def on_ui_change(self, fn, inputs, outputs):
450
+ """Set up an event listener for the UI change event."""
451
+ self.ui_state.change(fn, inputs, outputs)
452
+
453
+ def _update_state_and_label(self, model_step: ModelStep, name: str):
454
+ """Update both the state and the accordion label."""
455
+ new_model_step = self.sm.update_step_name(model_step, name)
456
+ new_label = _make_accordion_label(new_model_step)
457
+ return new_model_step, gr.update(label=new_label)
458
+
459
+ def refresh_variable_dropdowns(self, pipeline_state: PipelineState):
460
+ # TODO: Fix this. Not sure why this is needed.
461
+ """Refresh the variable dropdown options in all input rows."""
462
+ variable_choices = []
463
+ if self.pipeline_sm is not None:
464
+ variable_choices = self.pipeline_sm.get_all_variables(pipeline_state)
465
+
466
+ for _, fields, _ in self.input_rows:
467
+ _, inp_var, _ = fields
468
+ inp_var.update(choices=variable_choices)
469
+
470
+ def _update_model_and_refresh_ui(self, updated_model_step):
471
+ """Update the model step state and refresh UI elements that depend on it."""
472
+ self.model_step_state.value = updated_model_step
473
+ # Update accordion label
474
+ new_label = _make_accordion_label(updated_model_step)
475
+ if self.accordion:
476
+ self.accordion.update(label=new_label)
477
+ return updated_model_step
src/components/model_step/state_manager.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Literal, Union
2
+
3
+ import gradio as gr
4
+
5
+ from components.model_pipeline.state_manager import ModelStepUIState
6
+ from components.utils import DIRECTIONS, move_item
7
+ from utils import get_model_and_provider
8
+ from workflows.structs import FieldType, ModelStep
9
+
10
+
11
+ class ModelStepStateManager:
12
+ def __init__(self, max_input_fields: int, max_output_fields: int):
13
+ self.max_fields = {
14
+ "input": max_input_fields,
15
+ "output": max_output_fields,
16
+ }
17
+
18
+ # UI state update functions
19
+ def update_ui_state(self, ui_state: ModelStepUIState, key: str, value: Any) -> ModelStepUIState:
20
+ return ui_state.update(key, value)
21
+
22
+ # Property update functions
23
+ def update_step_name(self, model_step: ModelStep, value: str) -> ModelStep:
24
+ """Update the step name in state and accordion label."""
25
+ return model_step.update_property("name", value)
26
+
27
+ def update_temperature(self, model_step: ModelStep, value: float) -> ModelStep:
28
+ return model_step.update_property("temperature", value)
29
+
30
+ def update_model_and_provider(self, model_step: ModelStep, value: str) -> ModelStep:
31
+ """Update the model provider in the state."""
32
+ model, provider = get_model_and_provider(value)
33
+ return model_step.update({"model": model, "provider": provider})
34
+
35
+ def update_system_prompt(self, model_step: ModelStep, value: str) -> ModelStep:
36
+ """Update the system prompt in the state."""
37
+ return model_step.update_property("system_prompt", value)
38
+
39
+ # Field update functions
40
+ def update_input_field_name(self, model_step: ModelStep, value: str, index: int) -> ModelStep:
41
+ """Update a specific field of an input field at the given index."""
42
+ return model_step.update_field("input", index, "name", value)
43
+
44
+ def update_input_field_variable(self, model_step: ModelStep, value: str, index: int) -> ModelStep:
45
+ """Update a specific field of an input field at the given index."""
46
+ return model_step.update_field("input", index, "variable", value)
47
+
48
+ def update_input_field_description(self, model_step: ModelStep, value: str, index: int) -> ModelStep:
49
+ """Update a specific field of an input field at the given index."""
50
+ return model_step.update_field("input", index, "description", value)
51
+
52
+ def update_output_field_name(self, model_step: ModelStep, value: str, index: int) -> ModelStep:
53
+ """Update a specific field of an output field at the given index."""
54
+ return model_step.update_field("output", index, "name", value)
55
+
56
+ def update_output_field_type(self, model_step: ModelStep, value: str, index: int) -> ModelStep:
57
+ """Update a specific field of an output field at the given index."""
58
+ return model_step.update_field("output", index, "type", value)
59
+
60
+ def update_output_field_variable(self, model_step: ModelStep, value: str, index: int) -> ModelStep:
61
+ """Update a specific field of an output field at the given index."""
62
+ return model_step.update_field("output", index, "variable", value)
63
+
64
+ def update_output_field_description(self, model_step: ModelStep, value: str, index: int) -> ModelStep:
65
+ """Update a specific field of an output field at the given index."""
66
+ return model_step.update_field("output", index, "description", value)
67
+
68
+ def make_input_field_updates(self, model_step: ModelStep) -> list[gr.State | dict[str, Any]]:
69
+ fields = model_step.input_fields
70
+ updates = []
71
+ for i in range(self.max_fields["input"]):
72
+ if i < len(fields):
73
+ updates.extend(
74
+ [
75
+ gr.update(value=fields[i].name),
76
+ gr.update(value=fields[i].variable),
77
+ gr.update(value=fields[i].description),
78
+ ]
79
+ )
80
+ else:
81
+ updates.extend([gr.skip(), gr.skip(), gr.skip()])
82
+ return updates
83
+
84
+ def make_output_field_updates(self, model_step: ModelStep) -> list[gr.State | dict[str, Any]]:
85
+ fields = model_step.output_fields
86
+ updates = []
87
+ for i in range(self.max_fields["output"]):
88
+ if i < len(fields):
89
+ updates.extend(
90
+ [
91
+ gr.update(value=fields[i].name),
92
+ gr.update(value=fields[i].type),
93
+ gr.update(value=fields[i].description),
94
+ ]
95
+ )
96
+ else:
97
+ updates.extend([gr.skip(), gr.skip(), gr.skip()])
98
+ return updates
99
+
100
+ def _add_field(
101
+ self, model_step: ModelStep, field_type: FieldType, index: int = -1, input_var: str | None = None
102
+ ) -> tuple[Union[ModelStep, int, dict[str, Any]], ...]:
103
+ new_step = model_step.add_field(field_type, index, input_var)
104
+ fields = new_step.fields(field_type)
105
+ row_updates = [gr.update(visible=i < len(fields)) for i in range(self.max_fields[field_type])]
106
+ return new_step, len(fields), *row_updates
107
+
108
+ def _delete_field(
109
+ self, model_step: ModelStep, field_type: FieldType, index: int
110
+ ) -> tuple[Union[ModelStep, int, dict[str, Any]], ...]:
111
+ new_step = model_step.delete_field(field_type, index)
112
+ fields = new_step.fields(field_type)
113
+ row_updates = [gr.update(visible=i < len(fields)) for i in range(self.max_fields[field_type])]
114
+ return new_step, len(fields), *row_updates
115
+
116
+ # Field add/delete functions
117
+ def add_input_field(self, model_step: ModelStep, index: int = -1):
118
+ updates = self._add_field(model_step, "input", index, input_var="question_text")
119
+ return *updates, *self.make_input_field_updates(model_step)
120
+
121
+ def add_output_field(self, model_step: ModelStep, index: int = -1):
122
+ updates = self._add_field(model_step, "output", index)
123
+ return *updates, *self.make_output_field_updates(model_step)
124
+
125
+ def delete_input_field(self, model_step: ModelStep, index: int):
126
+ updates = self._delete_field(model_step, "input", index)
127
+ return *updates, *self.make_input_field_updates(model_step)
128
+
129
+ def delete_output_field(self, model_step: ModelStep, index: int):
130
+ updates = self._delete_field(model_step, "output", index)
131
+ return *updates, *self.make_output_field_updates(model_step)
132
+
133
+ def move_output_field(
134
+ self, model_step: ModelStep, index: int, direction: DIRECTIONS
135
+ ) -> list[gr.State | dict[str, Any]]:
136
+ """
137
+ Move an output field in the list either up or down.
138
+
139
+ Args:
140
+ index: Index of the output field to move
141
+ direction: Direction to move the field ('up' or 'down')
142
+
143
+ Returns:
144
+ list: A list containing [updated_state, field_value_updates...]
145
+ """
146
+ new_step = model_step.model_copy()
147
+ move_item(new_step.output_fields, index, direction)
148
+
149
+ # Update all output fields to reflect the new order
150
+ updates = self.make_output_field_updates(new_step)
151
+
152
+ return new_step, *updates
src/components/model_step/ui_components.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio.components import FormComponent
3
+
4
+
5
+ class ButtonGroup:
6
+ """Base class for button groups with common functionality."""
7
+
8
+ def __init__(self, events: list[str], *args, **kwargs):
9
+ self.buttons = {event: None for event in events}
10
+ self.click_args = {event: None for event in events}
11
+ self.render()
12
+
13
+ def render(self):
14
+ """Render the buttons and set up their event handlers."""
15
+ for event, button in self.buttons.items():
16
+ if self.click_args[event]:
17
+ button.click(*self.click_args[event])
18
+
19
+ def _setup_button(self, event, fn, inputs, outputs):
20
+ """Set up a button's click event handler."""
21
+ self.click_args[event] = fn, inputs, outputs
22
+ if self.buttons[event]:
23
+ self.buttons[event].click(fn, inputs, outputs)
24
+
25
+ def api_info(self):
26
+ return {
27
+ "name": self.__class__.__name__,
28
+ "events": self.EVENTS,
29
+ "inputs": [],
30
+ "outputs": [],
31
+ }
32
+
33
+ def example_payload(self):
34
+ """Return None since this component doesn't have direct input values."""
35
+ return None
36
+
37
+ def example_value(self):
38
+ """Return None since this component doesn't have direct output values."""
39
+ return None
40
+
41
+
42
+ class InputRowButtonGroup(ButtonGroup):
43
+ """Button group for input rows with delete and add buttons."""
44
+
45
+ EVENTS = ["delete", "add"]
46
+
47
+ def __init__(self, *args, **kwargs):
48
+ super().__init__(self.EVENTS, *args, **kwargs)
49
+
50
+ def render(self):
51
+ with gr.Column(scale=0, min_width=40, elem_classes="button-column"):
52
+ self.buttons["delete"] = gr.Button("❌", elem_classes="icon-button delete-button", scale=0)
53
+ self.buttons["add"] = gr.Button("➕", elem_classes="icon-button add-field-button", scale=0)
54
+ super().render()
55
+
56
+ def delete(self, fn, inputs, outputs):
57
+ self._setup_button("delete", fn, inputs, outputs)
58
+
59
+ def add(self, fn, inputs, outputs):
60
+ self._setup_button("add", fn, inputs, outputs)
61
+
62
+
63
+ class OutputRowButtonGroup(ButtonGroup):
64
+ """Button group for output rows with delete, add, up, and down buttons."""
65
+
66
+ EVENTS = ["delete", "add", "up", "down"]
67
+
68
+ def __init__(self, *args, **kwargs):
69
+ super().__init__(self.EVENTS, *args, **kwargs)
70
+
71
+ def render(self):
72
+ with gr.Column(scale=0, elem_classes="button-column", min_width=40):
73
+ self.buttons["delete"] = gr.Button("❌", elem_classes="icon-button delete-button", scale=0)
74
+ self.buttons["add"] = gr.Button("➕", elem_classes="icon-button add-field-button", scale=0)
75
+
76
+ with gr.Column(scale=0, elem_classes="button-column", min_width=40):
77
+ self.buttons["up"] = gr.Button("⬆️", elem_classes="icon-button up-button", scale=0)
78
+ self.buttons["down"] = gr.Button("⬇️", elem_classes="icon-button down-button", scale=0)
79
+ return super().render()
80
+
81
+ def delete(self, fn, inputs, outputs):
82
+ self._setup_button("delete", fn, inputs, outputs)
83
+
84
+ def add(self, fn, inputs, outputs):
85
+ self._setup_button("add", fn, inputs, outputs)
86
+
87
+ def up(self, fn, inputs, outputs):
88
+ self._setup_button("up", fn, inputs, outputs)
89
+
90
+ def down(self, fn, inputs, outputs):
91
+ self._setup_button("down", fn, inputs, outputs)
src/components/quizbowl/__init__.py ADDED
File without changes
src/components/quizbowl/bonus.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ from typing import Any
4
+
5
+ import gradio as gr
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import pandas as pd
9
+ from datasets import Dataset
10
+
11
+ from components.model_pipeline.model_pipeline import PipelineInterface, PipelineState
12
+ from submission import submit
13
+ from workflows import factory
14
+ from workflows.qb.simple_agent import SimpleBonusAgent
15
+ from workflows.structs import ModelStep, Workflow
16
+
17
+ from .plotting import (
18
+ create_pyplot,
19
+ create_scatter_pyplot,
20
+ evaluate_buzz,
21
+ update_plot,
22
+ )
23
+
24
+
25
+ def evaluate_bonus_part(prediction: str, clean_answers: list[str]) -> float:
26
+ """Evaluate a single bonus part."""
27
+ return evaluate_buzz(prediction, clean_answers)
28
+
29
+
30
+ def process_bonus_results(results: list[dict]) -> pd.DataFrame:
31
+ """Process results from bonus mode and prepare visualization data."""
32
+ return pd.DataFrame(
33
+ [
34
+ {
35
+ "Part": f"Part {r['part_number']}",
36
+ "Correct?": "✅" if r["score"] == 1 else "❌",
37
+ "Confidence": r["confidence"],
38
+ "Prediction": r["answer"],
39
+ "Explanation": r["explanation"],
40
+ }
41
+ for r in results
42
+ ]
43
+ )
44
+
45
+
46
+ def initialize_eval_interface(example: dict, model_outputs: list[dict]):
47
+ """Initialize the interface with example text."""
48
+ try:
49
+ # Create HTML for leadin and parts
50
+ leadin_html = f"<div class='leadin'>{example['leadin']}</div>"
51
+ parts_html = []
52
+ for i, part in enumerate(example["parts"]):
53
+ parts_html.append(f"<div class='part'><b>Part {i + 1}:</b> {part['part']}</div>")
54
+
55
+ html_content = f"{leadin_html}<div class='parts-container'>{''.join(parts_html)}</div>"
56
+
57
+ # Create confidence plot data
58
+ plot_data = create_bonus_confidence_plot(example["parts"], model_outputs)
59
+
60
+ # Store state
61
+ state = json.dumps({"parts": example["parts"], "outputs": model_outputs})
62
+
63
+ return html_content, plot_data, state
64
+ except Exception as e:
65
+ logging.error(f"Error initializing interface: {e}", exc_info=True)
66
+ return f"<div>Error initializing interface: {str(e)}</div>", pd.DataFrame(), "{}"
67
+
68
+
69
+ def create_bonus_confidence_plot(parts: list[dict], model_outputs: list[dict]):
70
+ """Create confidence plot for bonus parts."""
71
+ plt.style.use("ggplot")
72
+ fig = plt.figure(figsize=(10, 6))
73
+ ax = fig.add_subplot(111)
74
+
75
+ # Plot confidence for each part
76
+ x = range(1, len(parts) + 1)
77
+ confidences = [output["confidence"] for output in model_outputs]
78
+ scores = [output["score"] for output in model_outputs]
79
+
80
+ # Plot confidence bars
81
+ bars = ax.bar(x, confidences, color="#4698cf")
82
+
83
+ # Color bars based on correctness
84
+ for i, score in enumerate(scores):
85
+ bars[i].set_color("green" if score == 1 else "red")
86
+
87
+ ax.set_title("Part Confidence")
88
+ ax.set_xlabel("Part Number")
89
+ ax.set_ylabel("Confidence")
90
+ ax.set_xticks(x)
91
+ ax.set_xticklabels([f"Part {i}" for i in x])
92
+
93
+ return fig
94
+
95
+
96
+ def validate_workflow(workflow: Workflow):
97
+ """Validate that a workflow is properly configured for the bonus task."""
98
+ if not workflow.steps:
99
+ raise ValueError("Workflow must have at least one step")
100
+
101
+ # Ensure all steps are properly configured
102
+ for step_id, step in workflow.steps.items():
103
+ validate_model_step(step)
104
+
105
+ # Check that the workflow has the correct structure
106
+ input_vars = set(workflow.inputs)
107
+ if "leadin" not in input_vars or "part" not in input_vars:
108
+ raise ValueError("Workflow must have 'leadin' and 'part' as inputs")
109
+
110
+ output_vars = set(workflow.outputs)
111
+ if not all(var in output_vars for var in ["answer", "confidence", "explanation"]):
112
+ raise ValueError("Workflow must produce 'answer', 'confidence', and 'explanation' as outputs")
113
+
114
+
115
+ def validate_model_step(model_step: ModelStep):
116
+ """Validate that a model step is properly configured for the bonus task."""
117
+ # Check required fields
118
+ if not model_step.model or not model_step.provider:
119
+ raise ValueError("Model step must have both model and provider specified")
120
+
121
+ if model_step.call_type != "llm":
122
+ raise ValueError("Model step must have call_type 'llm'")
123
+
124
+ # Validate temperature for LLM steps
125
+ if model_step.temperature is None:
126
+ raise ValueError("Temperature must be specified for LLM model steps")
127
+
128
+ if not (0.0 <= model_step.temperature <= 1.0):
129
+ raise ValueError(f"Temperature must be between 0.0 and 1.0, got {model_step.temperature}")
130
+
131
+ # Validate input fields
132
+ input_field_names = {field.name for field in model_step.input_fields}
133
+ if "leadin" not in input_field_names or "part" not in input_field_names:
134
+ raise ValueError("Model step must have 'leadin' and 'part' input fields")
135
+
136
+ # Validate output fields
137
+ output_field_names = {field.name for field in model_step.output_fields}
138
+ required_outputs = {"answer", "confidence", "explanation"}
139
+ if not all(out in output_field_names for out in required_outputs):
140
+ raise ValueError("Model step must have all required output fields: answer, confidence, explanation")
141
+
142
+ # Validate confidence output field is of type float
143
+ for field in model_step.output_fields:
144
+ if field.name == "confidence" and field.type != "float":
145
+ raise ValueError("The 'confidence' output field must be of type 'float'")
146
+
147
+
148
+ class BonusInterface:
149
+ """Gradio interface for the Bonus mode."""
150
+
151
+ def __init__(self, app: gr.Blocks, dataset: Dataset, model_options: dict, defaults: dict):
152
+ """Initialize the Bonus interface."""
153
+ logging.info(f"Initializing Bonus interface with dataset size: {len(dataset)}")
154
+ self.ds = dataset
155
+ self.model_options = model_options
156
+ self.app = app
157
+ self.defaults = defaults
158
+ self.output_state = gr.State(value="{}")
159
+ self.render()
160
+
161
+ def _render_model_interface(self, workflow: Workflow, simple: bool = True):
162
+ """Render the model interface."""
163
+ self.pipeline_interface = PipelineInterface(
164
+ workflow,
165
+ simple=simple,
166
+ model_options=list(self.model_options.keys()),
167
+ )
168
+ with gr.Row():
169
+ self.run_btn = gr.Button("Run Bonus", variant="primary")
170
+
171
+ def _render_qb_interface(self):
172
+ """Render the quizbowl interface."""
173
+ with gr.Row():
174
+ self.qid_selector = gr.Number(
175
+ label="Question ID", value=1, precision=0, minimum=1, maximum=len(self.ds), show_label=True, scale=0
176
+ )
177
+ self.answer_display = gr.Textbox(
178
+ label="Answers", elem_id="answer-display", elem_classes="answer-box", interactive=False, scale=1
179
+ )
180
+ self.clean_answer_display = gr.Textbox(
181
+ label="Acceptable Answers",
182
+ elem_id="answer-display-2",
183
+ elem_classes="answer-box",
184
+ interactive=False,
185
+ scale=2,
186
+ )
187
+
188
+ self.question_display = gr.HTML(label="Question", elem_id="question-display")
189
+ with gr.Row():
190
+ self.confidence_plot = gr.Plot(
191
+ label="Part Confidence",
192
+ format="webp",
193
+ )
194
+
195
+ self.results_table = gr.DataFrame(
196
+ label="Model Outputs",
197
+ value=pd.DataFrame(columns=["Part", "Correct?", "Confidence", "Prediction", "Explanation"]),
198
+ )
199
+
200
+ with gr.Row():
201
+ self.eval_btn = gr.Button("Evaluate")
202
+
203
+ with gr.Accordion("Model Submission", elem_classes="model-submission-accordion", open=True):
204
+ with gr.Row():
205
+ self.model_name_input = gr.Textbox(label="Model Name")
206
+ self.description_input = gr.Textbox(label="Description")
207
+ with gr.Row():
208
+ gr.LoginButton()
209
+ self.submit_btn = gr.Button("Submit")
210
+ self.submit_status = gr.HTML(label="Submission Status")
211
+
212
+ def render(self):
213
+ """Create the Gradio interface."""
214
+ self.hidden_input = gr.Textbox(value="", visible=False, elem_id="hidden-index")
215
+ workflow = self.defaults["init_workflow"]
216
+
217
+ with gr.Row():
218
+ # Model Panel
219
+ with gr.Column(scale=1):
220
+ self._render_model_interface(workflow, simple=self.defaults["simple_workflow"])
221
+
222
+ with gr.Column(scale=1):
223
+ self._render_qb_interface()
224
+
225
+ self._setup_event_listeners()
226
+
227
+ def get_new_question_html(self, question_id: int):
228
+ """Get the HTML for a new question."""
229
+ example = self.ds[question_id - 1]
230
+ leadin = example["leadin"]
231
+ parts = example["parts"]
232
+
233
+ # Create HTML for leadin and parts
234
+ leadin_html = f"<div class='leadin'>{leadin}</div>"
235
+ parts_html = []
236
+ for i, part in enumerate(parts):
237
+ parts_html.append(f"<div class='part'>{part['part']}</div>")
238
+
239
+ parts_html_str = "<br>".join(parts_html)
240
+
241
+ html_content = (
242
+ f"<div class='token-container'>{leadin_html}<div class='parts-container'><br>{parts_html_str}</div></div>"
243
+ )
244
+
245
+ # Format answers
246
+ primary_answers = [f"{i + 1}. {part['answer_primary']}" for i, part in enumerate(parts)]
247
+ clean_answers = []
248
+ for i, part in enumerate(parts):
249
+ part_answers = [a for a in part["clean_answers"] if len(a.split()) <= 6]
250
+ clean_answers.append(f"{i + 1}. {', '.join(part_answers)}")
251
+
252
+ return html_content, "\n".join(primary_answers), "\n".join(clean_answers)
253
+
254
+ def get_model_outputs(self, example: dict, pipeline_state: PipelineState):
255
+ """Get the model outputs for a given question ID."""
256
+ outputs = []
257
+ leadin = example["leadin"]
258
+
259
+ for i, part in enumerate(example["parts"]):
260
+ agent = SimpleBonusAgent(workflow=pipeline_state.workflow)
261
+ # Run model for each part
262
+ part_output = agent.run(leadin, part["part"])
263
+
264
+ # Add part number and evaluate score
265
+ part_output["part_number"] = i + 1
266
+ part_output["score"] = evaluate_bonus_part(part_output["answer"], part["clean_answers"])
267
+
268
+ outputs.append(part_output)
269
+
270
+ return outputs
271
+
272
+ def run_bonus(
273
+ self,
274
+ question_id: int,
275
+ pipeline_state: PipelineState,
276
+ ) -> tuple[str, Any, Any]:
277
+ """Run the agent in bonus mode."""
278
+ try:
279
+ # Validate inputs
280
+ question_id = int(question_id - 1)
281
+ if not self.ds or question_id < 0 or question_id >= len(self.ds):
282
+ return "Invalid question ID or dataset not loaded", None, None
283
+
284
+ example = self.ds[question_id]
285
+ outputs = self.get_model_outputs(example, pipeline_state)
286
+
287
+ # Process results and prepare visualization data
288
+ html_content, plot_data, output_state = initialize_eval_interface(example, outputs)
289
+ df = process_bonus_results(outputs)
290
+
291
+ return (
292
+ html_content,
293
+ gr.update(value=plot_data, label=f"Part Confidence on Question {question_id + 1}"),
294
+ gr.update(value=output_state),
295
+ gr.update(value=df, label=f"Model Outputs for Question {question_id + 1}"),
296
+ )
297
+ except Exception as e:
298
+ import traceback
299
+
300
+ error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
301
+ return error_msg, None, None
302
+
303
+ def evaluate_bonus(self, pipeline_state: PipelineState, progress: gr.Progress = gr.Progress()):
304
+ """Evaluate the bonus questions."""
305
+ try:
306
+ # Validate inputs
307
+ if not self.ds or not self.ds.num_rows:
308
+ return "No dataset loaded", None, None
309
+
310
+ total_correct = 0
311
+ total_parts = 0
312
+ part_scores = []
313
+ part_numbers = []
314
+
315
+ for example in progress.tqdm(self.ds, desc="Evaluating bonus questions"):
316
+ model_outputs = self.get_model_outputs(example, pipeline_state)
317
+
318
+ for output in model_outputs:
319
+ total_parts += 1
320
+ if output["score"] == 1:
321
+ total_correct += 1
322
+ part_scores.append(output["score"])
323
+ part_numbers.append(output["part_number"])
324
+
325
+ accuracy = total_correct / total_parts
326
+ df = pd.DataFrame(
327
+ [
328
+ {
329
+ "Part Accuracy": f"{accuracy:.2%}",
330
+ "Total Score": f"{total_correct}/{total_parts}",
331
+ "Questions Evaluated": len(self.ds),
332
+ }
333
+ ]
334
+ )
335
+
336
+ plot_data = create_scatter_pyplot(part_numbers, part_scores)
337
+ return (
338
+ gr.update(value=df, label="Scores on Sample Set"),
339
+ gr.update(value=plot_data, label="Part Scores on Sample Set"),
340
+ )
341
+ except Exception:
342
+ import traceback
343
+
344
+ logging.error(f"Error evaluating bonus: {traceback.format_exc()}")
345
+ return "Error evaluating bonus", None, None
346
+
347
+ def submit_model(
348
+ self, model_name: str, description: str, pipeline_state: PipelineState, profile: gr.OAuthProfile = None
349
+ ):
350
+ """Submit the model output."""
351
+ return submit.submit_model(model_name, description, pipeline_state.workflow, "bonus", profile)
352
+
353
+ def _setup_event_listeners(self):
354
+ # Initialize with the default question (ID 0)
355
+
356
+ gr.on(
357
+ triggers=[self.app.load, self.qid_selector.change],
358
+ fn=self.get_new_question_html,
359
+ inputs=[self.qid_selector],
360
+ outputs=[self.question_display, self.answer_display, self.clean_answer_display],
361
+ )
362
+ self.run_btn.click(
363
+ self.pipeline_interface.validate_workflow,
364
+ inputs=[self.pipeline_interface.pipeline_state],
365
+ outputs=[self.pipeline_interface.pipeline_state],
366
+ ).success(
367
+ self.run_bonus,
368
+ inputs=[
369
+ self.qid_selector,
370
+ self.pipeline_interface.pipeline_state,
371
+ ],
372
+ outputs=[
373
+ self.question_display,
374
+ self.confidence_plot,
375
+ self.output_state,
376
+ self.results_table,
377
+ ],
378
+ )
379
+
380
+ self.eval_btn.click(
381
+ fn=self.evaluate_bonus,
382
+ inputs=[self.pipeline_interface.pipeline_state],
383
+ outputs=[self.results_table, self.confidence_plot],
384
+ )
385
+
386
+ self.submit_btn.click(
387
+ fn=self.submit_model_output,
388
+ inputs=[
389
+ self.model_name_input,
390
+ self.description_input,
391
+ self.pipeline_interface.pipeline_state,
392
+ ],
393
+ outputs=[self.submit_status],
394
+ )
395
+ self.hidden_input.change(
396
+ fn=update_plot,
397
+ inputs=[self.hidden_input, self.output_state],
398
+ outputs=[self.confidence_plot],
399
+ )
src/components/quizbowl/plotting.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import re
4
+ from collections import Counter
5
+
6
+ import matplotlib.pyplot as plt
7
+ import pandas as pd
8
+
9
+
10
+ def evaluate_buzz(prediction: str, clean_answers: list[str] | str) -> int:
11
+ """Evaluate the buzz of a prediction against the clean answers."""
12
+ if isinstance(clean_answers, str):
13
+ print("clean_answers is a string")
14
+ clean_answers = [clean_answers]
15
+ pred = prediction.lower().strip()
16
+ if not pred:
17
+ return 0
18
+ for answer in clean_answers:
19
+ answer = answer.strip().lower()
20
+ if answer and answer in pred:
21
+ print(f"Found {answer} in {pred}")
22
+ return 1
23
+ return 0
24
+
25
+
26
+ def create_answer_html(answer: str):
27
+ """Create HTML for the answer."""
28
+ return f"<div class='answer-header'>Answer:<br>{answer}</div>"
29
+
30
+
31
+ def create_tokens_html(tokens: list[str], eval_points: list[tuple], answer: str, marker_indices: list[int] = None):
32
+ """Create HTML for tokens with hover capability and a colored header for the answer."""
33
+ try:
34
+ html_parts = []
35
+ ep = dict(eval_points)
36
+ marker_indices = set(marker_indices) if isinstance(marker_indices, list) else set()
37
+
38
+ # Add a colored header for the answer
39
+ # html_parts.append(create_answer_html(answer))
40
+
41
+ for i, token in enumerate(tokens):
42
+ # Check if this token is a buzz point
43
+ values = ep.get(i, (None, 0, 0))
44
+ confidence, buzz_point, score = values
45
+
46
+ # Replace non-word characters for proper display in HTML
47
+ display_token = token
48
+ if not re.match(r"\w+", token):
49
+ display_token = token.replace(" ", "&nbsp;")
50
+
51
+ # Add buzz marker class if it's a buzz point
52
+ if confidence is None:
53
+ css_class = ""
54
+ elif not buzz_point:
55
+ css_class = " guess-point no-buzz"
56
+ else:
57
+ css_class = f" guess-point buzz-{score}"
58
+
59
+ token_html = f'<span id="token-{i}" class="token{css_class}" data-index="{i}">{display_token}</span>'
60
+ if i in marker_indices:
61
+ token_html += "<span style='color: rgba(0,0,255,0.3);'>|</span>"
62
+ html_parts.append(token_html)
63
+
64
+ return f"<div class='token-container'>{''.join(html_parts)}</div>"
65
+ except Exception as e:
66
+ logging.error(f"Error creating token HTML: {e}", exc_info=True)
67
+ return f"<div class='token-container'>Error creating tokens: {str(e)}</div>"
68
+
69
+
70
+ def create_line_plot(eval_points, highlighted_index=-1):
71
+ """Create a Gradio LinePlot of token values with optional highlighting using DataFrame."""
72
+ try:
73
+ # Create base confidence data
74
+ data = []
75
+
76
+ # Add buzz points to the plot
77
+ for i, (v, b) in eval_points:
78
+ color = "#ff4444" if b == 0 else "#228b22"
79
+ data.append(
80
+ {
81
+ "position": i,
82
+ "value": v,
83
+ "type": "buzz",
84
+ "highlight": True,
85
+ "color": color,
86
+ }
87
+ )
88
+
89
+ if highlighted_index >= 0:
90
+ # Add vertical line for the highlighted token
91
+ data.extend(
92
+ [
93
+ {
94
+ "position": highlighted_index,
95
+ "value": 0,
96
+ "type": "hover-line",
97
+ "color": "#000000",
98
+ "highlight": True,
99
+ },
100
+ {
101
+ "position": highlighted_index,
102
+ "value": 1,
103
+ "type": "hover-line",
104
+ "color": "#000000",
105
+ "highlight": True,
106
+ },
107
+ ]
108
+ )
109
+
110
+ return pd.DataFrame(data)
111
+ except Exception as e:
112
+ logging.error(f"Error creating line plot: {e}", exc_info=True)
113
+ # Return an empty DataFrame with the expected columns
114
+ return pd.DataFrame(columns=["position", "value", "type", "highlight", "color"])
115
+
116
+
117
+ def create_pyplot(tokens, eval_points, highlighted_index=-1):
118
+ """Create a pyplot of token values with optional highlighting."""
119
+ plt.style.use("ggplot") # Set theme to grid paper
120
+ fig = plt.figure(figsize=(10, 6)) # Set figure size
121
+ ax = fig.add_subplot(111)
122
+ x = [0]
123
+ y = [0]
124
+ for i, (v, b, s) in eval_points:
125
+ x.append(i + 1)
126
+ y.append(v)
127
+
128
+ ax.plot(x, y, "o--", color="#4698cf")
129
+ for i, (v, b, s) in eval_points:
130
+ if not b:
131
+ continue
132
+ color = "green" if s else "red"
133
+ ax.plot(i + 1, v, "o", color=color)
134
+ if i >= len(tokens):
135
+ print(f"Token index {i} is out of bounds for n_tokens: {len(tokens)}")
136
+ ax.annotate(f"{tokens[i]}", (i + 1, v), textcoords="offset points", xytext=(0, 10), ha="center")
137
+
138
+ if highlighted_index >= 0:
139
+ # Add light vertical line for the highlighted token from 0 to 1
140
+ ax.axvline(x=highlighted_index + 1, color="#ff9900", linestyle="--", ymin=0, ymax=1)
141
+
142
+ ax.set_title("Buzz Confidence")
143
+ ax.set_xlabel("Token Index")
144
+ ax.set_ylabel("Confidence")
145
+ ax.set_xticks(x)
146
+ ax.set_xticklabels(x)
147
+ return fig
148
+
149
+
150
+ def create_scatter_pyplot(token_positions, scores):
151
+ """Create a scatter plot of token positions and scores."""
152
+ plt.style.use("ggplot")
153
+ fig = plt.figure(figsize=(10, 6))
154
+ ax = fig.add_subplot(111)
155
+
156
+ counts = Counter(zip(token_positions, scores))
157
+ X = []
158
+ Y = []
159
+ S = []
160
+ for (pos, score), size in counts.items():
161
+ X.append(pos)
162
+ Y.append(score)
163
+ S.append(size * 20)
164
+
165
+ ax.scatter(X, Y, color="#4698cf", s=S)
166
+
167
+ return fig
168
+
169
+
170
+ def update_plot(highlighted_index, state):
171
+ """Update the plot when a token is hovered; add a vertical line on the plot."""
172
+ try:
173
+ if not state or state == "{}":
174
+ logging.warning("Empty state provided to update_plot")
175
+ return pd.DataFrame()
176
+
177
+ highlighted_index = int(highlighted_index) if highlighted_index else None
178
+ logging.info(f"Update plot triggered with token index: {highlighted_index}")
179
+
180
+ data = json.loads(state)
181
+ tokens = data.get("tokens", [])
182
+ values = data.get("values", [])
183
+
184
+ if not tokens or not values:
185
+ logging.warning("No tokens or values found in state")
186
+ return pd.DataFrame()
187
+
188
+ # Create updated plot with highlighting of the token point
189
+ # plot_data = create_line_plot(values, highlighted_index)
190
+ plot_data = create_pyplot(tokens, values, highlighted_index)
191
+ return plot_data
192
+ except Exception as e:
193
+ logging.error(f"Error updating plot: {e}")
194
+ return pd.DataFrame()
src/components/quizbowl/tossup.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ from typing import Any
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+ import pandas as pd
8
+ from datasets import Dataset
9
+
10
+ from components.model_pipeline.model_pipeline import PipelineInterface, PipelineState
11
+ from submission import submit
12
+ from workflows.qb.simple_agent import SimpleTossupAgent
13
+ from workflows.structs import ModelStep, Workflow
14
+
15
+ from .plotting import (
16
+ create_answer_html,
17
+ create_pyplot,
18
+ create_scatter_pyplot,
19
+ create_tokens_html,
20
+ evaluate_buzz,
21
+ update_plot,
22
+ )
23
+
24
+
25
+ def add_model_scores(model_outputs: list[dict], clean_answers: list[str], run_indices: list[int]) -> list[dict]:
26
+ """Add model scores to the model outputs."""
27
+ for output, run_idx in zip(model_outputs, run_indices):
28
+ output["score"] = evaluate_buzz(output["answer"], clean_answers)
29
+ output["token_position"] = run_idx + 1
30
+ return model_outputs
31
+
32
+
33
+ def prepare_buzz_evals(
34
+ run_indices: list[int], model_outputs: list[dict]
35
+ ) -> tuple[list[str], list[tuple[int, float, bool]]]:
36
+ """Process text into tokens and assign random values for demonstration."""
37
+ if not run_indices:
38
+ logging.warning("No run indices provided, returning empty results")
39
+ return [], []
40
+ eval_points = []
41
+ for i, v in zip(run_indices, model_outputs):
42
+ eval_point = v["confidence"], v["buzz"], v["score"]
43
+ eval_points.append((int(i), eval_point))
44
+
45
+ return eval_points
46
+
47
+
48
+ def initialize_eval_interface(example, model_outputs: list[dict]):
49
+ """Initialize the interface with example text."""
50
+ tokens = example["question"].split()
51
+ run_indices = example["run_indices"]
52
+ answer = example["answer_primary"]
53
+
54
+ try:
55
+ eval_points = prepare_buzz_evals(run_indices, model_outputs)
56
+
57
+ if not tokens:
58
+ return "<div>No tokens found in the provided text.</div>", pd.DataFrame(), "{}"
59
+ highlighted_index = next((int(i) for i, (_, b, _) in eval_points if b == 1), -1)
60
+ html_content = create_tokens_html(tokens, eval_points, answer)
61
+ plot_data = create_pyplot(tokens, eval_points, highlighted_index)
62
+
63
+ # Store tokens, values, and buzzes as JSON for later use
64
+ state = json.dumps({"tokens": tokens, "values": eval_points})
65
+
66
+ return html_content, plot_data, state
67
+ except Exception as e:
68
+ logging.error(f"Error initializing interface: {e}", exc_info=True)
69
+ return f"<div>Error initializing interface: {str(e)}</div>", pd.DataFrame(), "{}"
70
+
71
+
72
+ def process_tossup_results(results: list[dict], top_k_mode: bool = False) -> pd.DataFrame:
73
+ """Process results from tossup mode and prepare visualization data."""
74
+ # Create DataFrame for detailed results
75
+ if top_k_mode:
76
+ raise ValueError("Top-k mode not supported for tossup mode")
77
+ return pd.DataFrame(
78
+ [
79
+ {
80
+ "Token Position": r["token_position"],
81
+ "Correct?": "✅" if r["score"] == 1 else "❌",
82
+ "Confidence": r["confidence"],
83
+ "Prediction": r["answer"],
84
+ }
85
+ for r in results
86
+ ]
87
+ )
88
+
89
+
90
+ def validate_workflow(workflow: Workflow):
91
+ """
92
+ Validate that a workflow is properly configured for the tossup task.
93
+
94
+ Args:
95
+ workflow (Workflow): The workflow to validate
96
+
97
+ Raises:
98
+ ValueError: If the workflow is not properly configured
99
+ """
100
+ if not workflow.steps:
101
+ raise ValueError("Workflow must have at least one step")
102
+
103
+ # Ensure all steps are properly configured
104
+ for step_id, step in workflow.steps.items():
105
+ validate_model_step(step)
106
+
107
+ # Check that the workflow has the correct structure
108
+ input_vars = set(workflow.inputs)
109
+ if "question" not in input_vars:
110
+ raise ValueError("Workflow must have 'question' as an input")
111
+
112
+ output_vars = set(workflow.outputs)
113
+ if not any("answer" in out_var for out_var in output_vars):
114
+ raise ValueError("Workflow must produce an 'answer' as output")
115
+ if not any("confidence" in out_var for out_var in output_vars):
116
+ raise ValueError("Workflow must produce a 'confidence' score as output")
117
+
118
+
119
+ def validate_model_step(model_step: ModelStep):
120
+ """
121
+ Validate that a model step is properly configured for the tossup task.
122
+
123
+ Args:
124
+ model_step (ModelStep): The model step to validate
125
+
126
+ Raises:
127
+ ValueError: If the model step is not properly configured
128
+ """
129
+ # Check required fields
130
+ if not model_step.model or not model_step.provider:
131
+ raise ValueError("Model step must have both model and provider specified")
132
+
133
+ if model_step.call_type != "llm":
134
+ raise ValueError("Model step must have call_type 'llm'")
135
+
136
+ # Validate temperature for LLM steps
137
+ if model_step.temperature is None:
138
+ raise ValueError("Temperature must be specified for LLM model steps")
139
+
140
+ if not (0.0 <= model_step.temperature <= 1.0):
141
+ raise ValueError(f"Temperature must be between 0.0 and 1.0, got {model_step.temperature}")
142
+
143
+ # Validate input fields
144
+ input_field_names = {field.name for field in model_step.input_fields}
145
+ if "question" not in input_field_names:
146
+ raise ValueError("Model step must have a 'question' input field")
147
+
148
+ # Validate output fields
149
+ output_field_names = {field.name for field in model_step.output_fields}
150
+ if "answer" not in output_field_names:
151
+ raise ValueError("Model step must have an 'answer' output field")
152
+ if "confidence" not in output_field_names:
153
+ raise ValueError("Model step must have a 'confidence' output field")
154
+
155
+ # Validate confidence output field is of type float
156
+ for field in model_step.output_fields:
157
+ if field.name == "confidence" and field.type != "float":
158
+ raise ValueError("The 'confidence' output field must be of type 'float'")
159
+
160
+
161
+ class TossupInterface:
162
+ """Gradio interface for the Tossup mode."""
163
+
164
+ def __init__(self, app: gr.Blocks, dataset: Dataset, model_options: dict, defaults: dict):
165
+ """Initialize the Tossup interface."""
166
+ logging.info(f"Initializing Tossup interface with dataset size: {len(dataset)}")
167
+ self.ds = dataset
168
+ self.model_options = model_options
169
+ self.app = app
170
+ self.defaults = defaults
171
+ self.output_state = gr.State(value="{}")
172
+ self.render()
173
+
174
+ def _render_model_interface(self, workflow: Workflow, simple: bool = True):
175
+ """Render the model interface."""
176
+ self.pipeline_interface = PipelineInterface(
177
+ workflow,
178
+ simple=simple,
179
+ model_options=list(self.model_options.keys()),
180
+ )
181
+ with gr.Row():
182
+ self.buzz_t_slider = gr.Slider(
183
+ minimum=0.5,
184
+ maximum=1.0,
185
+ value=self.defaults["buzz_threshold"],
186
+ step=0.01,
187
+ label="Buzz Threshold",
188
+ )
189
+ self.early_stop_checkbox = gr.Checkbox(
190
+ value=self.defaults["early_stop"],
191
+ label="Early Stop",
192
+ info="Stop early if already buzzed",
193
+ )
194
+ self.run_btn = gr.Button("Run Tossup", variant="primary")
195
+
196
+ def _render_qb_interface(self):
197
+ """Render the quizbowl interface."""
198
+ with gr.Row():
199
+ self.qid_selector = gr.Number(
200
+ label="Question ID", value=1, precision=0, minimum=1, maximum=len(self.ds), show_label=True, scale=0
201
+ )
202
+ self.answer_display = gr.Textbox(
203
+ label="PrimaryAnswer", elem_id="answer-display", elem_classes="answer-box", interactive=False, scale=1
204
+ )
205
+ self.clean_answer_display = gr.Textbox(
206
+ label="Acceptable Answers",
207
+ elem_id="answer-display-2",
208
+ elem_classes="answer-box",
209
+ interactive=False,
210
+ scale=2,
211
+ )
212
+ # self.answer_display = gr.HTML(label="Answer", elem_id="answer-display")
213
+ self.question_display = gr.HTML(label="Question", elem_id="question-display")
214
+ with gr.Row():
215
+ self.confidence_plot = gr.Plot(
216
+ label="Buzz Confidence",
217
+ format="webp",
218
+ )
219
+ self.results_table = gr.DataFrame(
220
+ label="Model Outputs",
221
+ value=pd.DataFrame(columns=["Token Position", "Correct?", "Confidence", "Prediction"]),
222
+ )
223
+ with gr.Row():
224
+ self.eval_btn = gr.Button("Evaluate")
225
+
226
+ with gr.Accordion("Model Submission", elem_classes="model-submission-accordion", open=True):
227
+ with gr.Row():
228
+ self.model_name_input = gr.Textbox(label="Model Name")
229
+ self.description_input = gr.Textbox(label="Description")
230
+ with gr.Row():
231
+ gr.LoginButton()
232
+ self.submit_btn = gr.Button("Submit")
233
+ self.submit_status = gr.HTML(label="Submission Status")
234
+
235
+ def render(self):
236
+ """Create the Gradio interface."""
237
+
238
+ self.hidden_input = gr.Textbox(value="", visible=False, elem_id="hidden-index")
239
+
240
+ workflow = self.defaults["init_workflow"]
241
+
242
+ with gr.Row():
243
+ # Model Panel
244
+ with gr.Column(scale=1):
245
+ self._render_model_interface(workflow, simple=self.defaults["simple_workflow"])
246
+
247
+ with gr.Column(scale=1):
248
+ self._render_qb_interface()
249
+
250
+ self._setup_event_listeners()
251
+
252
+ def get_full_question(self, question_id: int) -> str:
253
+ """Get the full question text for a given question ID."""
254
+ try:
255
+ question_id = int(question_id - 1)
256
+ if not self.ds or question_id < 0 or question_id >= len(self.ds):
257
+ return "Invalid question ID or dataset not loaded"
258
+
259
+ question_data = self.ds[question_id]
260
+ # Get the full question text (the last element in question_runs)
261
+ full_question = question_data["question"]
262
+ gold_label = question_data["answer_primary"]
263
+
264
+ return f"Question: {full_question}\n\nCorrect Answer: {gold_label}"
265
+ except Exception as e:
266
+ return f"Error loading question: {str(e)}"
267
+
268
+ def validate_workflow(self, pipeline_state: PipelineState):
269
+ """Validate the workflow."""
270
+ try:
271
+ validate_workflow(pipeline_state.workflow)
272
+ except Exception as e:
273
+ raise gr.Error(f"Error validating workflow: {str(e)}")
274
+
275
+ def get_new_question_html(self, question_id: int):
276
+ """Get the HTML for a new question."""
277
+ example = self.ds[question_id - 1]
278
+ question = example["question"]
279
+ gold_label = example["answer_primary"]
280
+ marker_indices = example["run_indices"]
281
+ tokens = question.split()
282
+ question_html = create_tokens_html(tokens, [], gold_label, marker_indices)
283
+ clean_answers = [a for a in example["clean_answers"] if len(a.split()) <= 6]
284
+ clean_answers = ", ".join(clean_answers)
285
+ return question_html, gold_label, clean_answers
286
+
287
+ def get_model_outputs(self, example: dict, pipeline_state: PipelineState, buzz_threshold: float, early_stop: bool):
288
+ """Get the model outputs for a given question ID."""
289
+ question_runs = []
290
+ tokens = example["question"].split()
291
+ for run_idx in example["run_indices"]:
292
+ question_runs.append(" ".join(tokens[: run_idx + 1]))
293
+
294
+ agent = SimpleTossupAgent(workflow=pipeline_state.workflow, buzz_threshold=buzz_threshold)
295
+ outputs = list(agent.run(question_runs, early_stop=early_stop))
296
+ outputs = add_model_scores(outputs, example["clean_answers"], example["run_indices"])
297
+ return outputs
298
+
299
+ def run_tossup(
300
+ self,
301
+ question_id: int,
302
+ pipeline_state: PipelineState,
303
+ buzz_threshold: float,
304
+ early_stop: bool = True,
305
+ ) -> tuple[str, Any, Any]:
306
+ """Run the agent in tossup mode with a system prompt."""
307
+ try:
308
+ # Validate inputs
309
+ question_id = int(question_id - 1)
310
+ if not self.ds or question_id < 0 or question_id >= len(self.ds):
311
+ return "Invalid question ID or dataset not loaded", None, None
312
+ example = self.ds[question_id]
313
+ outputs = self.get_model_outputs(example, pipeline_state, buzz_threshold, early_stop)
314
+
315
+ # Process results and prepare visualization data
316
+ tokens_html, plot_data, output_state = initialize_eval_interface(example, outputs)
317
+ df = process_tossup_results(outputs)
318
+ return (
319
+ tokens_html,
320
+ gr.update(value=plot_data, label=f"Buzz Confidence on Question {question_id + 1}"),
321
+ gr.update(value=output_state),
322
+ gr.update(value=df, label=f"Model Outputs for Question {question_id + 1}"),
323
+ )
324
+ except Exception as e:
325
+ import traceback
326
+
327
+ error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
328
+ return error_msg, None, None
329
+
330
+ def evaluate_tossups(
331
+ self, pipeline_state: PipelineState, buzz_threshold: float, progress: gr.Progress = gr.Progress()
332
+ ):
333
+ """Evaluate the tossup."""
334
+ try:
335
+ # Validate inputs
336
+ if not self.ds or not self.ds.num_rows:
337
+ return "No dataset loaded", None, None
338
+
339
+ buzz_counts = 0
340
+ correct_buzzes = 0
341
+ token_positions = []
342
+ correctness = []
343
+ for example in progress.tqdm(self.ds, desc="Evaluating tossup questions"):
344
+ model_outputs = self.get_model_outputs(example, pipeline_state, buzz_threshold, early_stop=True)
345
+ if model_outputs[-1]["buzz"]:
346
+ buzz_counts += 1
347
+ if model_outputs[-1]["score"] == 1:
348
+ correct_buzzes += 1
349
+ token_positions.append(model_outputs[-1]["token_position"])
350
+ correctness.append(model_outputs[-1]["score"])
351
+ buzz_accuracy = correct_buzzes / buzz_counts
352
+ df = pd.DataFrame(
353
+ [
354
+ {
355
+ "Avg Buzz Position": f"{np.mean(token_positions):.2f}",
356
+ "Buzz Accuracy": f"{buzz_accuracy:.2%}",
357
+ "Total Score": f"{correct_buzzes}/{len(self.ds)}",
358
+ }
359
+ ]
360
+ )
361
+ plot_data = create_scatter_pyplot(token_positions, correctness)
362
+ return (
363
+ gr.update(value=df, label="Scores on Sample Set"),
364
+ gr.update(value=plot_data, label="Buzz Positions on Sample Set"),
365
+ )
366
+ except Exception:
367
+ import traceback
368
+
369
+ logging.error(f"Error evaluating tossups: {traceback.format_exc()}")
370
+ return "Error evaluating tossups", None, None
371
+
372
+ def submit_model(
373
+ self, model_name: str, description: str, pipeline_state: PipelineState, profile: gr.OAuthProfile = None
374
+ ):
375
+ """Submit the model output."""
376
+ return submit.submit_model(model_name, description, pipeline_state.workflow, "tossup", profile)
377
+
378
+ def _setup_event_listeners(self):
379
+ gr.on(
380
+ triggers=[self.app.load, self.qid_selector.change],
381
+ fn=self.get_new_question_html,
382
+ inputs=[self.qid_selector],
383
+ outputs=[self.question_display, self.answer_display, self.clean_answer_display],
384
+ )
385
+
386
+ self.run_btn.click(
387
+ self.pipeline_interface.validate_workflow,
388
+ inputs=[self.pipeline_interface.pipeline_state],
389
+ outputs=[self.pipeline_interface.pipeline_state],
390
+ ).success(
391
+ self.run_tossup,
392
+ inputs=[
393
+ self.qid_selector,
394
+ self.pipeline_interface.pipeline_state,
395
+ self.buzz_t_slider,
396
+ self.early_stop_checkbox,
397
+ ],
398
+ outputs=[
399
+ self.question_display,
400
+ self.confidence_plot,
401
+ self.output_state,
402
+ self.results_table,
403
+ ],
404
+ )
405
+
406
+ self.eval_btn.click(
407
+ fn=self.evaluate_tossups,
408
+ inputs=[self.pipeline_interface.pipeline_state, self.buzz_t_slider],
409
+ outputs=[self.results_table, self.confidence_plot],
410
+ )
411
+
412
+ self.submit_btn.click(
413
+ fn=self.submit_model,
414
+ inputs=[
415
+ self.model_name_input,
416
+ self.description_input,
417
+ self.pipeline_interface.pipeline_state,
418
+ ],
419
+ outputs=[self.submit_status],
420
+ )
421
+
422
+ self.hidden_input.change(
423
+ fn=update_plot,
424
+ inputs=[self.hidden_input, self.output_state],
425
+ outputs=[self.confidence_plot],
426
+ )
src/components/quizbowl/utils.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+
3
+ import pandas as pd
4
+
5
+
6
+ def _create_confidence_plot_data(results: List[Dict], top_k_mode: bool = False) -> pd.DataFrame:
7
+ """Create a DataFrame for the confidence plot."""
8
+ if not top_k_mode:
9
+ return pd.DataFrame(
10
+ {
11
+ "position": [r["position"] for r in results],
12
+ "confidence": [r["confidence"] for r in results],
13
+ "answer": [r["answer"] for r in results],
14
+ }
15
+ )
16
+
17
+ # For top-k mode, extract and plot top answers
18
+ return _create_top_k_plot_data(results)
19
+
20
+
21
+ def _create_top_k_plot_data(results: List[Dict]) -> pd.DataFrame:
22
+ """Create plot data for top-k mode."""
23
+ # Find top answers across all positions (limited to top 5)
24
+ top_answers = set()
25
+ for r in results:
26
+ for g in r.get("guesses", [])[:3]: # Get top 3 from each position
27
+ if g.get("answer"):
28
+ top_answers.add(g.get("answer"))
29
+
30
+ top_answers = list(top_answers)[:5] # Limit to 5 total answers
31
+
32
+ # Create plot data for each answer
33
+ all_data = []
34
+ for position_idx, result in enumerate(results):
35
+ position = result["position"]
36
+ for answer in top_answers:
37
+ confidence = 0
38
+ for guess in result.get("guesses", []):
39
+ if guess.get("answer") == answer:
40
+ confidence = guess.get("confidence", 0)
41
+ break
42
+ all_data.append({"position": position, "confidence": confidence, "answer": answer})
43
+
44
+ return pd.DataFrame(all_data)
45
+
46
+
47
+ def _create_top_k_dataframe(results: List[Dict]) -> pd.DataFrame:
48
+ """Create a DataFrame for top-k results."""
49
+ df_rows = []
50
+ for result in results:
51
+ position = result["position"]
52
+ for i, guess in enumerate(result.get("guesses", [])):
53
+ df_rows.append(
54
+ {
55
+ "position": position,
56
+ "answer": guess.get("answer", ""),
57
+ "confidence": guess.get("confidence", 0),
58
+ "rank": i + 1,
59
+ }
60
+ )
61
+ return pd.DataFrame(df_rows)
62
+
63
+
64
+ def _format_buzz_result(buzzed: bool, results: List[Dict], gold_label: str, top_k_mode: bool) -> tuple[str, str, bool]:
65
+ """Format the result text based on whether the agent buzzed."""
66
+ if not buzzed:
67
+ return f"Did not buzz. Correct answer was: {gold_label}", "No buzz", False
68
+
69
+ buzz_position = next(i for i, r in enumerate(results) if r.get("buzz", False))
70
+ buzz_result = results[buzz_position]
71
+
72
+ if top_k_mode:
73
+ # For top-k, check if any of the top guesses match
74
+ top_answers = [g.get("answer", "").lower() for g in buzz_result.get("guesses", [])]
75
+ correct = gold_label.lower() in [a.lower() for a in top_answers]
76
+ final_answer = top_answers[0] if top_answers else "No answer"
77
+ else:
78
+ # For regular mode
79
+ final_answer = buzz_result["answer"]
80
+ correct = final_answer.lower() == gold_label.lower()
81
+
82
+ result_text = f"BUZZED at position {buzz_position + 1} with answer: {final_answer}\n"
83
+ result_text += f"Correct answer: {gold_label}\n"
84
+ result_text += f"Result: {'CORRECT' if correct else 'INCORRECT'}"
85
+
86
+ return result_text, final_answer, correct
src/components/utils.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Literal
2
+
3
+ import gradio as gr
4
+
5
+ DIRECTIONS = Literal["up", "down"]
6
+
7
+
8
+ def make_state(value: Any) -> gr.State:
9
+ """Make a state from a value."""
10
+ if isinstance(value, gr.State):
11
+ return value
12
+ else:
13
+ return gr.State(value)
14
+
15
+
16
+ # List utilities
17
+ def move_item(items: list, position: int, direction: DIRECTIONS):
18
+ """Move an item up or down in a list."""
19
+ if not isinstance(items, list):
20
+ raise ValueError("items must be a list")
21
+ if not isinstance(position, int) or not (0 <= position < len(items)):
22
+ raise ValueError("position must be a valid index in the list")
23
+ if direction not in ["up", "down"]:
24
+ raise ValueError("direction must be 'up' or 'down'")
25
+
26
+ if direction == "up" and position > 0:
27
+ items[position], items[position - 1] = items[position - 1], items[position]
28
+ elif direction == "down" and position < len(items) - 1:
29
+ items[position], items[position + 1] = items[position + 1], items[position]
src/display/__init__.py ADDED
File without changes
src/display/css_html_js.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ custom_css = """
2
+
3
+ .markdown-text {
4
+ font-size: 16px !important;
5
+ }
6
+
7
+ #models-to-add-text {
8
+ font-size: 18px !important;
9
+ }
10
+
11
+ #citation-button span {
12
+ font-size: 16px !important;
13
+ }
14
+
15
+ #citation-button textarea {
16
+ font-size: 16px !important;
17
+ }
18
+
19
+ #citation-button > label > button {
20
+ margin: 6px;
21
+ transform: scale(1.3);
22
+ }
23
+
24
+ #leaderboard-table {
25
+ margin-top: 15px
26
+ }
27
+
28
+ #leaderboard-table-lite {
29
+ margin-top: 15px
30
+ }
31
+
32
+ #search-bar-table-box > div:first-child {
33
+ background: none;
34
+ border: none;
35
+ }
36
+
37
+ #search-bar {
38
+ padding: 0px;
39
+ }
40
+
41
+ /* Limit the width of the first AutoEvalColumn so that names don't expand too much */
42
+ #leaderboard-table td:nth-child(2),
43
+ #leaderboard-table th:nth-child(2) {
44
+ max-width: 400px;
45
+ overflow: auto;
46
+ white-space: nowrap;
47
+ }
48
+
49
+ /* Workflow JSON styling */
50
+ .workflow-json-container {
51
+ margin-top: 20px;
52
+ margin-bottom: 30px;
53
+ }
54
+
55
+ .workflow-json {
56
+ border: 1px solid #ddd;
57
+ border-radius: 8px;
58
+ box-shadow: 0 2px 5px rgba(0,0,0,0.1);
59
+ }
60
+
61
+ .workflow-json pre {
62
+ max-height: 500px;
63
+ overflow-y: auto;
64
+ }
65
+
66
+ .tab-buttons button {
67
+ font-size: 20px;
68
+ }
69
+
70
+ #scale-logo {
71
+ border-style: none !important;
72
+ box-shadow: none;
73
+ display: block;
74
+ margin-left: auto;
75
+ margin-right: auto;
76
+ max-width: 600px;
77
+ }
78
+
79
+ #scale-logo .download {
80
+ display: none;
81
+ }
82
+ #filter_type{
83
+ border: 0;
84
+ padding-left: 0;
85
+ padding-top: 0;
86
+ }
87
+ #filter_type label {
88
+ display: flex;
89
+ }
90
+ #filter_type label > span{
91
+ margin-top: var(--spacing-lg);
92
+ margin-right: 0.5em;
93
+ }
94
+ #filter_type label > .wrap{
95
+ width: 103px;
96
+ }
97
+ #filter_type label > .wrap .wrap-inner{
98
+ padding: 2px;
99
+ }
100
+ #filter_type label > .wrap .wrap-inner input{
101
+ width: 1px
102
+ }
103
+ #filter-columns-type{
104
+ border:0;
105
+ padding:0.5;
106
+ }
107
+ #filter-columns-size{
108
+ border:0;
109
+ padding:0.5;
110
+ }
111
+ #box-filter > .form{
112
+ border: 0
113
+ }
114
+ """
115
+
116
+ get_window_url_params = """
117
+ function(url_params) {
118
+ const params = new URLSearchParams(window.location.search);
119
+ url_params = Object.fromEntries(params);
120
+ return url_params;
121
+ }
122
+ """
src/display/custom_css.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ css_pipeline = """
2
+ :root {
3
+ color-scheme: light !important;
4
+ --block-border-width: 0;
5
+ --section-header-text-weight: 600;
6
+ --section-header-text-size: 14px;
7
+ --mono-font-family: "Roboto Mono", monospace;
8
+ --body-text-size: 14px !important;
9
+
10
+ --card-bg-color: #fcecd4;
11
+ --card-btn-color: #D4E4FC;
12
+ --card-btn-color-hover: #7DAEF6;
13
+ --answer-bg-color: #f0f8ff;
14
+ --hover-border-color: #121212;
15
+ }
16
+
17
+ .dark {
18
+ --block-border-width: 0;
19
+ --card-bg-color: #383127;
20
+ --answer-bg-color: #1a2b3c;
21
+ --hover-border-color: #ffffff;
22
+ }
23
+
24
+ .gradio-app {
25
+ // font-family: Arial, sans-serif;
26
+ }
27
+
28
+ .form {
29
+ box-shadow: 0 0 0 0 !important;
30
+ }
31
+
32
+ .head {
33
+ margin-bottom: 0px;
34
+ }
35
+
36
+ .gradio-container {
37
+ max-width: 1500px;
38
+ margin: 0 auto;
39
+ padding: 0 8px;
40
+ }
41
+
42
+ .html-container {
43
+ padding: 0px 0px;
44
+ margin: 4px 0px;
45
+ border-radius: 12px;
46
+ gap: 0px
47
+ }
48
+
49
+ .step-container {
50
+ background-color: var(--card-bg-color);
51
+ padding: 0px 0px;
52
+ margin: 4px 0px;
53
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2);
54
+ border-radius: 12px;
55
+ gap: 0px
56
+ }
57
+
58
+ .step-container:hover {
59
+ border-color: var(--hover-border-color);
60
+ box-shadow: 0 2px 8px rgba(0, 0, 0, 0.2);
61
+ }
62
+
63
+ .step-accordion {
64
+ background-color: var(--card-bg-color);
65
+ border: 0px solid #e0e0e0 !important;
66
+ border-radius: 12px;
67
+ overflow: hidden;
68
+ // transition: box-shadow 0.3s ease, border-color 0.3s ease;
69
+ padding: 8px 8px;
70
+ font-size: 12px;
71
+ }
72
+
73
+ .output-fields-panel {
74
+ background-color: var(--card-bg-color);
75
+ border: 0px solid #e0e0e0 !important;
76
+ border-radius: 12px;
77
+ overflow: hidden;
78
+ transition: box-shadow 0.3s ease, border-color 0.3s ease;
79
+ padding: 8px 8px;
80
+ font-size: 12px;
81
+ }
82
+
83
+ .model-submission-accordion {
84
+ background-color: var(--card-bg-color);
85
+ border: 0px solid #e0e0e0 !important;
86
+ border-radius: 12px;
87
+ overflow: hidden;
88
+ transition: box-shadow 0.3s ease, border-color 0.3s ease;
89
+ font-size: 14px;
90
+ }
91
+
92
+ .model-submission-accordion > label-wrap {
93
+ font-size: 16px;
94
+ font-weight: bold !important;
95
+ }
96
+
97
+ .step-accordion:hover .step-name-input input {
98
+ font-weight: bold;
99
+ }
100
+
101
+ .step-accordion > label-wrap {
102
+ font-size: 14px;
103
+ font-weight: bold !important;
104
+ }
105
+
106
+ .step-header-row {
107
+ margin: 0px 0px;
108
+ padding: 0px 0px;
109
+ border: 0px !important;
110
+ }
111
+
112
+ .step-header-row form {
113
+ margin: 0px 0px;
114
+ padding: 0px 0px;
115
+ border: 0px !important;
116
+ }
117
+
118
+ .step-name {
119
+ margin: 0px
120
+ padding: 0px 0px;
121
+ // border-radius: 8px;
122
+ border: 0px !important
123
+ }
124
+
125
+ .model-dropdown {
126
+ margin: 0px
127
+ padding: 0px 8px;
128
+ }
129
+
130
+ .model-dropdown input {
131
+ font-size: 14px;
132
+ padding-bottom: 2px;
133
+ padding-top: 2px;
134
+ }
135
+
136
+ .step-name input {
137
+ font-size: 14px;
138
+ font-weight: bold;
139
+ padding-bottom: 8px;
140
+ margin-bottom: 4px;
141
+ border-radius: 12px !important;
142
+ }
143
+
144
+ .step-controls {
145
+ display: flex;
146
+ justify-content: flex-end;
147
+ gap: 12px;
148
+ background-color: var(--card-bg-color);
149
+ border-radius: 12px;
150
+ padding: 0px
151
+ border: 1px solid black;
152
+ }
153
+
154
+ .step-control-btn {
155
+ background-color: var(--card-btn-color);
156
+ font-size: 12px !important;
157
+ color: var(--body-text-color);
158
+ min-width: 36px !important;
159
+ min-height: 24px !important;
160
+ padding: 4px !important;
161
+ margin: 8px !important;
162
+ border-radius: 12px;
163
+ }
164
+
165
+ .step-control-btn:hover {
166
+ background-color: var(--card-btn-color-hover);
167
+ }
168
+
169
+ .step-tabs {
170
+ margin-top: 0px;
171
+ padding: 0px 0px;
172
+ border-radius: 0px;
173
+ border: 0px
174
+ background-color: transparent;
175
+ }
176
+
177
+ .tab-content {
178
+ padding: 0px 0px;
179
+ margin-bottom: 0px;
180
+ border-radius: 4px;
181
+ border: 0px solid #eee;
182
+ background-color: transparent !important;
183
+ }
184
+
185
+ .fields-panel {
186
+ background-color: transparent !important;
187
+ gap: 5px !important;
188
+ border-radius: 4px;
189
+ padding: 2px;
190
+ }
191
+
192
+ .field-row {
193
+ margin-bottom: 1px;
194
+ margin-top: 1px;
195
+ padding: 2px;
196
+ border-radius: 8px;
197
+ background-color: var(--block-background-fill) !important;
198
+ border: 0px solid #eee;
199
+ gap: 0px !important;
200
+ }
201
+
202
+ .output-field-row {
203
+ margin-bottom: 1px;
204
+ margin-top: 1px;
205
+ padding: 2px;
206
+ border-radius: 4px;
207
+ background-color: var(--block-background-fill) !important;
208
+ border: 0px solid #eee;
209
+ gap: 0px !important;
210
+ }
211
+
212
+ .output-fields-header {
213
+ padding: 0px 8px;
214
+ }
215
+
216
+ .output-fields-panel {
217
+ background-color: var(--block-background-fill) !important;
218
+ padding: 8px 8px;
219
+ }
220
+
221
+ .output-field-variable {
222
+ font-family: var(--mono-font-family) !important;
223
+ font-weight: 300 !important;
224
+ font-size: 12px !important;
225
+ padding: 8px 8px;
226
+ border-radius: 4px;
227
+ border: 0px solid #eee !important;
228
+ }
229
+
230
+ .output-field-variable span {
231
+ font-size: 12px !important;
232
+ }
233
+
234
+ .field-type {
235
+ min-width: 100px !important;
236
+ }
237
+
238
+ .field-name > label, .field-variable > label, .field-description > label, .field-type > label {
239
+ font-size: 12px !important;
240
+ }
241
+
242
+ .field-name input, .field-description input, .field-type input {
243
+ font-family: var(--mono-font-family) !important;
244
+ font-size: 12px !important;
245
+ }
246
+
247
+ .field-variable input, .field-type input, .output-field-variable input {
248
+ font-family: var(--mono-font-family) !important;
249
+ font-size: 12px !important;
250
+ padding-top: 3px;
251
+ padding-bottom: 3px;
252
+ }
253
+
254
+ .field-name listbox, .field-variable listbox, .field-type listbox {
255
+ font-family: var(--mono-font-family) !important;
256
+ padding-top: 2px;
257
+ padding-bottom: 2px;
258
+ font-size: 12px !important;
259
+ }
260
+
261
+ .field-description {
262
+ font-size: 12px !important;
263
+ }
264
+
265
+ /* Accordion button labels */
266
+ .step-accordion button.label-wrap {
267
+ font-size: 14px;
268
+ font-weight: bold !important;
269
+ font-family: var(--mono-font-family) !important;
270
+ }
271
+
272
+ .step-accordion button.label-wrap.open {
273
+ font-size: 14px;
274
+ font-weight: bold !important;
275
+ font-family: var(--mono-font-family) !important;
276
+ }
277
+
278
+ .button-column {
279
+ margin-top: 2px;
280
+ margin-bottom: 2px;
281
+ padding-top: 2px;
282
+ padding-bottom: 2px;
283
+ display: flex;
284
+ flex-direction: column;
285
+ justify-content: center;
286
+ align-items: center;
287
+ gap: 4x !important;
288
+ height: 100%;
289
+ }
290
+
291
+ .icon-button {
292
+ min-width: 28px !important;
293
+ max-width: 42px !important;
294
+ height: 28px !important;
295
+ max-height: 42px !important;
296
+ padding: 0 !important;
297
+ border-radius: 4px !important;
298
+ transition: background-color 0.2s ease, color 0.2s ease;
299
+ }
300
+
301
+ .delete-button {
302
+ background-color: #ffebee !important;
303
+ color: #d32f2f !important;
304
+ }
305
+
306
+ .delete-button:hover {
307
+ background-color: #ffcdd2 !important;
308
+ }
309
+
310
+ .up-button, .down-button {
311
+ background-color: #e3f2fd !important;
312
+ color: #1976d2 !important;
313
+ }
314
+
315
+ .up-button:hover, .down-button:hover {
316
+ background-color: #bbdefb !important;
317
+ }
318
+
319
+ .add-field-button {
320
+ background-color: #e8f5e9 !important;
321
+ color: #2e7d32 !important;
322
+ }
323
+
324
+ .add-field-button:hover, .add-step-button:hover {
325
+ background-color: #c8e6c9 !important;
326
+ }
327
+
328
+ .pipeline-controls {
329
+ border-top: 1px solid #eee;
330
+ padding-top: 8px;
331
+ }
332
+
333
+ .pipeline-header {
334
+ border-bottom: 1px solid #eee;
335
+ padding: 8px 0px;
336
+ }
337
+
338
+ .pipeline-footer {
339
+ border-top: 1px solid #eee;
340
+ padding: 8px 0px;
341
+ }
342
+
343
+ .add-step-button {
344
+ background-color: #e8f5e9 !important;
345
+ color: #2e7d32 !important;
346
+ border-radius: 12px;
347
+ }
348
+
349
+ .export-button {
350
+ background-color: #e0f7f5 !important;
351
+ color: #00796b !important;
352
+ border-radius: 12px;
353
+ }
354
+
355
+ .export-button:hover {
356
+ background-color: #b2dfdb !important;
357
+ }
358
+
359
+ .pipeline-preview {
360
+ background-color: var(--card-bg-color);
361
+ border-radius: 12px;
362
+ box-shadow: 0 0 0 0 !important;
363
+ }
364
+ """
365
+
366
+
367
+ css_tossup = """
368
+ .token {
369
+ display: inline-block;
370
+ padding: 1px 3px;
371
+ margin: 1px;
372
+ border-radius: 4px;
373
+ cursor: pointer;
374
+ transition: background-color 0.2s;
375
+ }
376
+ .answer-header {
377
+ font-weight: 900;
378
+ font-size: 16px;
379
+ padding: 8px;
380
+ border-radius: 8px;
381
+ background-color: var(--answer-bg-color) !important;
382
+ }
383
+ .answer-box textarea {
384
+ font-size: 16px;
385
+ padding: 8px;
386
+ border-radius: 8px;
387
+ background-color: var(--answer-bg-color) !important;
388
+ }
389
+ .token:hover, .token.highlighted {
390
+ background-color: #ffcc00;
391
+ }
392
+ .token.guess-point {
393
+ border-bottom: 3px solid;
394
+ }
395
+ .token.no-buzz {
396
+ border-color: #6b96b3;
397
+ }
398
+ .token.buzz-0 {
399
+ border-color: #ff4444;
400
+ }
401
+ .token.buzz-1 {
402
+ border-color: #228b22; /* Darker and slightly muted green */
403
+ }
404
+ .token-container {
405
+ line-height: 1.7;
406
+ padding: 5px;
407
+ margin-left: 4px;
408
+ margin-right: 4px;
409
+ background-color: var(--answer-bg-color) !important;
410
+ border-radius: 8px;
411
+ margin-bottom: 10px;
412
+ }
413
+ """
src/display/formatting.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def model_hyperlink(link, model_name):
2
+ return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
3
+
4
+
5
+ def make_clickable_model(model_name):
6
+ link = f"https://huggingface.co/{model_name}"
7
+ return model_hyperlink(link, model_name)
8
+
9
+
10
+ def styled_error(error):
11
+ return f"<p style='color: red; font-size: 20px; text-align: center;'>{error}</p>"
12
+
13
+
14
+ def styled_warning(warn):
15
+ return f"<p style='color: orange; font-size: 20px; text-align: center;'>{warn}</p>"
16
+
17
+
18
+ def styled_message(message):
19
+ return f"<p style='color: green; font-size: 20px; text-align: center;'>{message}</p>"
20
+
21
+
22
+ def has_no_nan_values(df, columns):
23
+ return df[columns].notna().all(axis=1)
24
+
25
+
26
+ def has_nan_values(df, columns):
27
+ return df[columns].isna().any(axis=1)
src/display/utils.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, make_dataclass
2
+ from enum import Enum
3
+
4
+ import pandas as pd
5
+
6
+ from src.about import Tasks
7
+
8
+ def fields(raw_class):
9
+ return [v for k, v in raw_class.__dict__.items() if k[:2] != "__" and k[-2:] != "__"]
10
+
11
+
12
+ # These classes are for user facing column names,
13
+ # to avoid having to change them all around the code
14
+ # when a modif is needed
15
+ @dataclass
16
+ class ColumnContent:
17
+ name: str
18
+ type: str
19
+ displayed_by_default: bool
20
+ hidden: bool = False
21
+ never_hidden: bool = False
22
+
23
+ ## Leaderboard columns
24
+ auto_eval_column_dict = []
25
+ # Init
26
+ auto_eval_column_dict.append(["model_type_symbol", ColumnContent, ColumnContent("T", "str", True, never_hidden=True)])
27
+ auto_eval_column_dict.append(["model", ColumnContent, ColumnContent("Model", "markdown", True, never_hidden=True)])
28
+ #Scores
29
+ auto_eval_column_dict.append(["average", ColumnContent, ColumnContent("Average ⬆️", "number", True)])
30
+ for task in Tasks:
31
+ auto_eval_column_dict.append([task.name, ColumnContent, ColumnContent(task.value.col_name, "number", True)])
32
+ # Model information
33
+ auto_eval_column_dict.append(["model_type", ColumnContent, ColumnContent("Type", "str", False)])
34
+ auto_eval_column_dict.append(["architecture", ColumnContent, ColumnContent("Architecture", "str", False)])
35
+ auto_eval_column_dict.append(["weight_type", ColumnContent, ColumnContent("Weight type", "str", False, True)])
36
+ auto_eval_column_dict.append(["precision", ColumnContent, ColumnContent("Precision", "str", False)])
37
+ auto_eval_column_dict.append(["license", ColumnContent, ColumnContent("Hub License", "str", False)])
38
+ auto_eval_column_dict.append(["params", ColumnContent, ColumnContent("#Params (B)", "number", False)])
39
+ auto_eval_column_dict.append(["likes", ColumnContent, ColumnContent("Hub ❤️", "number", False)])
40
+ auto_eval_column_dict.append(["still_on_hub", ColumnContent, ColumnContent("Available on the hub", "bool", False)])
41
+ auto_eval_column_dict.append(["revision", ColumnContent, ColumnContent("Model sha", "str", False, False)])
42
+
43
+ # We use make dataclass to dynamically fill the scores from Tasks
44
+ AutoEvalColumn = make_dataclass("AutoEvalColumn", auto_eval_column_dict, frozen=True)
45
+
46
+ ## For the queue columns in the submission tab
47
+ @dataclass(frozen=True)
48
+ class EvalQueueColumn: # Queue column
49
+ model = ColumnContent("model", "markdown", True)
50
+ revision = ColumnContent("revision", "str", True)
51
+ private = ColumnContent("private", "bool", True)
52
+ precision = ColumnContent("precision", "str", True)
53
+ weight_type = ColumnContent("weight_type", "str", "Original")
54
+ status = ColumnContent("status", "str", True)
55
+
56
+ ## All the model information that we might need
57
+ @dataclass
58
+ class ModelDetails:
59
+ name: str
60
+ display_name: str = ""
61
+ symbol: str = "" # emoji
62
+
63
+
64
+ class ModelType(Enum):
65
+ PT = ModelDetails(name="pretrained", symbol="🟢")
66
+ FT = ModelDetails(name="fine-tuned", symbol="🔶")
67
+ IFT = ModelDetails(name="instruction-tuned", symbol="⭕")
68
+ RL = ModelDetails(name="RL-tuned", symbol="🟦")
69
+ Unknown = ModelDetails(name="", symbol="?")
70
+
71
+ def to_str(self, separator=" "):
72
+ return f"{self.value.symbol}{separator}{self.value.name}"
73
+
74
+ @staticmethod
75
+ def from_str(type):
76
+ if "fine-tuned" in type or "🔶" in type:
77
+ return ModelType.FT
78
+ if "pretrained" in type or "🟢" in type:
79
+ return ModelType.PT
80
+ if "RL-tuned" in type or "🟦" in type:
81
+ return ModelType.RL
82
+ if "instruction-tuned" in type or "⭕" in type:
83
+ return ModelType.IFT
84
+ return ModelType.Unknown
85
+
86
+ class WeightType(Enum):
87
+ Adapter = ModelDetails("Adapter")
88
+ Original = ModelDetails("Original")
89
+ Delta = ModelDetails("Delta")
90
+
91
+ class Precision(Enum):
92
+ float16 = ModelDetails("float16")
93
+ bfloat16 = ModelDetails("bfloat16")
94
+ Unknown = ModelDetails("?")
95
+
96
+ def from_str(precision):
97
+ if precision in ["torch.float16", "float16"]:
98
+ return Precision.float16
99
+ if precision in ["torch.bfloat16", "bfloat16"]:
100
+ return Precision.bfloat16
101
+ return Precision.Unknown
102
+
103
+ # Column selection
104
+ COLS = [c.name for c in fields(AutoEvalColumn) if not c.hidden]
105
+
106
+ EVAL_COLS = [c.name for c in fields(EvalQueueColumn)]
107
+ EVAL_TYPES = [c.type for c in fields(EvalQueueColumn)]
108
+
109
+ BENCHMARK_COLS = [t.value.col_name for t in Tasks]
110
+
src/envs.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from huggingface_hub import HfApi
4
+
5
+ # Info to change for your repository
6
+ # ----------------------------------
7
+ TOKEN = os.environ.get("HF_TOKEN") # A read/write token for your org
8
+ OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
9
+ ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY")
10
+ COHERE_API_KEY = os.environ.get("COHERE_API_KEY")
11
+
12
+ OWNER = (
13
+ "umdclip" # Change to your org - don't forget to create a results and request dataset, with the correct format!
14
+ )
15
+ # ----------------------------------
16
+
17
+ REPO_ID = f"{OWNER}/advcal-leaderboard"
18
+ QUEUE_REPO = f"{OWNER}/advcal-requests"
19
+ RESULTS_REPO = f"{OWNER}/advcal-results"
20
+
21
+ PLAYGROUND_DATASET_NAMES = {
22
+ "tossup": "umdclip/acf-co24-tossups",
23
+ "bonus": "umdclip/acf-co24-bonuses",
24
+ }
25
+
26
+ # If you setup a cache later, just change HF_HOME
27
+ CACHE_PATH = os.getenv("HF_HOME", ".")
28
+
29
+ # Local caches
30
+ EVAL_REQUESTS_PATH = os.path.join(CACHE_PATH, "eval-queue")
31
+ EVAL_RESULTS_PATH = os.path.join(CACHE_PATH, "eval-results")
32
+ EVAL_REQUESTS_PATH_BACKEND = os.path.join(CACHE_PATH, "eval-queue-bk")
33
+ EVAL_RESULTS_PATH_BACKEND = os.path.join(CACHE_PATH, "eval-results-bk")
34
+
35
+ THEME = "gstaff/xkcd"
36
+ UNSELECTED_VAR_NAME = "Select Variable..."
37
+ UNSELECTED_MODEL_NAME = "Select Model..."
38
+ AVAILABLE_MODELS = {
39
+ "OpenAI/gpt-4o": {
40
+ "model": "gpt-4o-2024-11-20",
41
+ },
42
+ "OpenAI/gpt-4o-mini": {
43
+ "model": "gpt-4o-mini-2024-07-18",
44
+ },
45
+ "OpenAI/gpt-3.5-turbo": {
46
+ "model": "gpt-3.5-turbo-0125",
47
+ },
48
+ "Anthropic/claude-3-7-sonnet": {
49
+ "model": "claude-3-7-sonnet-20250219",
50
+ },
51
+ "Anthropic/claude-3-5-sonnet": {
52
+ "model": "claude-3-5-sonnet-20241022",
53
+ },
54
+ "Anthropic/claude-3-5-haiku": {
55
+ "model": "claude-3-5-haiku-20241022",
56
+ },
57
+ "Cohere/command-r": {
58
+ "model": "command-r-08-2024",
59
+ },
60
+ "Cohere/command-r-plus": {
61
+ "model": "command-r-plus-08-2024",
62
+ },
63
+ "Cohere/command-r7b": {
64
+ "model": "command-r7b-12-2024",
65
+ },
66
+ }
67
+
68
+ DEFAULT_SELECTIONS = {
69
+ "tossup": {
70
+ "simple_workflow": False,
71
+ "model": "OpenAI/gpt-4o-mini",
72
+ "temperature": 0.2,
73
+ "buzz_threshold": 0.85,
74
+ "early_stop": True,
75
+ },
76
+ "bonus": {
77
+ "simple_workflow": False,
78
+ "model": "OpenAI/gpt-4o-mini",
79
+ "temperature": 0.2,
80
+ "buzz_threshold": 0.85,
81
+ "early_stop": True,
82
+ },
83
+ }
84
+
85
+ DAILY_SUBMISSION_LIMIT_PER_USER = 5
86
+ API = HfApi(token=TOKEN)
src/submission/structs.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from typing import Dict, List, Literal, Optional
3
+
4
+ from pydantic import BaseModel, Field
5
+
6
+ from workflows.structs import Workflow
7
+
8
+ CompetitionType = Literal["tossup", "bonus"]
9
+ SubmissionType = Literal["python_file", "simple_workflow", "complex_workflow"]
10
+ SubmissionStatus = Literal["submitted", "in_progress", "completed", "failed"]
11
+
12
+
13
+ class Submission(BaseModel):
14
+ """
15
+ Represents a submission in the competition system, formatted for HuggingFace datasets.
16
+
17
+ This model is designed to be easily serializable to/from HuggingFace dataset format
18
+ while maintaining type safety and validation through Pydantic.
19
+
20
+ Attributes:
21
+ id: Unique identifier for the submission
22
+ name: Display name of the submission
23
+ description: Detailed description of what the submission does
24
+ user_email: Email of the user who created the submission
25
+ competition_type: Type of competition (Tossup or Bonus)
26
+ submission_type: Format of the submission (python file or workflow)
27
+ workflow: Optional workflow definition for workflow submissions, stored as JSON
28
+ code: Optional code content for python file submissions
29
+ status: Current status of the submission
30
+ created_at: ISO format timestamp of creation
31
+ updated_at: ISO format timestamp of last update
32
+ """
33
+
34
+ id: str = Field(description="Unique identifier for the submission")
35
+ model_name: str = Field(description="Display name of the submission")
36
+ username: str = Field(description="HuggingFace username of the user who created the submission")
37
+ description: str = Field(description="Detailed description of what the submission does")
38
+ competition_type: CompetitionType = Field(description="Type of competition (tossup or bonus)")
39
+ submission_type: SubmissionType = Field(description="Format of the submission (python file or workflow)")
40
+ workflow: Optional[Workflow] = Field(default=None, description="Optional workflow definition stored as JSON dict")
41
+ code: Optional[str] = Field(default=None, description="Optional code content for python file submissions")
42
+ status: SubmissionStatus = Field(description="Current status of the submission")
43
+ created_at: str = Field(description="ISO format timestamp of creation")
44
+ updated_at: str = Field(description="ISO format timestamp of last update")
45
+
46
+ def to_dict(self) -> Dict:
47
+ """Convert to dictionary format suitable for HF datasets"""
48
+ data = self.model_dump()
49
+ if self.workflow:
50
+ data["workflow"] = self.workflow.model_dump(exclude_defaults=True)
51
+ return data
52
+
53
+ @classmethod
54
+ def from_dict(cls, data: Dict) -> "Submission":
55
+ """Create instance from dictionary format used in HF datasets"""
56
+ if data.get("workflow"):
57
+ data["workflow"] = Workflow.model_validate(data["workflow"])
58
+ return cls.model_validate(data)
src/submission/submit.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import traceback
4
+ from datetime import datetime, timedelta, timezone
5
+ from typing import Optional
6
+
7
+ import gradio as gr
8
+ import yaml
9
+
10
+ from src.display.formatting import styled_error, styled_message
11
+ from src.envs import API, DAILY_SUBMISSION_LIMIT_PER_USER, EVAL_REQUESTS_PATH, QUEUE_REPO
12
+ from src.submission.structs import CompetitionType, Submission, SubmissionStatus
13
+ from workflows.structs import Workflow
14
+
15
+
16
+ def get_user_submissions_today(username: str, competition_type: str) -> list[Submission]:
17
+ today = datetime.now(timezone.utc).strftime("%Y%m%d")
18
+ if username is None:
19
+ raise gr.Error("Authentication required. Please log in to view your submissions.")
20
+ out_dir = f"{EVAL_REQUESTS_PATH}/{username}"
21
+ submissions = []
22
+ if not os.path.exists(out_dir):
23
+ return submissions
24
+ for file in os.listdir(out_dir):
25
+ if not file.startswith(f"{competition_type}_"):
26
+ continue
27
+ with open(os.path.join(out_dir, file), "r") as f:
28
+ submission = Submission.from_dict(json.load(f))
29
+ if submission.created_at.startswith(today):
30
+ submissions.append(submission)
31
+ return submissions
32
+
33
+
34
+ def get_time_until_next_submission(tz: timezone = timezone.utc) -> str:
35
+ next_day_00 = datetime.now(tz) + timedelta(days=1)
36
+ next_day_00 = next_day_00.replace(hour=0, minute=0, second=0, microsecond=0)
37
+ remaining_time = next_day_00 - datetime.now(tz)
38
+ hours = remaining_time.seconds // 3600
39
+ minutes = (remaining_time.seconds % 3600) // 60
40
+ remaining_time_str = f"{hours} hours {minutes} mins"
41
+ return remaining_time_str
42
+
43
+
44
+ def create_submission(
45
+ username: str,
46
+ model_name: str,
47
+ description: str,
48
+ workflow: Workflow,
49
+ competition_type: CompetitionType,
50
+ ) -> Submission:
51
+ """
52
+ Create a submission for a tossup model.
53
+
54
+ Args:
55
+ name: Display name of the submission
56
+ description: Detailed description of what the submission does
57
+ user_email: Email of the user who created the submission
58
+ workflow: The workflow configuration for the tossup model
59
+
60
+ Returns:
61
+ Submission object if successful, None if validation fails
62
+ """
63
+ # Create the submission
64
+ dt = datetime.now(timezone.utc)
65
+ submission = Submission(
66
+ id=f"{competition_type}_{dt.strftime('%Y%m%d_%H%M%S')}_{model_name.lower().replace(' ', '_')}",
67
+ model_name=model_name,
68
+ username=username,
69
+ description=description,
70
+ competition_type=competition_type,
71
+ submission_type="simple_workflow",
72
+ workflow=workflow,
73
+ status="submitted",
74
+ created_at=dt.isoformat(),
75
+ updated_at=dt.isoformat(),
76
+ )
77
+
78
+ return submission
79
+
80
+
81
+ def submit_model(
82
+ model_name: str,
83
+ description: str,
84
+ workflow: Workflow,
85
+ competition_type: CompetitionType,
86
+ profile: gr.OAuthProfile | None,
87
+ ) -> str:
88
+ """
89
+ Submit a tossup model for evaluation.
90
+
91
+ Args:
92
+ name: Display name of the submission
93
+ description: Detailed description of what the submission does
94
+ user_email: Email of the user who created the submission
95
+ workflow: The workflow configuration for the tossup model
96
+
97
+ Returns:
98
+ Status message
99
+ """
100
+
101
+ if profile is None:
102
+ return styled_error("Authentication required. Please log in first to submit your model.")
103
+
104
+ username = profile.username
105
+
106
+ if len(get_user_submissions_today(username)) >= DAILY_SUBMISSION_LIMIT_PER_USER:
107
+ time_str = get_time_until_next_submission()
108
+ return styled_error(
109
+ f"Daily submission limit of {DAILY_SUBMISSION_LIMIT_PER_USER} reached. Please try again in \n {time_str}."
110
+ )
111
+ try:
112
+ submission = create_submission(
113
+ username=username,
114
+ model_name=model_name,
115
+ description=description,
116
+ workflow=workflow,
117
+ competition_type=competition_type,
118
+ )
119
+ # Convert to dictionary format
120
+ submission_dict = submission.to_dict()
121
+
122
+ # Create output directory path
123
+ out_dir = f"{EVAL_REQUESTS_PATH}/{username}"
124
+ out_path = f"{out_dir}/{submission.id}.json"
125
+
126
+ # Upload to HuggingFace dataset
127
+ API.upload_file(
128
+ path_or_fileobj=json.dumps(submission_dict, indent=2).encode(),
129
+ path_in_repo=out_path.split("eval-queue/")[1],
130
+ repo_id=QUEUE_REPO,
131
+ repo_type="dataset",
132
+ commit_message=f"Add tossup submission {submission.id}",
133
+ )
134
+
135
+ return styled_message(
136
+ f"Successfully submitted tossup model!\n"
137
+ f"Submission ID: {submission.id}\n"
138
+ f"Name: {username}/{model_name}\n"
139
+ f"Please wait for up to an hour for the model to show in the PENDING list."
140
+ )
141
+
142
+ except Exception as e:
143
+ traceback.print_exc()
144
+ return styled_error(f"Error submitting model: {str(e)}")
145
+
146
+
147
+ if __name__ == "__main__":
148
+ # Example usage
149
+ from workflows.factory import create_quizbowl_simple_step_initial_setup
150
+
151
+ # Create workflow
152
+ model_step = create_quizbowl_simple_step_initial_setup()
153
+ model_step.model = "gpt-4"
154
+ model_step.provider = "openai"
155
+ model_step.temperature = 0.7
156
+
157
+ workflow = Workflow(
158
+ inputs=["question_text"],
159
+ outputs={"answer": "A.answer", "confidence": "A.confidence"},
160
+ steps={"A": model_step},
161
+ )
162
+
163
+ # Submit model
164
+ result = submit_model(
165
+ model_name="GPT-4 Tossup",
166
+ description="A simple GPT-4 model for tossup questions",
167
+ workflow=workflow,
168
+ competition_type="tossup",
169
+ )
170
+ print(result)
src/utils.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Description: Utility functions for the model_step component.
2
+
3
+ from envs import AVAILABLE_MODELS, UNSELECTED_MODEL_NAME
4
+
5
+
6
+ def guess_model_provider(model_name: str):
7
+ """Guess the provider of a model name."""
8
+ model_name = model_name.lower()
9
+ if model_name.startswith("gpt-"):
10
+ return "OpenAI"
11
+ if "sonnet" in model_name or "claude" in model_name:
12
+ return "Anthropic"
13
+ raise ValueError(f"Model `{model_name}` not yet supported")
14
+
15
+
16
+ def get_model_and_provider(model_name: str):
17
+ """Get the model and provider from a model name."""
18
+ if model_name == UNSELECTED_MODEL_NAME:
19
+ return "", ""
20
+ splits = model_name.split("/", maxsplit=1)
21
+ if len(splits) == 1:
22
+ full_model_name = AVAILABLE_MODELS.get(model_name, model_name)
23
+ provider = guess_model_provider(full_model_name)
24
+ return full_model_name, provider
25
+ if len(splits) == 2:
26
+ provider, model_name = splits
27
+ full_model_name = AVAILABLE_MODELS.get(model_name, model_name)
28
+ return full_model_name, provider
29
+ raise ValueError(f"Model `{model_name}` not yet supported")
30
+
31
+
32
+ def get_full_model_name(model_name: str, provider: str = ""):
33
+ """Get the full model name from a model name."""
34
+ if model_name == "":
35
+ return UNSELECTED_MODEL_NAME
36
+ if not provider:
37
+ provider = guess_model_provider(model_name)
38
+ return f"{provider}/{model_name}"
src/workflows/README.md ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Workflows Subpackage
2
+
3
+ This subpackage provides a framework for defining, validating, and executing workflows composed of interconnected model steps with dependency management.
4
+
5
+ ## Overview
6
+
7
+ The workflows subpackage enables the creation and execution of workflows where multiple model steps can be combined, with outputs from earlier steps feeding into inputs of later steps. The package handles dependency resolution, execution order, and error handling.
8
+
9
+ ## Components
10
+
11
+ ### `structs.py`
12
+
13
+ Contains the core data structures used throughout the workflow system:
14
+
15
+ - `Field`: Represents an input or output field with name and type information
16
+ - `ModelStep`: Represents a single step in a workflow with input fields, output fields, and model details
17
+ - `Workflow`: A collection of ModelSteps with their identifiers
18
+
19
+ ### `utils.py`
20
+
21
+ Provides utility functions for workflow operations:
22
+
23
+ - `_create_variable_step_mapping`: Maps variables to the steps that produce them
24
+ - `create_dependency_graph`: Builds a dependency graph representing the execution order constraints
25
+ - `topological_sort`: Sorts steps in execution order based on their dependencies
26
+
27
+ ### `workflow_executor.py`
28
+
29
+ Handles the execution of workflows:
30
+
31
+ - Processes inputs and outputs between steps
32
+ - Coordinates the execution of model steps in the correct order
33
+ - Integrates with external model providers (e.g., via litellm)
34
+
35
+ ### `errors.py`
36
+
37
+ Defines custom exceptions for workflow-related errors:
38
+
39
+ - `WorkflowError`: Base class for workflow errors
40
+ - `CyclicDependencyError`: Raised when detecting cycles in the workflow graph
41
+ - `UnknownVariableError`: Raised when a step requires a variable that's not provided or produced
42
+
43
+ ## Usage Example
44
+
45
+ ```python
46
+ from workflows.structs import Field, ModelStep, Workflow
47
+
48
+ # Define a workflow with two steps
49
+ step1 = ModelStep(
50
+ input_fields=[Field(name="query", type="string")],
51
+ output_fields=[Field(name="summary", type="string")],
52
+ model="gpt-3.5-turbo",
53
+ system_prompt="Summarize the following text"
54
+ )
55
+
56
+ step2 = ModelStep(
57
+ input_fields=[Field(name="summary", type="string", variable="step1.summary")],
58
+ output_fields=[Field(name="key_points", type="array")],
59
+ model="gpt-4",
60
+ system_prompt="Extract key points from the summary"
61
+ )
62
+
63
+ workflow = Workflow(steps={"step1": step1, "step2": step2})
64
+
65
+ # Execute the workflow
66
+ from workflows.workflow_executor import execute_workflow
67
+
68
+ result = execute_workflow(
69
+ workflow=workflow,
70
+ input_values={"query": "Long text to summarize..."}
71
+ )
72
+
73
+ # Access results
74
+ summary = result["step1.summary"]
75
+ key_points = result["step2.key_points"]
76
+ ```
77
+
78
+ ## Error Handling
79
+
80
+ The workflows system provides robust error handling:
81
+
82
+ - Detects cyclic dependencies in workflow definitions
83
+ - Validates input/output variable references
84
+ - Ensures all required inputs are provided
85
+
86
+ ## Extending the Workflows System
87
+
88
+ To extend the workflows system:
89
+
90
+ 1. Add new model step types by extending the `ModelStep` class
91
+ 2. Create custom field types by extending validation in the execution logic
92
+ 3. Implement additional error types in `errors.py` for specialized error handling
src/workflows/errors.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom exceptions for workflow validation and execution errors.
3
+
4
+ This module defines the exception hierarchy for the workflows package, enabling
5
+ specific error types to be raised and caught during workflow validation and execution.
6
+ Each exception provides detailed error messages to help diagnose and fix issues in
7
+ workflow definitions or execution.
8
+
9
+ Exception hierarchy:
10
+ - WorkflowError (base class)
11
+ - UnknownVariableError (missing variable reference)
12
+ - CyclicDependencyError (circular dependencies)
13
+ - FunctionNotFoundError (missing function reference)
14
+ """
15
+
16
+
17
+ # Define custom exceptions for workflow errors
18
+ class WorkflowError(Exception):
19
+ """
20
+ Base exception class for all workflow-related errors.
21
+
22
+ This is the parent class for all workflow-specific exceptions and can be used
23
+ to catch any error from the workflows package.
24
+ """
25
+
26
+ pass
27
+
28
+
29
+ class UnknownVariableError(WorkflowError):
30
+ """
31
+ Raised when a workflow step references a variable that doesn't exist.
32
+
33
+ This typically occurs when a step's input field references a variable that is neither
34
+ provided as an external input nor produced as an output by any previous step.
35
+ """
36
+
37
+ def __init__(self, var: str):
38
+ super().__init__(f"Unknown variable referenced: {var}")
39
+
40
+
41
+ class CyclicDependencyError(WorkflowError):
42
+ """
43
+ Raised when a cyclic dependency is detected in a workflow.
44
+
45
+ A cyclic dependency occurs when there is a circular reference in the workflow graph,
46
+ such as step A depending on step B, which depends on step A. Such workflows cannot
47
+ be executed because there's no valid order to process the steps.
48
+ """
49
+
50
+ def __init__(self):
51
+ super().__init__("Cyclic dependency detected in workflow")
52
+
53
+
54
+ class FunctionNotFoundError(WorkflowError):
55
+ """
56
+ Raised when a referenced function cannot be found during workflow execution.
57
+
58
+ This typically occurs when a step references a function that doesn't exist in
59
+ the available function registry or namespace.
60
+ """
61
+
62
+ def __init__(self, func_name: str):
63
+ super().__init__(f"Function not found: {func_name}")
src/workflows/executors.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ import json
3
+ from typing import Any
4
+
5
+ import pydantic
6
+
7
+ from llms import completion
8
+ from workflows.errors import WorkflowError
9
+ from workflows.structs import InputField, ModelStep, OutputField, Workflow
10
+ from workflows.utils import create_dependency_graph, topological_sort
11
+
12
+ """
13
+ Core workflow execution functionality.
14
+
15
+ This module handles the execution of defined workflows, including input processing,
16
+ dependency-based execution order, model calling, and output collection. It integrates
17
+ with the litellm library to handle model interactions.
18
+
19
+ Key components:
20
+ - Utility functions for input/output transformation
21
+ - Input processing and validation
22
+ - Model step execution
23
+ - Complete workflow execution with dependency resolution
24
+
25
+ The module orchestrates the execution of steps in the correct order based on their
26
+ dependencies and manages the flow of data between steps.
27
+ """
28
+
29
+
30
+ def upper(x):
31
+ if isinstance(x, str):
32
+ return x.upper()
33
+ return x
34
+
35
+
36
+ def lower(x):
37
+ if isinstance(x, str):
38
+ return x.lower()
39
+ return x
40
+
41
+
42
+ TYPE_MAP = {
43
+ "str": str,
44
+ "int": int,
45
+ "float": float,
46
+ "bool": bool,
47
+ }
48
+
49
+ FUNCTION_MAP = {
50
+ "upper": upper,
51
+ "lower": lower,
52
+ "len": len,
53
+ "split": str.split,
54
+ }
55
+
56
+
57
+ def get_type(type_str: str) -> type:
58
+ return TYPE_MAP.get(type_str, eval(type_str))
59
+
60
+
61
+ def create_processed_inputs(model_step: ModelStep, available_vars: dict[str, Any]) -> dict[str, Any]:
62
+ """
63
+ Creates processed inputs for a model step.
64
+
65
+ This function extracts and processes the required inputs for a model step based on
66
+ its input field definitions. It retrieves values from the available variables dictionary
67
+ and applies any specified transformations.
68
+
69
+ Args:
70
+ model_step (ModelStep): The model step for which to create processed inputs.
71
+ available_vars (dict[str, Any]): Dictionary of variables available for use as inputs.
72
+ Keys are variable names, values are the variable values.
73
+
74
+ Returns:
75
+ dict[str, Any]: A dictionary of processed inputs ready for use by the model step.
76
+ Keys are input field names, values are the processed input values.
77
+
78
+ Raises:
79
+ WorkflowError: If a required variable is not found in available_vars,
80
+ or if a specified transformation function is not available.
81
+
82
+ Example:
83
+ >>> available_vars = {"step1.output": "Hello World"}
84
+ >>> create_processed_inputs(model_step, available_vars)
85
+ {"input_field_name": "HELLO WORLD"} # If upper transformation was specified
86
+ """
87
+ processed_inputs: dict[str, Any] = {}
88
+ for input_field in model_step.input_fields:
89
+ var = input_field.variable
90
+ value = available_vars[var]
91
+ if input_field.func is not None:
92
+ func = FUNCTION_MAP.get(input_field.func)
93
+ func = func or eval(input_field.func)
94
+ value = func(value)
95
+ processed_inputs[input_field.name] = value
96
+ return processed_inputs
97
+
98
+
99
+ # %%
100
+ def execute_model_step(
101
+ model_step: ModelStep, available_vars: dict[str, Any], return_full_content: bool = False
102
+ ) -> dict[str, Any] | tuple[dict[str, Any], str]:
103
+ """
104
+ Executes a model step using the provided available variables.
105
+
106
+ This function handles the complete execution of a model step, including:
107
+ 1. Processing inputs using variable references and transformations
108
+ 2. Constructing the appropriate prompt for the model
109
+ 3. Calling the model via litellm with structured output
110
+ 4. Processing and validating the model's response
111
+ 5. Applying any output transformations
112
+
113
+ The function supports different providers and model types through the litellm
114
+ integration, allowing for a consistent interface regardless of the underlying model.
115
+
116
+ Args:
117
+ model_step (ModelStep): The model step to execute, containing model details,
118
+ input/output specifications, and system prompt.
119
+ available_vars (dict[str, Any]): A dictionary of all variables available to this step,
120
+ including outputs from previous steps and external inputs.
121
+
122
+ Returns:
123
+ dict[str, Any]: A dictionary of processed outputs from the model step,
124
+ with keys matching the output field names.
125
+
126
+ Raises:
127
+ WorkflowError: If there's an error in input processing, model execution,
128
+ or output validation.
129
+
130
+ Example:
131
+ >>> step = ModelStep(
132
+ ... id="summarize",
133
+ ... model="gpt-3.5-turbo",
134
+ ... provider="openai",
135
+ ... call_type="llm",
136
+ ... system_prompt="Summarize the text",
137
+ ... input_fields=[InputField(name="text", variable="input_text", description="Text to summarize")],
138
+ ... output_fields=[OutputField(name="summary", type="str", description="Summary of the text")]
139
+ ... )
140
+ >>> execute_model_step(step, {"input_text": "Long text to be summarized..."})
141
+ {"summary": "A concise summary of the text."}
142
+ """
143
+ # Ensure inputs are processed using the specified functions in input_fields.
144
+ processed_inputs = create_processed_inputs(model_step, available_vars)
145
+
146
+ # Construct the input prompt for the model
147
+ input_str = ", ".join(f"{k}={v}" for k, v in processed_inputs.items())
148
+ step_result = f"{model_step.system_prompt} | Inputs: {input_str}"
149
+
150
+ # Define the expected output fields and their types
151
+ fields = {
152
+ field.name: (get_type(field.type), pydantic.Field(..., description=field.description))
153
+ for field in model_step.output_fields
154
+ }
155
+ ModelResponse = pydantic.create_model("ModelResponse", **fields)
156
+
157
+ # Execute the model step using litellm
158
+ api_response = completion(
159
+ model=f"{model_step.provider}/{model_step.model}",
160
+ system=model_step.system_prompt,
161
+ prompt=step_result,
162
+ response_format=ModelResponse,
163
+ )
164
+ # api_response = litellm.completion(
165
+ # model=model_step.model,
166
+ # messages=[{"role": "user", "content": step_result}],
167
+ # response_format=ModelResponse,
168
+ # )
169
+
170
+ # Extract and parse the model response
171
+ # model_response_content = api_response["choices"][0]["message"]["content"]
172
+ # model_response = json.loads(model_response_content)
173
+ model_response = api_response["output"]
174
+ # Map the parsed response to the output fields
175
+ outputs = {field.name: model_response[field.name] for field in model_step.output_fields}
176
+ if return_full_content:
177
+ return outputs, api_response["content"]
178
+ return outputs
179
+
180
+
181
+ # Example usage
182
+ if __name__ == "__main__":
183
+ # Define a simple model step
184
+ model_step = ModelStep(
185
+ id="step1",
186
+ model="gpt-4o-mini",
187
+ provider="OpenAI",
188
+ call_type="llm",
189
+ system_prompt="You are a simple NLP tool that takes a string, and a number N, and return the first N entities in the string, and the total count of entities in the string.",
190
+ input_fields=[
191
+ InputField(name="sentence", description="The sentence to process", variable="sentence", func=None),
192
+ InputField(name="n", description="The number of entities to return", variable="n", func=None),
193
+ ],
194
+ output_fields=[
195
+ OutputField(
196
+ name="entities",
197
+ description="The first N entities in the string as a list of strings",
198
+ type="list[str]",
199
+ func=None,
200
+ ),
201
+ OutputField(name="count", description="The total count of entities in the string", type="int", func=None),
202
+ ],
203
+ )
204
+
205
+ # Define processed inputs
206
+ processed_inputs = {"sentence": "Abdul Akbar is a good person, but Jesus is the son of God.", "n": 3}
207
+
208
+ # Execute the model step
209
+ outputs = execute_model_step(model_step, processed_inputs)
210
+ print(outputs)
211
+
212
+
213
+ # %%
214
+ def execute_workflow(
215
+ workflow: Workflow, input_values: dict[str, Any], return_full_content: bool = False
216
+ ) -> dict[str, Any] | tuple[dict[str, Any], str]:
217
+ """
218
+ Execute the given workflow as a computational graph.
219
+
220
+ This function orchestrates the complete execution of a workflow by:
221
+
222
+ 1. Validating and populating initial values using the provided external inputs
223
+ 2. Building a dependency graph between workflow steps
224
+ 3. Determining a valid execution order using topological sorting
225
+ 4. Executing each step in the correct order, with inputs from previous steps
226
+ 5. Collecting and returning the final outputs
227
+
228
+ The execution process ensures that all dependencies are satisfied before a step
229
+ is executed, and that the data flows correctly between steps according to the
230
+ variable references defined in each step's input fields.
231
+
232
+ Args:
233
+ workflow (Workflow): The workflow to execute, containing steps, their
234
+ dependencies, and input/output specifications.
235
+ input_values (dict[str, Any]): External input values to be used by the workflow.
236
+ Keys should match the required workflow.inputs.
237
+
238
+ Returns:
239
+ dict[str, Any]: A dictionary of the workflow's outputs, with keys matching
240
+ the variables defined in workflow.outputs.
241
+
242
+ Raises:
243
+ UnknownVariableError: If an input_field references a variable that is not
244
+ provided externally nor produced by any step.
245
+ CyclicDependencyError: If the workflow contains a circular dependency that
246
+ prevents a valid execution order.
247
+ FunctionNotFoundError: If a transformation function specified in input_fields.func
248
+ or output_fields.func is not available.
249
+ WorkflowError: For any other workflow-related errors, such as missing required inputs.
250
+
251
+ Example:
252
+ >>> workflow = Workflow(
253
+ ... steps={
254
+ ... "extract": ModelStep(...), # A step that extracts entities
255
+ ... "analyze": ModelStep(...) # A step that analyzes the entities
256
+ ... },
257
+ ... inputs=["text"],
258
+ ... outputs=["analyze.sentiment", "extract.entities"]
259
+ ... )
260
+ >>> result = execute_workflow(workflow, {"text": "Apple is launching a new product tomorrow."})
261
+ >>> print(result["analyze.sentiment"])
262
+ "positive"
263
+ >>> print(result["extract.entities"])
264
+ ["Apple", "product"]
265
+ """
266
+ # Step 1: Pre-populate computed values with external workflow inputs.
267
+ computed_values: dict[str, Any] = {}
268
+ for var in workflow.inputs:
269
+ if var not in input_values:
270
+ raise WorkflowError(f"Missing required workflow input: {var}")
271
+ computed_values[var] = input_values[var]
272
+
273
+ # Step 2: Build dependency graph among model steps.
274
+ # For each step, examine its input_fields. If an input is not in the pre-populated external inputs,
275
+ # then it is expected to be produced by some step. Otherwise, raise an error.
276
+ dependencies = create_dependency_graph(workflow, input_values)
277
+
278
+ # Step 3: Determine the execution order of the steps using topological sort.
279
+ # Raises an error if a cycle is detected.
280
+ execution_order = topological_sort(dependencies)
281
+
282
+ # Step 4: Execute steps in topological order.
283
+ for step_id in execution_order:
284
+ step = workflow.steps[step_id]
285
+
286
+ # Execute the step
287
+ outputs = execute_model_step(step, computed_values)
288
+ outputs = {f"{step_id}.{k}": v for k, v in outputs.items()}
289
+ computed_values.update(outputs)
290
+
291
+ # Step 5: Gather and return workflow outputs.
292
+ final_outputs: dict[str, Any] = {}
293
+ for target, var in workflow.outputs.items():
294
+ if var not in computed_values:
295
+ raise WorkflowError(f"Workflow output variable {var} was not produced")
296
+ final_outputs[target] = computed_values[var]
297
+
298
+ return final_outputs
299
+
300
+
301
+ def run_examples():
302
+ """
303
+ Runs three example workflows demonstrating:
304
+ 1. A successful (linear) workflow execution.
305
+ 2. A cyclic dependency error.
306
+ 3. An unknown variable dependency error.
307
+ """
308
+ print("Example 1: Successful Workflow Execution")
309
+ # Example 1: Simple linear workflow.
310
+ # External input "input.value" is provided. Two steps:
311
+ # - step1 takes "input.value" and produces "step1.result".
312
+ # - step2 uses "step1.result" and produces "step2.final".
313
+ from workflows.structs import ModelStep, Workflow
314
+
315
+ workflow_success = Workflow(
316
+ steps={
317
+ "step1": ModelStep(
318
+ id="step1",
319
+ model="gpt-4o-mini",
320
+ provider="OpenAI",
321
+ call_type="llm",
322
+ system_prompt="Step1 processing",
323
+ input_fields=[InputField(name="value", description="Input value", variable="input.value")],
324
+ output_fields=[OutputField(name="result", description="Processed result", type="str", func="upper")],
325
+ ),
326
+ "step2": ModelStep(
327
+ id="step2",
328
+ model="gpt-4o-mini",
329
+ provider="OpenAI",
330
+ call_type="llm",
331
+ system_prompt="Step2 processing",
332
+ input_fields=[InputField(name="result", description="Result from step1", variable="step1.result")],
333
+ output_fields=[OutputField(name="final", description="Final output", type="str", func="lower")],
334
+ ),
335
+ },
336
+ inputs=["input.value"],
337
+ outputs={"final": "step2.final"},
338
+ )
339
+ input_values_success = {"input.value": "Hello, World!"}
340
+ try:
341
+ outputs = execute_workflow(workflow_success, input_values_success)
342
+ print("Workflow outputs:", outputs)
343
+ except WorkflowError as e:
344
+ print("Workflow failed with error:", e)
345
+
346
+ print("\nExample 2: Cyclic Dependency Workflow")
347
+ # Example 2: Cyclic dependency.
348
+ # stepA depends on an output from stepB and vice versa.
349
+ workflow_cycle = Workflow(
350
+ steps={
351
+ "stepA": ModelStep(
352
+ id="stepA",
353
+ model="gpt-4o-mini",
354
+ provider="OpenAI",
355
+ call_type="llm",
356
+ system_prompt="StepA processing",
357
+ input_fields=[
358
+ InputField(name="input", description="Input from stepB", variable="stepB.output", func="identity")
359
+ ],
360
+ output_fields=[OutputField(name="output", description="Output from A", type="str", func="upper")],
361
+ ),
362
+ "stepB": ModelStep(
363
+ id="stepB",
364
+ model="gpt-4o-mini",
365
+ provider="OpenAI",
366
+ call_type="llm",
367
+ system_prompt="StepB processing",
368
+ input_fields=[
369
+ InputField(name="input", description="Input from stepA", variable="stepA.output", func="identity")
370
+ ],
371
+ output_fields=[OutputField(name="output", description="Output from B", type="str", func="upper")],
372
+ ),
373
+ },
374
+ inputs=[], # no external inputs
375
+ outputs={"output": "stepB.output"},
376
+ )
377
+ try:
378
+ outputs = execute_workflow(workflow_cycle, {})
379
+ print("Workflow outputs:", outputs)
380
+ except WorkflowError as e:
381
+ print("Workflow failed with error:", e)
382
+
383
+ print("\nExample 3: Unknown Variable Dependency Workflow")
384
+ # Example 3: A workflow that references a variable not provided as an input or produced by any step.
385
+ workflow_unknown = Workflow(
386
+ steps={
387
+ "stepX": ModelStep(
388
+ id="stepX",
389
+ model="gpt-4o-mini",
390
+ provider="OpenAI",
391
+ call_type="llm",
392
+ system_prompt="StepX processing",
393
+ input_fields=[
394
+ InputField(
395
+ name="input", description="Non-existent input", variable="nonexistent.value", func="identity"
396
+ )
397
+ ],
398
+ output_fields=[OutputField(name="output", description="Output from X", type="str", func="upper")],
399
+ )
400
+ },
401
+ inputs=[], # no external inputs
402
+ outputs={"output": "stepX.output"},
403
+ )
404
+ try:
405
+ outputs = execute_workflow(workflow_unknown, {})
406
+ print("Workflow outputs:", outputs)
407
+ except WorkflowError as e:
408
+ print("Workflow failed with error:", e)
409
+
410
+
411
+ if __name__ == "__main__":
412
+ # create example of model_step
413
+ model_step = ModelStep(
414
+ id="step1",
415
+ model="gpt-4o-mini",
416
+ provider="OpenAI",
417
+ call_type="llm",
418
+ system_prompt="You are a simple NLP tool that takes a string, and a number N, and return the first N entities in the string, and the total count of entities in the string.",
419
+ input_fields=[
420
+ InputField(name="sentence", description="The sentence to process", variable="sentence", func=None),
421
+ InputField(name="n", description="The number of entities to return", variable="n", func=None),
422
+ ],
423
+ output_fields=[
424
+ OutputField(
425
+ name="entities",
426
+ description="The first N entities in the string as a list of strings",
427
+ type="list[str]",
428
+ func=None,
429
+ ),
430
+ OutputField(name="count", description="The total count of entities in the string", type="int", func=None),
431
+ ],
432
+ )
433
+
434
+ processed_inputs = {"sentence": "Abdul Akbar is a good person, but Jesus is the son of God.", "n": 3}
435
+ processed_inputs = create_processed_inputs(model_step, processed_inputs)
436
+ print(processed_inputs)
437
+
438
+ run_examples()
439
+
440
+ # %%
src/workflows/factory.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ from workflows.structs import Field, InputField, ModelStep, OutputField, Workflow
3
+
4
+ INITIAL_SYS_PROMPT = """You are a helpful performant question answering bot.
5
+ Given a question clue, output your most likely guess in a couple words with a calibrated confidence for the guess.
6
+ """
7
+
8
+
9
+ def create_simple_workflow():
10
+ pass
11
+
12
+
13
+ def create_first_step_input_fields() -> list[InputField]:
14
+ return [
15
+ InputField(
16
+ name="question",
17
+ description="The question text progressively revealed to the agent so far.",
18
+ variable="question_text",
19
+ )
20
+ ]
21
+
22
+
23
+ def create_empty_input_field() -> list[InputField]:
24
+ return [InputField(name="", description="", variable="question_text")]
25
+
26
+
27
+ def create_quizbowl_simple_step_initial_setup():
28
+ return ModelStep(
29
+ id="simple_step",
30
+ name="Quizbowl Simple Step",
31
+ model="",
32
+ provider="",
33
+ temperature=0.7,
34
+ call_type="llm",
35
+ system_prompt=INITIAL_SYS_PROMPT,
36
+ input_fields=[
37
+ InputField(name="question", description="The question to answer", variable="question"),
38
+ ],
39
+ output_fields=[
40
+ OutputField(name="answer", description="The most likely answer", type="str"),
41
+ OutputField(name="confidence", description="The confidence of the answer", type="float"),
42
+ ],
43
+ )
44
+
45
+
46
+ def create_new_llm_step(step_id: str, name: str) -> ModelStep:
47
+ return ModelStep(
48
+ id=step_id,
49
+ name=name,
50
+ model="gpt-4o",
51
+ provider="OpenAI",
52
+ call_type="llm",
53
+ temperature=0.7,
54
+ system_prompt="",
55
+ input_fields=create_empty_input_field(),
56
+ output_fields=[OutputField(name="", description="")],
57
+ )
58
+
59
+
60
+ def create_first_llm_step() -> ModelStep:
61
+ return ModelStep(
62
+ id="A",
63
+ name="",
64
+ model="gpt-4o",
65
+ provider="OpenAI",
66
+ call_type="llm",
67
+ temperature=0.7,
68
+ system_prompt="",
69
+ input_fields=[create_first_step_input_fields()],
70
+ output_fields=[OutputField(name="", description="")],
71
+ )
72
+
73
+
74
+ def create_quizbowl_simple_workflow():
75
+ return Workflow(
76
+ inputs=["question_text"],
77
+ outputs={"answer": "A.answer", "confidence": "A.confidence"},
78
+ steps={
79
+ "A": ModelStep(
80
+ id="A",
81
+ name="Tossup Agent",
82
+ model="gpt-4o-mini",
83
+ provider="OpenAI",
84
+ call_type="llm",
85
+ temperature=0.3,
86
+ system_prompt="You are a helpful assistant that can answer questions.",
87
+ input_fields=[InputField(name="question", description="The question text", variable="question_text")],
88
+ output_fields=[
89
+ OutputField(
90
+ name="answer",
91
+ description="The best guess at the answer to the question",
92
+ type="str",
93
+ ),
94
+ OutputField(
95
+ name="confidence",
96
+ description="The confidence in the answer, ranging from 0 to 1 in increments of 0.05.",
97
+ type="float",
98
+ ),
99
+ ],
100
+ )
101
+ },
102
+ )
103
+
104
+
105
+ BONUS_SYS_PROMPT = """You are a quizbowl player answering bonus questions. For each part:
106
+ 1. Read the leadin and part carefully
107
+ 2. Provide a concise answer
108
+ 3. Rate your confidence (0-1)
109
+ 4. Explain your reasoning
110
+
111
+ Format your response as:
112
+ ANSWER: <your answer>
113
+ CONFIDENCE: <0-1>
114
+ EXPLANATION: <your reasoning>"""
115
+
116
+
117
+ def create_quizbowl_bonus_simple_workflow() -> Workflow:
118
+ """Create a simple model step for bonus questions."""
119
+ return Workflow(
120
+ inputs=["leadin", "part"],
121
+ outputs={"answer": "A.answer", "confidence": "A.confidence", "explanation": "A.explanation"},
122
+ steps={
123
+ "A": ModelStep(
124
+ id="A",
125
+ name="Bonus Agent",
126
+ model="gpt-4o-mini",
127
+ provider="OpenAI",
128
+ temperature=0.3,
129
+ call_type="llm",
130
+ system_prompt=BONUS_SYS_PROMPT,
131
+ input_fields=[
132
+ InputField(
133
+ name="question_leadin",
134
+ description="The leadin text for the bonus question",
135
+ variable="leadin",
136
+ ),
137
+ InputField(
138
+ name="question_part",
139
+ description="The specific part text to answer",
140
+ variable="part",
141
+ ),
142
+ ],
143
+ output_fields=[
144
+ OutputField(name="answer", description="The predicted answer", type="str"),
145
+ OutputField(name="confidence", description="Confidence in the answer (0-1)", type="float"),
146
+ OutputField(name="explanation", description="Short explanation for the answer", type="str"),
147
+ ],
148
+ )
149
+ },
150
+ )
src/workflows/qb/__init__.py ADDED
File without changes
src/workflows/qb/simple_agent.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from typing import Any, Iterable
3
+
4
+ # from litellm import completion
5
+ from llms import completion
6
+ from workflows.executors import execute_model_step, execute_workflow
7
+ from workflows.structs import ModelStep, Workflow
8
+
9
+
10
+ def _get_agent_response(self, prompt: str, system_prompt: str) -> dict:
11
+ """Get response from the LLM model."""
12
+ messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}]
13
+
14
+ start_time = time.time()
15
+ response = completion(
16
+ model=self.model,
17
+ messages=messages,
18
+ temperature=self.temperature,
19
+ max_tokens=150, # Limit token usage for faster responses
20
+ )
21
+ response_time = time.time() - start_time
22
+
23
+ return response, response_time
24
+
25
+
26
+ def _get_model_step_response(
27
+ model_step: ModelStep, available_vars: dict[str, Any]
28
+ ) -> tuple[dict[str, Any], str, float]:
29
+ """Get response from the LLM model."""
30
+ start_time = time.time()
31
+ response, content = execute_model_step(model_step, available_vars, return_full_content=True)
32
+ response_time = time.time() - start_time
33
+ return response, content, response_time
34
+
35
+
36
+ def _get_workflow_response(workflow: Workflow, available_vars: dict[str, Any]) -> tuple[dict[str, Any], str, float]:
37
+ """Get response from the LLM model."""
38
+ start_time = time.time()
39
+ response, content = execute_workflow(workflow, available_vars, return_full_content=True)
40
+ response_time = time.time() - start_time
41
+ return response, content, response_time
42
+
43
+
44
+ class SimpleTossupAgent:
45
+ external_input_variable = "question_text"
46
+ output_variables = ["answer", "confidence"]
47
+
48
+ def __init__(self, workflow: Workflow, buzz_threshold: float):
49
+ steps = list(workflow.steps.values())
50
+ assert len(steps) == 1, "Only one step is allowed in a simple workflow"
51
+ self.model_step = steps[0]
52
+ self.buzz_threshold = buzz_threshold
53
+ self.output_variables = list(workflow.outputs.keys())
54
+
55
+ if self.external_input_variable not in workflow.inputs:
56
+ raise ValueError(f"External input variable {self.external_input_variable} not found in model step inputs")
57
+
58
+ for out_var in self.output_variables:
59
+ if out_var not in workflow.outputs:
60
+ raise ValueError(f"Output variable {out_var} not found in the workflow outputs")
61
+
62
+ def run(self, question_runs: list[str], early_stop: bool = True) -> Iterable[dict]:
63
+ """
64
+ Process a tossup question and decide when to buzz based on confidence.
65
+
66
+ Args:
67
+ question_runs: Progressive reveals of the question text
68
+ early_stop: Whether to stop after the first buzz
69
+
70
+ Yields:
71
+ Dict with answer, confidence, and whether to buzz
72
+ """
73
+
74
+ for i, question_text in enumerate(question_runs):
75
+ response, content, response_time = _get_model_step_response(
76
+ self.model_step, {self.external_input_variable: question_text}
77
+ )
78
+ buzz = response["confidence"] >= self.buzz_threshold
79
+ result = {
80
+ "answer": response["answer"],
81
+ "confidence": response["confidence"],
82
+ "buzz": buzz,
83
+ "question_fragment": question_text,
84
+ "position": i + 1,
85
+ "full_response": content,
86
+ "response_time": response_time,
87
+ }
88
+
89
+ yield result
90
+
91
+ # If we've reached the confidence threshold, buzz and stop
92
+ if early_stop and buzz:
93
+ return
94
+
95
+
96
+ class SimpleBonusAgent:
97
+ external_input_variables = ["leadin", "part"]
98
+ output_variables = ["answer", "confidence", "explanation"]
99
+
100
+ def __init__(self, workflow: Workflow):
101
+ steps = list(workflow.steps.values())
102
+ assert len(steps) == 1, "Only one step is allowed in a simple workflow"
103
+ self.model_step = steps[0]
104
+ self.output_variables = list(workflow.outputs.keys())
105
+
106
+ # Validate input variables
107
+ for input_var in self.external_input_variables:
108
+ if input_var not in workflow.inputs:
109
+ raise ValueError(f"External input variable {input_var} not found in model step inputs")
110
+
111
+ # Validate output variables
112
+ for out_var in self.output_variables:
113
+ if out_var not in workflow.outputs:
114
+ raise ValueError(f"Output variable {out_var} not found in the workflow outputs")
115
+
116
+ def run(self, leadin: str, part: str) -> dict:
117
+ """
118
+ Process a bonus part with the given leadin.
119
+
120
+ Args:
121
+ leadin: The leadin text for the bonus question
122
+ part: The specific part text to answer
123
+
124
+ Returns:
125
+ Dict with answer, confidence, and explanation
126
+ """
127
+ response, content, response_time = _get_model_step_response(
128
+ self.model_step,
129
+ {
130
+ "leadin": leadin,
131
+ "part": part,
132
+ },
133
+ )
134
+
135
+ return {
136
+ "answer": response["answer"],
137
+ "confidence": response["confidence"],
138
+ "explanation": response["explanation"],
139
+ "full_response": content,
140
+ "response_time": response_time,
141
+ }
142
+
143
+
144
+ # Example usage
145
+ if __name__ == "__main__":
146
+ # Load the Quizbowl dataset
147
+ from datasets import load_dataset
148
+
149
+ from workflows.factory import create_quizbowl_bonus_step_initial_setup, create_quizbowl_simple_step_initial_setup
150
+
151
+ ds_name = "umdclip/leaderboard_co_set"
152
+ ds = load_dataset(ds_name, split="train")
153
+
154
+ # Create the agents
155
+ tossup_step = create_quizbowl_simple_step_initial_setup()
156
+ tossup_step.model = "gpt-4"
157
+ tossup_step.provider = "openai"
158
+ tossup_agent = SimpleTossupAgent(workflow=tossup_step, buzz_threshold=0.9)
159
+
160
+ bonus_step = create_quizbowl_bonus_step_initial_setup()
161
+ bonus_step.model = "gpt-4"
162
+ bonus_step.provider = "openai"
163
+ bonus_agent = SimpleBonusAgent(workflow=bonus_step)
164
+
165
+ # Example for tossup mode
166
+ print("\n=== TOSSUP MODE EXAMPLE ===")
167
+ sample_question = ds[30]
168
+ print(sample_question["question_runs"][-1])
169
+ print(sample_question["gold_label"])
170
+ print()
171
+ question_runs = sample_question["question_runs"]
172
+
173
+ results = tossup_agent.run(question_runs, early_stop=True)
174
+ for result in results:
175
+ print(result["full_response"])
176
+ print(f"Guess at position {result['position']}: {result['answer']}")
177
+ print(f"Confidence: {result['confidence']}")
178
+ if result["buzz"]:
179
+ print("Buzzed!\n")
180
+
181
+ # Example for bonus mode
182
+ print("\n=== BONUS MODE EXAMPLE ===")
183
+ sample_bonus = ds[31] # Assuming this is a bonus question
184
+ leadin = sample_bonus["leadin"]
185
+ parts = sample_bonus["parts"]
186
+
187
+ print(f"Leadin: {leadin}")
188
+ for i, part in enumerate(parts):
189
+ print(f"\nPart {i + 1}: {part['part']}")
190
+ result = bonus_agent.run(leadin, part["part"])
191
+ print(f"Answer: {result['answer']}")
192
+ print(f"Confidence: {result['confidence']}")
193
+ print(f"Explanation: {result['explanation']}")
194
+ print(f"Response time: {result['response_time']:.2f}s")
src/workflows/quizbowl_agent.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ import json
3
+ import os
4
+ import time
5
+ from typing import Dict, Iterable, List, Optional, Tuple, Union
6
+
7
+ import litellm
8
+ from datasets import load_dataset
9
+ from litellm import completion
10
+
11
+ litellm.drop_params = True
12
+
13
+ # Set your API key - you can replace this with your actual key or use environment variables
14
+ os.environ["OPENAI_API_KEY"] = (
15
+ "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"
16
+ )
17
+
18
+ DEFAULT_SYS_PROMPT = """
19
+ You are a Quizbowl expert. You will be given a question that's progressively revealed.
20
+ Your goal is to identify the answer as quickly as possible with high confidence.
21
+ Respond with a JSON object with two fields:
22
+ 1. "answer": Your best guess for the answer
23
+ 2. "confidence": Your confidence in your answer from 0.0 to 1.0
24
+
25
+ DO NOT include any explanation. ONLY return the JSON object.
26
+ """
27
+
28
+
29
+ class QuizbowlAgent:
30
+ """
31
+ An agent for playing Quizbowl with two modes:
32
+ 1. Tossup mode: Fast and direct with confidence calibration for buzzing
33
+ 2. Bonus round mode: Provides guess, rationale, and confidence
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ model: str = "gpt-4o-mini",
39
+ buzz_threshold: float = 0.85,
40
+ temperature: float = 0.2,
41
+ system_prompt: str = DEFAULT_SYS_PROMPT,
42
+ ):
43
+ """
44
+ Initialize the QuizbowlAgent.
45
+
46
+ Args:
47
+ model: The LLM model to use for answering
48
+ buzz_threshold: Confidence threshold for buzzing in tossup mode (0-1)
49
+ temperature: Temperature for model sampling
50
+ """
51
+ self.model = model
52
+ self.buzz_threshold = buzz_threshold
53
+ self.temperature = temperature
54
+ self.system_prompt = system_prompt
55
+
56
+ def _process_question_runs(self, question_runs: List[str]) -> List[str]:
57
+ """Process question runs to extract increasing amounts of text."""
58
+ # For simpler testing, just return the runs as they are in the dataset
59
+ return question_runs
60
+
61
+ def _get_agent_response(self, prompt: str, system_prompt: str) -> Dict:
62
+ """Get response from the LLM model."""
63
+ messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}]
64
+
65
+ start_time = time.time()
66
+ response = completion(
67
+ model=self.model,
68
+ messages=messages,
69
+ temperature=self.temperature,
70
+ max_tokens=150, # Limit token usage for faster responses
71
+ )
72
+ response_time = time.time() - start_time
73
+
74
+ return response, response_time
75
+
76
+ def _extract_confidence_and_answer(self, content: str) -> Tuple[str, float]:
77
+ """Extract the answer and confidence score from the model response."""
78
+ try:
79
+ # Try to parse JSON from the response
80
+ data = json.loads(content)
81
+ answer = data.get("answer", "")
82
+ confidence = float(data.get("confidence", 0.0))
83
+ return answer, confidence
84
+ except (json.JSONDecodeError, ValueError):
85
+ # Fallback if parsing fails
86
+ lines = content.strip().split("\n")
87
+ answer = lines[0] if lines else ""
88
+ confidence = 0.5 # Default confidence
89
+
90
+ # Try to extract confidence from text
91
+ for line in lines:
92
+ if "confidence:" in line.lower():
93
+ try:
94
+ confidence = float(line.lower().split("confidence:")[1].strip())
95
+ except (ValueError, IndexError):
96
+ pass
97
+
98
+ return answer, confidence
99
+
100
+ def tossup_mode(self, question_runs: List[str]) -> Iterable[Dict]:
101
+ """
102
+ Process a tossup question and decide when to buzz based on confidence.
103
+
104
+ Args:
105
+ question_runs: Progressive reveals of the question text
106
+
107
+ Yields:
108
+ Dict with answer, confidence, and whether to buzz
109
+ """
110
+
111
+ for i, question_text in enumerate(question_runs):
112
+ prompt = f"Question: {question_text}\n\nProvide your answer and confidence level:"
113
+
114
+ response, response_time = self._get_agent_response(prompt, DEFAULT_SYS_PROMPT)
115
+ content = response.choices[0].message.content
116
+
117
+ answer, confidence = self._extract_confidence_and_answer(content)
118
+
119
+ result = {
120
+ "answer": answer,
121
+ "confidence": confidence,
122
+ "buzz": confidence >= self.buzz_threshold,
123
+ "question_fragment": question_text,
124
+ "position": i + 1,
125
+ "full_response": content,
126
+ "response_time": response_time,
127
+ }
128
+
129
+ yield result
130
+
131
+ # If we've reached the confidence threshold, buzz and stop
132
+ if confidence >= self.buzz_threshold:
133
+ return
134
+
135
+ def tossup_mode_top5(self, question_runs: List[str]) -> Iterable[Dict]:
136
+ """
137
+ Process a tossup question and provide the top 5 guesses with confidence levels.
138
+
139
+ Args:
140
+ question_runs: Progressive reveals of the question text
141
+
142
+ Returns:
143
+ Dict with top 5 answers, their confidences, and whether to buzz
144
+ """
145
+
146
+ for i, question_text in enumerate(question_runs):
147
+ prompt = f"Question: {question_text}\n\nProvide your top 5 answers and confidence levels."
148
+
149
+ response, response_time = self._get_agent_response(prompt, self.system_prompt)
150
+ content = response.choices[0].message.content
151
+
152
+ try:
153
+ # Try to parse JSON from the response
154
+ data = json.loads(content)
155
+ guesses = data.get("guesses", [])
156
+ except (json.JSONDecodeError, ValueError):
157
+ # Fallback if parsing fails
158
+ guesses = []
159
+
160
+ result = {
161
+ "guesses": guesses,
162
+ "buzz": any(guess["confidence"] >= self.buzz_threshold for guess in guesses),
163
+ "question_fragment": question_text,
164
+ "position": i + 1,
165
+ "full_response": content,
166
+ "response_time": response_time,
167
+ }
168
+
169
+ yield result
170
+
171
+ # If any guess reaches the confidence threshold, buzz and stop
172
+ if result["buzz"]:
173
+ return
174
+
175
+ def bonus_round_mode(self, question: str) -> Dict:
176
+ """
177
+ Process a bonus round question with detailed analysis.
178
+
179
+ Args:
180
+ question: The bonus question text
181
+
182
+ Returns:
183
+ Dict with answer, rationale, and confidence
184
+ """
185
+ system_prompt = """
186
+ You are a Quizbowl expert answering a bonus question. Provide:
187
+ 1. Your direct answer
188
+ 2. A very brief and crisp one line rationale for your answer (key clues that led to it)
189
+ 3. Your confidence level (0.0-1.0)
190
+
191
+ Respond with a JSON object with these three fields:
192
+ {
193
+ "answer": "Your answer here",
194
+ "rationale": "Your reasoning here",
195
+ "confidence": 0.XX
196
+ }
197
+ """
198
+
199
+ prompt = f"Bonus Question: {question}\n\nProvide your answer, rationale, and confidence:"
200
+
201
+ response = self._get_agent_response(prompt, system_prompt)
202
+ content = response.choices[0].message.content
203
+
204
+ try:
205
+ # Try to parse JSON
206
+ result = json.loads(content)
207
+ # Ensure all fields are present
208
+ if not all(k in result for k in ["answer", "rationale", "confidence"]):
209
+ raise ValueError("Missing fields in response")
210
+ except (json.JSONDecodeError, ValueError):
211
+ # If parsing fails, extract manually
212
+ lines = content.strip().split("\n")
213
+ result = {"answer": "", "rationale": "", "confidence": 0.5}
214
+
215
+ for line in lines:
216
+ if line.lower().startswith("answer:"):
217
+ result["answer"] = line[7:].strip()
218
+ elif line.lower().startswith("rationale:"):
219
+ result["rationale"] = line[10:].strip()
220
+ elif line.lower().startswith("confidence:"):
221
+ try:
222
+ result["confidence"] = float(line[11:].strip())
223
+ except ValueError:
224
+ pass
225
+
226
+ return result
227
+
228
+
229
+ # %%
230
+ # Example usage
231
+ if __name__ == "__main__":
232
+ # Load the Quizbowl dataset
233
+ ds_name = "umdclip/leaderboard_co_set"
234
+ ds = load_dataset(ds_name, split="train")
235
+
236
+ # Create the agent
237
+ agent = QuizbowlAgent(model="gpt-4-turbo", buzz_threshold=0.85)
238
+
239
+ # Example for tossup mode
240
+ print("\n=== TOSSUP MODE EXAMPLE ===")
241
+ sample_question = ds[0]
242
+ print(sample_question["question_runs"][-1])
243
+ print(sample_question["gold_label"])
244
+ question_runs = sample_question["question_runs"]
245
+
246
+ results = agent.tossup_mode(question_runs)
247
+ for result in results:
248
+ print(f"Guess at position {result['position']}: {result['answer']}")
249
+ print(f"Confidence: {result['confidence']}")
250
+ if result["buzz"]:
251
+ print("Buzzed!\n")
252
+
253
+ results = agent.tossup_mode_top5(question_runs)
254
+ for result in results:
255
+ guesses = [f"{guess['answer']} ({guess['confidence']})" for guess in result["guesses"]]
256
+ print(f"Guesses at position {result['position']}: {', '.join(guesses)}")
257
+ if result["buzz"]:
258
+ print("Buzzed!")
259
+
260
+ # Example for bonus round mode
261
+ print("\n=== BONUS ROUND MODE EXAMPLE ===")
262
+ bonus_question = sample_question["question_runs"][-1]
263
+
264
+ bonus_result = agent.bonus_round_mode(bonus_question)
265
+ print(f"Answer: {bonus_result['answer']}")
266
+ print(f"Rationale: {bonus_result['rationale']}")
267
+ print(f"Confidence: {bonus_result['confidence']}")
268
+
269
+ # %%
src/workflows/structs.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ from typing import Any, Literal, Optional
3
+
4
+ from pydantic import BaseModel, Field, model_validator
5
+
6
+ """
7
+ Core data structures for defining workflows and their components.
8
+
9
+ This module defines the primary classes used to model workflows, steps, and their
10
+ input/output fields. These data structures serve as the foundation for workflow
11
+ definition, validation, and execution throughout the workflows package.
12
+
13
+ The primary components are:
14
+ - InputField: Represents an input to a model step with name and source variable
15
+ - OutputField: Represents an output from a model step with name and type
16
+ - ModelStep: Represents a single step in a workflow with inputs and outputs
17
+ - Workflow: A collection of interconnected steps with defined inputs and outputs
18
+
19
+ All classes use Pydantic's BaseModel for validation and serialization support.
20
+ """
21
+ FieldType = Literal["input", "output"]
22
+
23
+ SUPPORTED_TYPES = Literal["str", "int", "float", "bool", "list[str]", "list[int]", "list[float]", "list[bool]"]
24
+ """Supported field types for input and output fields"""
25
+
26
+
27
+ class InputField(BaseModel):
28
+ """
29
+ Defines an input field for a model step.
30
+
31
+ An input field specifies what data a step requires, where it comes from,
32
+ and optional pre-processing to apply before use.
33
+
34
+ Attributes:
35
+ name: The name of the input field within the step's context
36
+ description: Human-readable description of the input's purpose
37
+ variable: Reference to the source variable (format: "{step_id}.{field_name}" or external input name)
38
+ func: Optional function name to transform the input value before use
39
+ """
40
+
41
+ name: str
42
+ description: str
43
+ variable: str
44
+
45
+ # function to call on the input before passing it to the model
46
+ func: str | None = None
47
+
48
+
49
+ class OutputField(BaseModel):
50
+ """
51
+ Defines an output field produced by a model step.
52
+
53
+ An output field specifies a value that the step will produce, including
54
+ its data type and optional post-processing.
55
+
56
+ Attributes:
57
+ name: The name of the output field within the step's context
58
+ description: Human-readable description of the output's purpose
59
+ type: The data type of the output (one of SUPPORTED_TYPES)
60
+ func: Optional function name to transform the raw output value
61
+ """
62
+
63
+ name: str
64
+ type: SUPPORTED_TYPES = Field(default="str")
65
+ description: str
66
+
67
+ # function to call on the output string from the model
68
+ func: str | None = None
69
+
70
+
71
+ class ModelStep(BaseModel):
72
+ """
73
+ Represents a single step in a workflow.
74
+
75
+ A model step encapsulates the details of a specific operation within a workflow,
76
+ including what model to use, what inputs it requires, and what outputs it produces.
77
+
78
+ Attributes:
79
+ id: Unique identifier for this step within a workflow
80
+ model: The model to use for this step (e.g., "gpt-4")
81
+ provider: The provider of the model (e.g., "openai")
82
+ call_type: The type of operation (e.g., "llm", "search")
83
+ system_prompt: Instructions for the model
84
+ input_fields: List of input fields required by this step
85
+ output_fields: List of output fields produced by this step
86
+ """
87
+
88
+ id: str
89
+ name: str
90
+ model: str
91
+ provider: str
92
+ call_type: str = "llm" # llm, search, etc # TODO: make this enum or provide explicit options using Literal
93
+
94
+ # TODO: Validate that this is not None for call_type = llm
95
+ temperature: Optional[float] = None
96
+
97
+ system_prompt: str
98
+ input_fields: list[InputField]
99
+ output_fields: list[OutputField]
100
+
101
+ def fields(self, field_type: FieldType) -> list[InputField | OutputField]:
102
+ return self.input_fields if field_type == "input" else self.output_fields
103
+
104
+ def get_full_model_name(self):
105
+ return f"{self.provider} {self.model}"
106
+
107
+ def get_produced_variables(self) -> list[str]:
108
+ return [f"{self.id}.{field.name}" for field in self.output_fields if field.name]
109
+
110
+ def update(self, update: dict[str, Any]) -> "ModelStep":
111
+ return self.model_copy(update=update)
112
+
113
+ def update_property(self, field: str, value: Any) -> "ModelStep":
114
+ "Update the `field` key of the model step with `value`."
115
+ return self.update({field: value})
116
+
117
+ def update_field(self, field_type: FieldType, index: int, key: str, value: str) -> "ModelStep":
118
+ """Update a specific field of an input or output field at the given index."""
119
+ if field_type == "input":
120
+ fields = self.input_fields
121
+ elif field_type == "output":
122
+ fields = self.output_fields
123
+ else:
124
+ raise ValueError(f"Invalid field type: {field_type}")
125
+
126
+ if index < len(fields):
127
+ fields[index] = fields[index].model_copy(update={key: value})
128
+ return self.model_copy()
129
+
130
+ @staticmethod
131
+ def create_new_field(field_type: FieldType, input_var: str | None = None) -> InputField | OutputField:
132
+ if field_type == "input":
133
+ return InputField(name="", description="", variable=input_var)
134
+ elif field_type == "output":
135
+ return OutputField(name="", description="")
136
+ else:
137
+ raise ValueError(f"Invalid field type: {field_type}")
138
+
139
+ def add_field(self, field_type: FieldType, index: int = -1, input_var: str | None = None) -> "ModelStep":
140
+ """Add a new field to the state and update visibility.
141
+
142
+ Args:
143
+ field_type: Type of field to add ('input' or 'output').
144
+ index: Position to insert the new field (-1 to append).
145
+ Returns:
146
+ A new ModelStep with the updated fields.
147
+ """
148
+ new_step = self.model_copy()
149
+ fields = new_step.input_fields if field_type == "input" else new_step.output_fields
150
+ new_field = ModelStep.create_new_field(field_type, input_var)
151
+ fields.insert(index + 1, new_field) if index != -1 else fields.append(new_field)
152
+ return new_step
153
+
154
+ def delete_field(self, field_type: FieldType, index: int) -> "ModelStep":
155
+ """
156
+ Delete an input or output field from the state and update visibility.
157
+
158
+ Args:
159
+ field_type: Type of field to delete ('input' or 'output').
160
+ index: Index of the field to delete. [-1 to delete the last field]
161
+
162
+ Returns:
163
+ A new ModelStep with the updated fields.
164
+ """
165
+ new_step = self.model_copy()
166
+ fields = new_step.input_fields if field_type == "input" else new_step.output_fields
167
+ fields.pop(index)
168
+ return new_step
169
+
170
+
171
+ class Workflow(BaseModel):
172
+ """
173
+ Represents a complete workflow composed of interconnected steps.
174
+
175
+ A workflow defines a directed acyclic graph of model steps, where outputs
176
+ from earlier steps can be used as inputs to later steps.
177
+
178
+ Attributes:
179
+ inputs: List of input variables required by the workflow
180
+ outputs: List of output variables produced by the workflow
181
+ steps: Dictionary mapping step IDs to ModelStep instances
182
+
183
+ The inputs and outputs lists use the format "{step_id}.{field_name}"
184
+ to uniquely identify variables within the workflow.
185
+ """
186
+
187
+ # variables of form {node}.{field}
188
+ inputs: list[str] = Field(default_factory=list)
189
+
190
+ # variables of form {node}.{field}
191
+ outputs: dict[str, str | None] = Field(default_factory=dict)
192
+ steps: dict[str, ModelStep] = Field(default_factory=dict)
193
+
194
+ def model_dump(self, *args, **kwargs):
195
+ data = super().model_dump(*args, **kwargs)
196
+ data["steps"] = list(data["steps"].values())
197
+ return data
198
+
199
+ @model_validator(mode="before")
200
+ def dictify_steps(cls, data):
201
+ if "steps" in data and isinstance(data["steps"], list):
202
+ steps_dict = {}
203
+ for step in data["steps"]:
204
+ if step["id"] in steps_dict:
205
+ raise ValueError(f"Duplicate step ID: {step['id']}")
206
+ steps_dict[step["id"]] = step
207
+ data["steps"] = steps_dict
208
+ return data
209
+
210
+ def get_step_variables(self, step_id: str) -> list[str]:
211
+ """Get all variables from a specific step."""
212
+ step = self.steps[step_id]
213
+ variables = []
214
+ for output in step.output_fields:
215
+ if output.name == "":
216
+ continue
217
+ output_var = f"{step.id}.{output.name}"
218
+ variables.append(output_var)
219
+ return variables
220
+
221
+ def get_available_variables(self) -> list[str]:
222
+ """Get all output variables from all steps."""
223
+ variables = set(self.inputs)
224
+ for step in self.steps.values():
225
+ variables.update(self.get_step_variables(step.id))
226
+ return list(variables)
227
+
228
+
229
+ # %%
src/workflows/utils.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+ from typing import Any
3
+
4
+ from workflows.errors import CyclicDependencyError, UnknownVariableError, WorkflowError
5
+ from workflows.structs import ModelStep, Workflow
6
+
7
+ """
8
+ Utilities for workflow dependency management and execution order determination.
9
+
10
+ This module provides functions for analyzing workflows, determining dependencies between steps,
11
+ and calculating the correct execution order to ensure all dependencies are satisfied.
12
+ Key functionality includes:
13
+
14
+ - Variable to step mapping: Identifying which step produces each variable
15
+ - Dependency graph creation: Building a graph representing dependencies between steps
16
+ - Topological sorting: Determining a valid execution order based on dependencies
17
+ - Cycle detection: Identifying cyclic dependencies that would prevent execution
18
+
19
+ These utilities form the foundation for workflow validation and execution in the
20
+ workflow_executor module.
21
+ """
22
+
23
+
24
+ def _create_variable_step_mapping(workflow: Workflow) -> dict[str, str]:
25
+ """
26
+ Creates a mapping from produced variable names to the model step that produces them.
27
+
28
+ Args:
29
+ workflow (Workflow): The workflow containing steps and their input/output fields.
30
+
31
+ Returns:
32
+ dict[str, str]: A dictionary where keys are variable names (formatted as "{step_id}.{output name}")
33
+ and values are the step IDs that produce them.
34
+
35
+ Raises:
36
+ WorkflowError: If there are duplicate step IDs or if a variable is produced by multiple steps.
37
+
38
+ Example:
39
+ For a workflow with steps "extract" and "summarize" each producing outputs:
40
+ >>> _create_variable_step_mapping(workflow)
41
+ {'extract.keywords': 'extract', 'summarize.summary': 'summarize'}
42
+ """
43
+ variable_step_map: dict[str, str] = {} # variable name -> step id
44
+ for step_id, step in workflow.steps.items():
45
+ for output in step.output_fields:
46
+ var_name = f"{step_id}.{output.name}"
47
+ if var_name in variable_step_map:
48
+ raise WorkflowError(f"Variable '{output.name}' has duplicate entry in step {step_id}")
49
+ variable_step_map[var_name] = step_id
50
+ return variable_step_map
51
+
52
+
53
+ def create_dependency_graph(workflow: Workflow, input_values: dict[str, Any]) -> dict[str, set[str]]:
54
+ """
55
+ Creates a dependency graph from a workflow.
56
+
57
+ This function analyzes the workflow and determines which steps depend on others
58
+ based on their input/output relationships. A step depends on another if it requires
59
+ a variable that is produced by the other step. External inputs provided through
60
+ input_values don't create dependencies.
61
+
62
+ Args:
63
+ workflow (Workflow): The workflow containing steps and their input/output fields.
64
+ input_values (dict[str, Any]): A dictionary of external input values provided to the workflow.
65
+
66
+ Returns:
67
+ dict[str, set[str]]: A dictionary where keys are step IDs and values are sets of step IDs
68
+ that the key step depends on.
69
+
70
+ Raises:
71
+ UnknownVariableError: If an input field references a variable that is not provided
72
+ externally nor produced by any step.
73
+
74
+ Example:
75
+ For a workflow where step "classify" depends on output from "extract":
76
+ >>> create_dependency_graph(workflow, {})
77
+ {'extract': set(), 'classify': {'extract'}}
78
+
79
+ With external input provided for "text" variable:
80
+ >>> create_dependency_graph(workflow, {'text': 'Sample text'})
81
+ {'extract': set(), 'classify': {'extract'}}
82
+ """
83
+ produced_by = _create_variable_step_mapping(workflow)
84
+ dependencies: dict[str, set[str]] = {step_id: set() for step_id in workflow.steps}
85
+ for step_id, step in workflow.steps.items():
86
+ for input_field in step.input_fields:
87
+ var = input_field.variable
88
+ # If the variable was provided externally, then no dependency is needed.
89
+ if var in input_values:
90
+ continue
91
+ # Otherwise, check if the variable is produced by a step.
92
+ if var in produced_by:
93
+ producer_step_id = produced_by[var]
94
+ if producer_step_id != step_id: # Avoid self-dependency
95
+ dependencies[step_id].add(producer_step_id)
96
+ else:
97
+ raise UnknownVariableError(f"Variable '{var}' is not provided externally nor produced by any step")
98
+ return dependencies
99
+
100
+
101
+ def topological_sort(dependencies: dict[str, set[str]]) -> list[str]:
102
+ """
103
+ Performs a topological sort on a dependency graph and detects cycles using Kahn's algorithm.
104
+
105
+ A topological sort orders the steps such that for every dependency from step A to step B,
106
+ step A comes before step B in the ordering. This ensures that all dependencies are satisfied
107
+ when executing steps in the returned order.
108
+
109
+ Args:
110
+ dependencies (dict[str, set[str]]): A dictionary where each key is a node identifier and
111
+ each value is a set of nodes that the key node depends on.
112
+
113
+ Returns:
114
+ list[str]: A list representing the nodes in topological order if no cycle is detected.
115
+
116
+ Raises:
117
+ CyclicDependencyError: If a cycle is detected in the graph.
118
+
119
+ Example:
120
+ >>> topological_sort({'A': set(), 'B': {'A'}, 'C': {'B'}})
121
+ ['A', 'B', 'C']
122
+
123
+ >>> topological_sort({'A': {'B'}, 'B': {'A'}}) # Cyclic dependency
124
+ CyclicDependencyError
125
+
126
+ Algorithm:
127
+ This implementation uses Kahn's algorithm:
128
+ 1. Calculate in-degree for all nodes (number of dependencies)
129
+ 2. Start with nodes having 0 in-degree (no dependencies)
130
+ 3. Process each node by removing its outgoing edges
131
+ 4. Add newly dependency-free nodes to the processing queue
132
+ 5. If not all nodes are processed, a cycle exists
133
+ """
134
+
135
+ nodes = list(dependencies.keys())
136
+ dependents: dict[str, list[str]] = {node: [] for node in nodes}
137
+ in_degree: dict[str, int] = {node: 0 for node in nodes}
138
+
139
+ # Calculate in-degrees and build dependents list
140
+ for node, deps in dependencies.items():
141
+ in_degree[node] = len(deps)
142
+ for dep in deps:
143
+ dependents[dep].append(node)
144
+
145
+ # Initialize queue with nodes having zero in-degree
146
+ queue = deque([node for node, deg in in_degree.items() if deg == 0])
147
+ execution_order: list[str] = []
148
+
149
+ # Process nodes in topological order
150
+ while queue:
151
+ current = queue.popleft()
152
+ execution_order.append(current)
153
+ for dep in dependents[current]:
154
+ in_degree[dep] -= 1
155
+ if in_degree[dep] == 0:
156
+ queue.append(dep)
157
+
158
+ # If execution order includes all nodes, no cycle exists
159
+ if len(execution_order) != len(nodes):
160
+ raise CyclicDependencyError()
161
+ return execution_order
src/workflows/validators.py ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import keyword
2
+ import re
3
+ from dataclasses import dataclass
4
+ from enum import Enum
5
+ from typing import Any, Dict, List, Optional, Set, Tuple, Union
6
+
7
+ from .structs import InputField, ModelStep, OutputField, Workflow
8
+
9
+ SUPPORTED_TYPES = {"str", "int", "float", "bool", "list[str]", "list[int]", "list[float]", "list[bool]"}
10
+
11
+ # Constants for validation
12
+ MAX_FIELD_NAME_LENGTH = 50
13
+ MAX_DESCRIPTION_LENGTH = 200
14
+ MAX_SYSTEM_PROMPT_LENGTH = 4000
15
+ MIN_TEMPERATURE = 0.0
16
+ MAX_TEMPERATURE = 1.0
17
+
18
+
19
+ class ValidationErrorType(Enum):
20
+ """Types of validation errors that can occur"""
21
+
22
+ STEP = "step"
23
+ DAG = "dag"
24
+ VARIABLE = "variable"
25
+ TYPE = "type"
26
+ GENERAL = "general"
27
+ NAMING = "naming"
28
+ LENGTH = "length"
29
+ RANGE = "range"
30
+
31
+
32
+ @dataclass
33
+ class ValidationError:
34
+ """Represents a validation error with type and message"""
35
+
36
+ error_type: ValidationErrorType
37
+ message: str
38
+ step_id: Optional[str] = None
39
+ field_name: Optional[str] = None
40
+
41
+
42
+ class WorkflowValidationError(Exception):
43
+ """Base class for workflow validation errors"""
44
+
45
+ def __init__(self, errors: List[ValidationError]):
46
+ self.errors = errors
47
+ super().__init__(f"Workflow validation failed with {len(errors)} errors")
48
+
49
+
50
+ class WorkflowValidator:
51
+ """Validates workflows for correctness and consistency"""
52
+
53
+ def __init__(self):
54
+ self.errors: List[ValidationError] = []
55
+ self.workflow: Optional[Workflow] = None
56
+
57
+ def validate(self, workflow: Workflow) -> bool:
58
+ """Main validation entry point"""
59
+ self.errors = []
60
+ self.workflow = workflow
61
+
62
+ # Basic workflow validation
63
+ if not self._validate_workflow_basic(workflow):
64
+ return False
65
+
66
+ # If it's a single-step workflow, use simple validation
67
+ if len(workflow.steps) == 1:
68
+ return self.validate_simple_workflow(workflow)
69
+
70
+ # Otherwise use complex validation
71
+ return self.validate_complex_workflow(workflow)
72
+
73
+ def validate_simple_workflow(self, workflow: Workflow) -> bool:
74
+ """Validates a single-step workflow"""
75
+ if not self.workflow:
76
+ return False
77
+
78
+ # Get the single step
79
+ step = next(iter(workflow.steps.values()))
80
+
81
+ # Validate the step itself
82
+ if not self._validate_step(step):
83
+ return False
84
+
85
+ # Validate input variables
86
+ for input_var in workflow.inputs:
87
+ if not self._is_valid_external_input(input_var):
88
+ self.errors.append(
89
+ ValidationError(ValidationErrorType.VARIABLE, f"Invalid input variable format: {input_var}")
90
+ )
91
+ return False
92
+
93
+ # Validate output variables
94
+ for output_name, output_var in workflow.outputs.items():
95
+ if not output_var:
96
+ self.errors.append(
97
+ ValidationError(ValidationErrorType.VARIABLE, f"Missing output variable for {output_name}")
98
+ )
99
+ return False
100
+
101
+ # Check if output variable references a valid step output
102
+ if not self._is_valid_variable_reference(output_var):
103
+ self.errors.append(
104
+ ValidationError(ValidationErrorType.VARIABLE, f"Invalid output variable reference: {output_var}")
105
+ )
106
+ return False
107
+
108
+ # Verify the output field exists in the step
109
+ _, field_name = self._parse_variable_reference(output_var)
110
+ if not any(field.name == field_name for field in step.output_fields):
111
+ self.errors.append(
112
+ ValidationError(
113
+ ValidationErrorType.VARIABLE,
114
+ f"Output field '{field_name}' not found in step '{step.id}'",
115
+ step.id,
116
+ field_name,
117
+ )
118
+ )
119
+ return False
120
+
121
+ return True
122
+
123
+ def validate_complex_workflow(self, workflow: Workflow) -> bool:
124
+ """Validates a multi-step workflow"""
125
+ if not self.workflow:
126
+ return False
127
+
128
+ # Validate each step
129
+ for step in workflow.steps.values():
130
+ if not self._validate_step(step):
131
+ return False
132
+
133
+ # Validate input variables
134
+ for input_var in workflow.inputs:
135
+ if not self._is_valid_external_input(input_var):
136
+ self.errors.append(
137
+ ValidationError(ValidationErrorType.VARIABLE, f"Invalid input variable format: {input_var}")
138
+ )
139
+ return False
140
+
141
+ # Validate output variables
142
+ for output_name, output_var in workflow.outputs.items():
143
+ if not output_var:
144
+ self.errors.append(
145
+ ValidationError(ValidationErrorType.VARIABLE, f"Missing output variable for {output_name}")
146
+ )
147
+ return False
148
+
149
+ if not self._is_valid_variable_reference(output_var):
150
+ self.errors.append(
151
+ ValidationError(ValidationErrorType.VARIABLE, f"Invalid output variable reference: {output_var}")
152
+ )
153
+ return False
154
+
155
+ # Verify the output field exists in the referenced step
156
+ step_id, field_name = self._parse_variable_reference(output_var)
157
+ if step_id not in workflow.steps:
158
+ self.errors.append(
159
+ ValidationError(ValidationErrorType.VARIABLE, f"Referenced step '{step_id}' not found")
160
+ )
161
+ return False
162
+
163
+ ref_step = workflow.steps[step_id]
164
+ if not any(field.name == field_name for field in ref_step.output_fields):
165
+ self.errors.append(
166
+ ValidationError(
167
+ ValidationErrorType.VARIABLE,
168
+ f"Output field '{field_name}' not found in step '{step_id}'",
169
+ step_id,
170
+ field_name,
171
+ )
172
+ )
173
+ return False
174
+
175
+ # Build dependency graph
176
+ dep_graph: Dict[str, Set[str]] = {}
177
+ for step_id, step in workflow.steps.items():
178
+ dep_graph[step_id] = self._get_step_dependencies(step)
179
+
180
+ # Check for cycles in step dependencies
181
+ visited = set()
182
+ path = set()
183
+
184
+ def has_cycle(node: str) -> bool:
185
+ if node in path:
186
+ return True
187
+ if node in visited:
188
+ return False
189
+
190
+ visited.add(node)
191
+ path.add(node)
192
+
193
+ for neighbor in dep_graph.get(node, set()):
194
+ if has_cycle(neighbor):
195
+ return True
196
+
197
+ path.remove(node)
198
+ return False
199
+
200
+ # Check each step for cycles
201
+ for step_id in workflow.steps:
202
+ if has_cycle(step_id):
203
+ self.errors.append(
204
+ ValidationError(ValidationErrorType.DAG, f"Circular dependency detected involving step: {step_id}")
205
+ )
206
+ return False
207
+
208
+ # Check for orphaned steps (steps that aren't used by any other step)
209
+ used_steps = set()
210
+ for deps in dep_graph.values():
211
+ used_steps.update(deps)
212
+ print("Used steps: ", used_steps)
213
+ for step_id in workflow.steps:
214
+ if step_id not in used_steps and not any(
215
+ output_var and self._parse_variable_reference(output_var)[0] == step_id
216
+ for output_var in workflow.outputs.values()
217
+ ):
218
+ self.errors.append(ValidationError(ValidationErrorType.DAG, f"Orphaned step detected: {step_id}"))
219
+ return False
220
+
221
+ # Validate variable dependencies
222
+ if not self._validate_variable_dependencies(workflow):
223
+ return False
224
+
225
+ return True
226
+
227
+ def _validate_workflow_basic(self, workflow: Workflow) -> bool:
228
+ """Validates basic workflow properties"""
229
+ # Check for atleast one input
230
+ if not workflow.inputs:
231
+ self.errors.append(
232
+ ValidationError(ValidationErrorType.GENERAL, "Workflow must contain at least one input")
233
+ )
234
+ return False
235
+
236
+ if not workflow.outputs:
237
+ self.errors.append(
238
+ ValidationError(ValidationErrorType.GENERAL, "Workflow must contain at least one output")
239
+ )
240
+ return False
241
+
242
+ for output_var in workflow.outputs.values():
243
+ if output_var is None:
244
+ self.errors.append(ValidationError(ValidationErrorType.GENERAL, "Output variable cannot be None"))
245
+ return False
246
+
247
+ # Check for empty workflow
248
+ if not workflow.steps:
249
+ self.errors.append(ValidationError(ValidationErrorType.GENERAL, "Workflow must contain at least one step"))
250
+ return False
251
+
252
+ # Check for step ID consistency
253
+ for step_id, step in workflow.steps.items():
254
+ if step_id != step.id:
255
+ self.errors.append(
256
+ ValidationError(ValidationErrorType.STEP, f"Step ID mismatch: {step_id} != {step.id}", step_id)
257
+ )
258
+ return False
259
+ return True
260
+
261
+ def _validate_step(self, step: ModelStep) -> bool:
262
+ """Validates a single step"""
263
+ # Validate required fields
264
+ if not step.id or not step.name or not step.model or not step.provider or not step.call_type:
265
+ self.errors.append(ValidationError(ValidationErrorType.STEP, "Step missing required fields", step.id))
266
+ return False
267
+
268
+ # Validate step ID and name
269
+ if not self._is_valid_identifier(step.id):
270
+ self.errors.append(
271
+ ValidationError(
272
+ ValidationErrorType.NAMING,
273
+ f"Invalid step ID format: {step.id}. Must be a valid Python identifier.",
274
+ step.id,
275
+ )
276
+ )
277
+ return False
278
+
279
+ # Validate temperature for LLM call type
280
+ if step.call_type == "llm":
281
+ if step.temperature is None:
282
+ self.errors.append(
283
+ ValidationError(ValidationErrorType.STEP, "LLM step must specify temperature", step.id)
284
+ )
285
+ return False
286
+
287
+ if not MIN_TEMPERATURE <= step.temperature <= MAX_TEMPERATURE:
288
+ self.errors.append(
289
+ ValidationError(
290
+ ValidationErrorType.RANGE,
291
+ f"Temperature must be between {MIN_TEMPERATURE} and {MAX_TEMPERATURE}",
292
+ step.id,
293
+ )
294
+ )
295
+ return False
296
+
297
+ # Validate system prompt for LLM call type
298
+ if step.call_type == "llm":
299
+ if not step.system_prompt:
300
+ self.errors.append(
301
+ ValidationError(ValidationErrorType.STEP, "LLM step must specify system prompt", step.id)
302
+ )
303
+ return False
304
+
305
+ if len(step.system_prompt) > MAX_SYSTEM_PROMPT_LENGTH:
306
+ self.errors.append(
307
+ ValidationError(
308
+ ValidationErrorType.LENGTH,
309
+ f"System prompt exceeds maximum length of {MAX_SYSTEM_PROMPT_LENGTH} characters",
310
+ step.id,
311
+ )
312
+ )
313
+ return False
314
+
315
+ # Validate input fields
316
+ input_names = set()
317
+ for field in step.input_fields:
318
+ if not self._validate_input_field(field):
319
+ return False
320
+ if field.name in input_names:
321
+ self.errors.append(
322
+ ValidationError(
323
+ ValidationErrorType.STEP, f"Duplicate input field name: {field.name}", step.id, field.name
324
+ )
325
+ )
326
+ return False
327
+ input_names.add(field.name)
328
+
329
+ # Validate output fields
330
+ output_names = set()
331
+ for field in step.output_fields:
332
+ if not self._validate_output_field(field):
333
+ return False
334
+ if field.name in output_names:
335
+ self.errors.append(
336
+ ValidationError(
337
+ ValidationErrorType.STEP, f"Duplicate output field name: {field.name}", step.id, field.name
338
+ )
339
+ )
340
+ return False
341
+ output_names.add(field.name)
342
+
343
+ return True
344
+
345
+ def _validate_input_field(self, field: InputField) -> bool:
346
+ """Validates an input field"""
347
+ # Validate required fields
348
+ if not field.name or not field.description or not field.variable:
349
+ self.errors.append(
350
+ ValidationError(ValidationErrorType.STEP, "Input field missing required fields", field_name=field.name)
351
+ )
352
+ return False
353
+
354
+ # Validate field name
355
+ if not self._is_valid_identifier(field.name):
356
+ self.errors.append(
357
+ ValidationError(
358
+ ValidationErrorType.NAMING,
359
+ f"Invalid field name format: {field.name}. Must be a valid Python identifier.",
360
+ field_name=field.name,
361
+ )
362
+ )
363
+ return False
364
+
365
+ # Validate field name length
366
+ if len(field.name) > MAX_FIELD_NAME_LENGTH:
367
+ self.errors.append(
368
+ ValidationError(
369
+ ValidationErrorType.LENGTH,
370
+ f"Field name exceeds maximum length of {MAX_FIELD_NAME_LENGTH} characters",
371
+ field_name=field.name,
372
+ )
373
+ )
374
+ return False
375
+
376
+ # Validate description length
377
+ if len(field.description) > MAX_DESCRIPTION_LENGTH:
378
+ self.errors.append(
379
+ ValidationError(
380
+ ValidationErrorType.LENGTH,
381
+ f"Description exceeds maximum length of {MAX_DESCRIPTION_LENGTH} characters",
382
+ field_name=field.name,
383
+ )
384
+ )
385
+ return False
386
+
387
+ # Validate variable reference
388
+ if not self._is_valid_variable_reference(field.variable):
389
+ self.errors.append(
390
+ ValidationError(
391
+ ValidationErrorType.VARIABLE,
392
+ f"Invalid variable reference: {field.variable}",
393
+ field_name=field.name,
394
+ )
395
+ )
396
+ return False
397
+
398
+ return True
399
+
400
+ def _validate_output_field(self, field: OutputField) -> bool:
401
+ """Validates an output field"""
402
+ # Validate required fields
403
+ if not field.name or not field.description:
404
+ self.errors.append(
405
+ ValidationError(
406
+ ValidationErrorType.STEP, "Output field missing required fields", field_name=field.name
407
+ )
408
+ )
409
+ return False
410
+
411
+ # Validate field name
412
+ if not self._is_valid_identifier(field.name):
413
+ self.errors.append(
414
+ ValidationError(
415
+ ValidationErrorType.NAMING,
416
+ f"Invalid field name format: {field.name}. Must be a valid Python identifier.",
417
+ field_name=field.name,
418
+ )
419
+ )
420
+ return False
421
+
422
+ # Validate field name length
423
+ if len(field.name) > MAX_FIELD_NAME_LENGTH:
424
+ self.errors.append(
425
+ ValidationError(
426
+ ValidationErrorType.LENGTH,
427
+ f"Field name exceeds maximum length of {MAX_FIELD_NAME_LENGTH} characters",
428
+ field_name=field.name,
429
+ )
430
+ )
431
+ return False
432
+
433
+ # Validate description length
434
+ if len(field.description) > MAX_DESCRIPTION_LENGTH:
435
+ self.errors.append(
436
+ ValidationError(
437
+ ValidationErrorType.LENGTH,
438
+ f"Description exceeds maximum length of {MAX_DESCRIPTION_LENGTH} characters",
439
+ field_name=field.name,
440
+ )
441
+ )
442
+ return False
443
+
444
+ # Validate type
445
+ if field.type not in SUPPORTED_TYPES:
446
+ self.errors.append(
447
+ ValidationError(
448
+ ValidationErrorType.TYPE, f"Unsupported output type: {field.type}", field_name=field.name
449
+ )
450
+ )
451
+ return False
452
+
453
+ return True
454
+
455
+ def _validate_simple_workflow_variables(self, workflow: Workflow) -> bool:
456
+ """Validates variables in a simple workflow"""
457
+ step = next(iter(workflow.steps.values()))
458
+
459
+ # Validate input variables
460
+ for input_var in workflow.inputs:
461
+ if not self._is_valid_external_input(input_var):
462
+ self.errors.append(
463
+ ValidationError(ValidationErrorType.VARIABLE, f"Invalid input variable format: {input_var}")
464
+ )
465
+ return False
466
+
467
+ # Validate output variables
468
+ for output_name, output_var in workflow.outputs.items():
469
+ if output_var and not self._is_valid_variable_reference(output_var):
470
+ self.errors.append(
471
+ ValidationError(ValidationErrorType.VARIABLE, f"Invalid output variable reference: {output_var}")
472
+ )
473
+ return False
474
+
475
+ return True
476
+
477
+ def _validate_variable_dependencies(self, workflow: Workflow) -> bool:
478
+ """Validates variable dependencies between steps"""
479
+ # Build variable dependency graph
480
+ var_graph: Dict[str, Set[str]] = {}
481
+
482
+ for step_id, step in workflow.steps.items():
483
+ for field in step.input_fields:
484
+ if field.variable not in var_graph:
485
+ var_graph[field.variable] = set()
486
+
487
+ # Add dependency from input variable to step's outputs
488
+ for output in step.output_fields:
489
+ var_graph[field.variable].add(f"{step_id}.{output.name}")
490
+
491
+ # Check for cycles in variable dependencies
492
+ visited = set()
493
+ path = set()
494
+
495
+ def has_cycle(node: str) -> bool:
496
+ if node in path:
497
+ return True
498
+ if node in visited:
499
+ return False
500
+
501
+ visited.add(node)
502
+ path.add(node)
503
+
504
+ for neighbor in var_graph.get(node, set()):
505
+ if has_cycle(neighbor):
506
+ return True
507
+
508
+ path.remove(node)
509
+ return False
510
+
511
+ # Check each variable for cycles
512
+ for var in var_graph:
513
+ if has_cycle(var):
514
+ self.errors.append(
515
+ ValidationError(ValidationErrorType.VARIABLE, f"Circular variable dependency detected: {var}")
516
+ )
517
+ return False
518
+
519
+ # Validate external input existence
520
+ external_inputs = set(workflow.inputs)
521
+ for step in workflow.steps.values():
522
+ for field in step.input_fields:
523
+ step_id, field_name = self._parse_variable_reference(field.variable)
524
+ if not step_id and field_name not in external_inputs:
525
+ self.errors.append(
526
+ ValidationError(
527
+ ValidationErrorType.VARIABLE,
528
+ f"External input '{field_name}' not found in workflow inputs",
529
+ field_name=field_name,
530
+ )
531
+ )
532
+ return False
533
+
534
+ return True
535
+
536
+ def _get_step_dependencies(self, step: ModelStep) -> Set[str]:
537
+ """Gets set of step IDs that this step depends on"""
538
+ deps = set()
539
+ for field in step.input_fields:
540
+ step_id = self._parse_variable_reference(field.variable)[0]
541
+ if step_id:
542
+ deps.add(step_id)
543
+ return deps
544
+
545
+ def _parse_variable_reference(self, var: str) -> Tuple[Optional[str], str]:
546
+ """Extracts step_id and field_name from variable reference"""
547
+ parts = var.split(".")
548
+ if len(parts) == 1:
549
+ return None, parts[0]
550
+ return parts[0], parts[1]
551
+
552
+ def _is_valid_variable_reference(self, var: str) -> bool:
553
+ """Validates if a variable reference is properly formatted"""
554
+ if not self.workflow:
555
+ return False
556
+ parts = var.split(".")
557
+ if len(parts) == 1:
558
+ return True # External input
559
+ if len(parts) != 2:
560
+ return False
561
+ step_id, field_name = parts
562
+ return step_id in self.workflow.steps and any(
563
+ field.name == field_name for field in self.workflow.steps[step_id].output_fields
564
+ )
565
+
566
+ def _is_valid_external_input(self, var: str) -> bool:
567
+ """Validates if a variable is a valid external input"""
568
+ if not var:
569
+ return False
570
+ if not self._is_valid_identifier(var):
571
+ return False
572
+ if keyword.iskeyword(var):
573
+ return False
574
+ if "." in var: # External inputs should not contain dots
575
+ return False
576
+ return True
577
+
578
+ def _is_valid_identifier(self, name: str) -> bool:
579
+ """Validates if a string is a valid Python identifier"""
580
+ if not name:
581
+ return False
582
+ if keyword.iskeyword(name):
583
+ return False
584
+ if not name.strip(): # Check for whitespace-only strings
585
+ return False
586
+ return bool(re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", name))
tests/conftest.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ # Add the src directory to the PYTHONPATH
5
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../src"))
tests/test_executors.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from unittest.mock import patch
3
+
4
+ import pytest
5
+
6
+ from workflows.errors import CyclicDependencyError, WorkflowError
7
+ from workflows.executors import (
8
+ create_processed_inputs,
9
+ execute_model_step,
10
+ execute_workflow,
11
+ lower,
12
+ upper,
13
+ )
14
+ from workflows.structs import InputField, ModelStep, OutputField, Workflow
15
+
16
+ # Tests for utility functions
17
+
18
+
19
+ def test_upper():
20
+ """Test the upper function with different input types."""
21
+ assert upper("hello") == "HELLO"
22
+ assert upper("Hello World") == "HELLO WORLD"
23
+ assert upper("") == ""
24
+ # Non-string inputs should be returned unchanged
25
+ assert upper(123) == 123
26
+ assert upper([1, 2, 3]) == [1, 2, 3]
27
+ assert upper(None) is None
28
+
29
+
30
+ def test_lower():
31
+ """Test the lower function with different input types."""
32
+ assert lower("HELLO") == "hello"
33
+ assert lower("Hello World") == "hello world"
34
+ assert lower("") == ""
35
+ # Non-string inputs should be returned unchanged
36
+ assert lower(123) == 123
37
+ assert lower([1, 2, 3]) == [1, 2, 3]
38
+ assert lower(None) is None
39
+
40
+
41
+ # Tests for create_processed_inputs
42
+
43
+
44
+ def test_create_processed_inputs_basic():
45
+ """Test basic input processing without transformations."""
46
+ step = ModelStep(
47
+ id="test_step",
48
+ model="gpt-4",
49
+ provider="openai",
50
+ call_type="llm",
51
+ system_prompt="Test prompt",
52
+ input_fields=[InputField(name="text", description="Input text", variable="input_text")],
53
+ output_fields=[],
54
+ )
55
+ available_vars = {"input_text": "Hello World"}
56
+
57
+ result = create_processed_inputs(step, available_vars)
58
+ assert result == {"text": "Hello World"}
59
+
60
+
61
+ def test_create_processed_inputs_with_transformation():
62
+ """Test input processing with transformation functions."""
63
+ step = ModelStep(
64
+ id="test_step",
65
+ model="gpt-4",
66
+ provider="openai",
67
+ call_type="llm",
68
+ system_prompt="Test prompt",
69
+ input_fields=[
70
+ InputField(name="upper_text", description="Uppercase text", variable="input_text", func="upper"),
71
+ InputField(name="lower_text", description="Lowercase text", variable="input_caps", func="lower"),
72
+ ],
73
+ output_fields=[],
74
+ )
75
+ available_vars = {"input_text": "hello", "input_caps": "WORLD"}
76
+
77
+ result = create_processed_inputs(step, available_vars)
78
+ assert result == {"upper_text": "HELLO", "lower_text": "world"}
79
+
80
+
81
+ def test_create_processed_inputs_missing_var():
82
+ """Test that appropriate error is raised when a variable is missing."""
83
+ step = ModelStep(
84
+ id="test_step",
85
+ model="gpt-4",
86
+ provider="openai",
87
+ call_type="llm",
88
+ system_prompt="Test prompt",
89
+ input_fields=[InputField(name="text", description="Input text", variable="missing_var")],
90
+ output_fields=[],
91
+ )
92
+ available_vars = {"input_text": "Hello World"}
93
+
94
+ with pytest.raises(KeyError):
95
+ create_processed_inputs(step, available_vars)
96
+
97
+
98
+ def test_create_processed_inputs_unknown_func():
99
+ """Test that appropriate error is raised when an unknown function is specified."""
100
+ step = ModelStep(
101
+ id="test_step",
102
+ model="gpt-4",
103
+ provider="openai",
104
+ call_type="llm",
105
+ system_prompt="Test prompt",
106
+ input_fields=[InputField(name="text", description="Input text", variable="input_text", func="unknown_func")],
107
+ output_fields=[],
108
+ )
109
+ available_vars = {"input_text": "Hello World"}
110
+
111
+ # This should raise an error when the function isn't found
112
+ with pytest.raises(Exception):
113
+ create_processed_inputs(step, available_vars)
114
+
115
+
116
+ # Tests for execute_model_step
117
+
118
+
119
+ @patch("workflows.executors.litellm.completion")
120
+ def test_execute_model_step_success(mock_completion):
121
+ """Test successful execution of a model step with mocked litellm response."""
122
+ # Mock the litellm response
123
+ mock_response = {"choices": [{"message": {"content": json.dumps({"summary": "This is a summary"})}}]}
124
+ mock_completion.return_value = mock_response
125
+
126
+ # Create a test step
127
+ step = ModelStep(
128
+ id="summarize",
129
+ model="gpt-3.5-turbo",
130
+ provider="openai",
131
+ call_type="llm",
132
+ system_prompt="Summarize the text",
133
+ input_fields=[InputField(name="text", description="Text to summarize", variable="input_text")],
134
+ output_fields=[OutputField(name="summary", description="Summary of the text", type="str")],
135
+ )
136
+
137
+ # Execute the step
138
+ result = execute_model_step(step, {"input_text": "Long text to be summarized..."})
139
+
140
+ # Verify the results
141
+ assert result == {"summary": "This is a summary"}
142
+
143
+ # Verify the litellm call was made correctly
144
+ mock_completion.assert_called_once()
145
+ args, kwargs = mock_completion.call_args
146
+ assert kwargs["model"] == "gpt-3.5-turbo"
147
+ assert "Summarize the text" in kwargs["messages"][0]["content"]
148
+
149
+
150
+ @patch("workflows.executors.litellm.completion")
151
+ def test_execute_model_step_error(mock_completion):
152
+ """Test handling of errors in model step execution."""
153
+ # Make litellm raise an exception
154
+ mock_completion.side_effect = Exception("API Error")
155
+
156
+ # Create a test step
157
+ step = ModelStep(
158
+ id="summarize",
159
+ model="gpt-3.5-turbo",
160
+ provider="openai",
161
+ call_type="llm",
162
+ system_prompt="Summarize the text",
163
+ input_fields=[InputField(name="text", description="Text to summarize", variable="input_text")],
164
+ output_fields=[OutputField(name="summary", description="Summary of the text", type="str")],
165
+ )
166
+
167
+ # Execute the step - should raise an exception
168
+ with pytest.raises(Exception):
169
+ execute_model_step(step, {"input_text": "Long text to be summarized..."})
170
+
171
+
172
+ # Tests for execute_workflow
173
+
174
+
175
+ @patch("workflows.executors.execute_model_step")
176
+ def test_execute_workflow_simple(mock_execute_step):
177
+ """Test execution of a simple workflow with a single step."""
178
+ # Configure mock to return expected outputs
179
+ mock_execute_step.return_value = {"summary": "This is a summary"}
180
+
181
+ # Create a simple workflow
182
+ step = ModelStep(
183
+ id="summarize",
184
+ model="gpt-3.5-turbo",
185
+ provider="openai",
186
+ call_type="llm",
187
+ system_prompt="Summarize the text",
188
+ input_fields=[InputField(name="text", description="Text to summarize", variable="input_text")],
189
+ output_fields=[OutputField(name="summary", description="Summary of the text", type="str")],
190
+ )
191
+
192
+ workflow = Workflow(steps={"summarize": step}, inputs=["input_text"], outputs={"summary": "summarize.summary"})
193
+
194
+ # Execute the workflow
195
+ result = execute_workflow(workflow, {"input_text": "Long text to be summarized..."})
196
+
197
+ # Verify the results
198
+ assert result == {"summary": "This is a summary"}
199
+
200
+ # Verify execute_model_step was called correctly
201
+ mock_execute_step.assert_called_once()
202
+
203
+
204
+ @patch("workflows.executors.execute_model_step")
205
+ def test_execute_workflow_multi_step(mock_execute_step):
206
+ """Test execution of a multi-step workflow with dependencies."""
207
+
208
+ # Configure mock to return different values based on the step
209
+ def side_effect(step, available_vars):
210
+ if step.id == "extract":
211
+ return {"entities": ["Apple", "product"]}
212
+ elif step.id == "analyze":
213
+ return {"sentiment": "positive"}
214
+ return {}
215
+
216
+ mock_execute_step.side_effect = side_effect
217
+
218
+ # Create extract step
219
+ extract_step = ModelStep(
220
+ id="extract",
221
+ model="gpt-3.5-turbo",
222
+ provider="openai",
223
+ call_type="llm",
224
+ system_prompt="Extract entities",
225
+ input_fields=[InputField(name="text", description="Text to analyze", variable="input_text")],
226
+ output_fields=[OutputField(name="entities", description="Extracted entities", type="list[str]")],
227
+ )
228
+
229
+ # Create analyze step that depends on extract step
230
+ analyze_step = ModelStep(
231
+ id="analyze",
232
+ model="gpt-4",
233
+ provider="openai",
234
+ call_type="llm",
235
+ system_prompt="Analyze sentiment",
236
+ input_fields=[InputField(name="entities", description="Entities to analyze", variable="extract.entities")],
237
+ output_fields=[OutputField(name="sentiment", description="Sentiment analysis", type="str")],
238
+ )
239
+
240
+ workflow = Workflow(
241
+ steps={"extract": extract_step, "analyze": analyze_step},
242
+ inputs=["input_text"],
243
+ outputs={"entities": "extract.entities", "sentiment": "analyze.sentiment"},
244
+ )
245
+
246
+ # Execute the workflow
247
+ result = execute_workflow(workflow, {"input_text": "Apple is launching a new product tomorrow."})
248
+
249
+ # Verify the results
250
+ assert result == {"entities": ["Apple", "product"], "sentiment": "positive"}
251
+
252
+ # Verify execute_model_step was called twice (once for each step)
253
+ assert mock_execute_step.call_count == 2
254
+
255
+
256
+ def test_execute_workflow_missing_input():
257
+ """Test that an error is raised when a required input is missing."""
258
+ step = ModelStep(
259
+ id="summarize",
260
+ model="gpt-3.5-turbo",
261
+ provider="openai",
262
+ call_type="llm",
263
+ system_prompt="Summarize the text",
264
+ input_fields=[InputField(name="text", description="Text to summarize", variable="input_text")],
265
+ output_fields=[OutputField(name="summary", description="Summary of the text", type="str")],
266
+ )
267
+
268
+ workflow = Workflow(steps={"summarize": step}, inputs=["input_text"], outputs={"summary": "summarize.summary"})
269
+
270
+ # Execute with missing input
271
+ with pytest.raises(WorkflowError, match="Missing required workflow input"):
272
+ execute_workflow(workflow, {})
273
+
274
+
275
+ @patch("workflows.executors.create_dependency_graph")
276
+ def test_execute_workflow_cyclic_dependency(mock_dependency_graph):
277
+ """Test that a cyclic dependency in the workflow raises an appropriate error."""
278
+ # Make create_dependency_graph raise a CyclicDependencyError
279
+ mock_dependency_graph.side_effect = CyclicDependencyError()
280
+
281
+ step = ModelStep(
282
+ id="test",
283
+ model="gpt-3.5-turbo",
284
+ provider="openai",
285
+ call_type="llm",
286
+ system_prompt="Test",
287
+ input_fields=[],
288
+ output_fields=[],
289
+ )
290
+
291
+ workflow = Workflow(steps={"test": step}, inputs=[], outputs=[])
292
+
293
+ # This should propagate the CyclicDependencyError
294
+ with pytest.raises(CyclicDependencyError):
295
+ execute_workflow(workflow, {})
tests/test_utils.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ from workflows.errors import CyclicDependencyError, UnknownVariableError, WorkflowError
4
+ from workflows.utils import _create_variable_step_mapping, create_dependency_graph, topological_sort
5
+
6
+
7
+ # Dummy classes to simulate Workflow, Step, and Field
8
+ class DummyField:
9
+ def __init__(self, name, type="str", variable=None):
10
+ self.name = name
11
+ self.type = type
12
+ # For input fields, variable property is needed
13
+ self.variable = variable if variable is not None else name
14
+
15
+
16
+ class DummyStep:
17
+ def __init__(self, input_fields, output_fields):
18
+ self.input_fields = input_fields
19
+ self.output_fields = output_fields
20
+
21
+
22
+ class DummyWorkflow:
23
+ def __init__(self, steps):
24
+ # steps is a dict with key as step_id and value as DummyStep
25
+ self.steps = steps
26
+
27
+
28
+ # Tests for _create_variable_step_mapping
29
+
30
+
31
+ def test_create_variable_step_mapping_success():
32
+ # Create a workflow with two steps producing unique output variables
33
+ step_a = DummyStep(input_fields=[], output_fields=[DummyField("out1")])
34
+ step_b = DummyStep(input_fields=[], output_fields=[DummyField("out2")])
35
+ workflow = DummyWorkflow({"A": step_a, "B": step_b})
36
+ mapping = _create_variable_step_mapping(workflow)
37
+ assert mapping == {"A.out1": "A", "B.out2": "B"}
38
+
39
+
40
+ def test_create_variable_step_mapping_duplicate():
41
+ # Create a workflow where two steps produce an output with same name
42
+ step_a = DummyStep(input_fields=[], output_fields=[DummyField("out"), DummyField("out")])
43
+ workflow = DummyWorkflow({"A": step_a})
44
+ with pytest.raises(WorkflowError):
45
+ _create_variable_step_mapping(workflow)
46
+
47
+
48
+ def test_create_variable_step_mapping_empty():
49
+ """Test _create_variable_step_mapping with an empty workflow should return an empty mapping."""
50
+ workflow = DummyWorkflow({})
51
+ mapping = _create_variable_step_mapping(workflow)
52
+ assert mapping == {}
53
+
54
+
55
+ def test_create_variable_step_mapping_multiple_outputs():
56
+ """Test a workflow where a single step produces multiple outputs with unique names."""
57
+ step = DummyStep(input_fields=[], output_fields=[DummyField("out1"), DummyField("out2")])
58
+ workflow = DummyWorkflow({"A": step})
59
+ mapping = _create_variable_step_mapping(workflow)
60
+ assert mapping == {"A.out1": "A", "A.out2": "A"}
61
+
62
+
63
+ # Tests for create_dependency_graph
64
+
65
+
66
+ def test_create_dependency_graph_success_with_dependency():
67
+ # Step A produces 'A.out', which is used as input in step B
68
+ step_a = DummyStep(input_fields=[], output_fields=[DummyField("out")])
69
+ # For input_fields, explicitly set variable to reference A.out
70
+ step_b = DummyStep(input_fields=[DummyField("dummy", variable="A.out")], output_fields=[DummyField("result")])
71
+ workflow = DummyWorkflow({"A": step_a, "B": step_b})
72
+ # No external input provided for A.out so dependency must be created
73
+ deps = create_dependency_graph(workflow, input_values={})
74
+ # Step B depends on step A
75
+ assert deps["B"] == {"A"}
76
+ # Step A has no dependencies
77
+ assert deps["A"] == set()
78
+
79
+
80
+ def test_create_dependency_graph_success_with_external_input():
81
+ # Step B expects an input, but it is provided externally
82
+ step_b = DummyStep(
83
+ input_fields=[DummyField("param", variable="external_param")], output_fields=[DummyField("result")]
84
+ )
85
+ workflow = DummyWorkflow({"B": step_b})
86
+ # Provide external input for external_param
87
+ deps = create_dependency_graph(workflow, input_values={"external_param": 42})
88
+ # With external input, no dependency is needed
89
+ assert deps["B"] == set()
90
+
91
+
92
+ def test_create_dependency_graph_unknown_variable():
93
+ # Step B expects an input that is neither produced by any step nor provided externally
94
+ step_b = DummyStep(
95
+ input_fields=[DummyField("param", variable="non_existent")], output_fields=[DummyField("result")]
96
+ )
97
+ workflow = DummyWorkflow({"B": step_b})
98
+ with pytest.raises(UnknownVariableError):
99
+ _ = create_dependency_graph(workflow, input_values={})
100
+
101
+
102
+ def test_create_dependency_graph_complex():
103
+ """Test create_dependency_graph on a more complex workflow with multiple dependencies."""
104
+ # Step A produces A.out, Step B uses A.out, Step C uses B.out, and Step D uses both A.out and B.out
105
+ step_a = DummyStep(input_fields=[], output_fields=[DummyField("out")])
106
+ step_b = DummyStep(input_fields=[DummyField("inp", variable="A.out")], output_fields=[DummyField("out")])
107
+ step_c = DummyStep(input_fields=[DummyField("inp", variable="B.out")], output_fields=[DummyField("result")])
108
+ step_d = DummyStep(
109
+ input_fields=[DummyField("inp1", variable="A.out"), DummyField("inp2", variable="B.out")],
110
+ output_fields=[DummyField("final")],
111
+ )
112
+
113
+ workflow = DummyWorkflow({"A": step_a, "B": step_b, "C": step_c, "D": step_d})
114
+ # Provide external input for "B.out" so that step B's output isn't expected to come from a step
115
+ # However, to simulate dependency, assume external input is not provided for the dependencies used in step C and D
116
+ # Therefore, workflow must resolve A.out for step B, and then step B produces B.out for steps C and D.
117
+ # Let's not provide any external input, so both dependencies are created.
118
+
119
+ deps = create_dependency_graph(workflow, input_values={})
120
+ # Expected dependencies:
121
+ # B depends on A
122
+ # C depends on B
123
+ # D depends on both A and B
124
+ assert deps["B"] == {"A"}
125
+ assert deps["C"] == {"B"}
126
+ assert deps["D"] == {"A", "B"}
127
+
128
+
129
+ # Tests for topological_sort
130
+
131
+
132
+ def test_topological_sort_success():
133
+ # Create a simple dependency graph: A -> B -> C
134
+ deps = {"A": set(), "B": {"A"}, "C": {"B"}}
135
+ order = topological_sort(deps)
136
+ # Check that order satisfies dependencies: A before B, B before C
137
+ assert order.index("A") < order.index("B") < order.index("C")
138
+
139
+
140
+ def test_topological_sort_cycle():
141
+ # Create a cyclic dependency: A -> B and B -> A
142
+ deps = {"A": {"B"}, "B": {"A"}}
143
+ with pytest.raises(CyclicDependencyError):
144
+ _ = topological_sort(deps)
145
+
146
+
147
+ def test_topological_sort_single_node():
148
+ """Test topological_sort on a graph with a single node and no dependencies."""
149
+ deps = {"A": set()}
150
+ order = topological_sort(deps)
151
+ assert order == ["A"]
152
+
153
+
154
+ def test_topological_sort_disconnected():
155
+ """Test topological_sort on a graph with disconnected nodes (no dependencies among them)."""
156
+ deps = {"A": set(), "B": set(), "C": set()}
157
+ order = topological_sort(deps)
158
+ # The order can be in any permutation, but must contain all nodes
159
+ assert set(order) == {"A", "B", "C"}
tests/test_validators.py ADDED
@@ -0,0 +1,647 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+
3
+ import pytest
4
+ from pydantic import ValidationError as PydanticValidationError
5
+
6
+ from workflows.structs import InputField, ModelStep, OutputField, Workflow
7
+ from workflows.validators import ValidationError, ValidationErrorType, WorkflowValidator
8
+
9
+
10
+ # Test Data
11
+ def create_basic_step(step_id: str = "step1") -> ModelStep:
12
+ """Creates a basic valid step for testing"""
13
+ return ModelStep(
14
+ id=step_id,
15
+ name="Test Step",
16
+ model="gpt-4",
17
+ provider="openai",
18
+ call_type="llm",
19
+ temperature=0.7,
20
+ system_prompt="Test prompt",
21
+ input_fields=[],
22
+ output_fields=[],
23
+ )
24
+
25
+
26
+ def create_basic_workflow(steps: List[ModelStep] | None = None) -> Workflow:
27
+ """Creates a basic valid workflow for testing"""
28
+ if steps is None:
29
+ steps = [create_basic_step()]
30
+ return Workflow(inputs=[], outputs={}, steps={step.id: step for step in steps})
31
+
32
+
33
+ # Additional Test Data
34
+ def create_step_with_fields(
35
+ step_id: str, input_fields: List[InputField], output_fields: List[OutputField]
36
+ ) -> ModelStep:
37
+ """Creates a step with specific input and output fields"""
38
+ return ModelStep(
39
+ id=step_id,
40
+ name="Test Step",
41
+ model="gpt-4",
42
+ provider="openai",
43
+ call_type="llm",
44
+ temperature=0.7,
45
+ system_prompt="Test prompt",
46
+ input_fields=input_fields,
47
+ output_fields=output_fields,
48
+ )
49
+
50
+
51
+ def create_valid_workflow() -> Workflow:
52
+ # Create a step with input and output fields
53
+ step = create_step_with_fields(
54
+ "step1",
55
+ [InputField(name="input", description="test", variable="external_input")],
56
+ [OutputField(name="output", description="test", type="str")],
57
+ )
58
+
59
+ # Create workflow with the single step
60
+ workflow = create_basic_workflow([step])
61
+ workflow.inputs = ["external_input"]
62
+ workflow.outputs = {"output": "step1.output"}
63
+ return workflow
64
+
65
+
66
+ # Basic Workflow Validation Tests
67
+ class TestBasicWorkflowValidation:
68
+ def test_empty_workflow(self):
69
+ """Test validation of empty workflow"""
70
+ validator = WorkflowValidator()
71
+ workflow = Workflow(inputs=["input"], outputs={"output": "input"}, steps={})
72
+ assert not validator.validate(workflow)
73
+ assert len(validator.errors) == 1
74
+ assert validator.errors[0].error_type == ValidationErrorType.GENERAL
75
+ assert "must contain at least one step" in validator.errors[0].message
76
+
77
+ def test_workflow_without_inputs(self):
78
+ """Test validation of workflow without inputs"""
79
+ validator = WorkflowValidator()
80
+ workflow = create_basic_workflow()
81
+ workflow.inputs = []
82
+ workflow.outputs = {"output": "step1.field"}
83
+ assert not validator.validate(workflow)
84
+ assert len(validator.errors) == 1
85
+ assert validator.errors[0].error_type == ValidationErrorType.GENERAL
86
+ assert "must contain at least one input" in validator.errors[0].message
87
+
88
+ def test_workflow_without_outputs(self):
89
+ """Test validation of workflow without outputs"""
90
+ validator = WorkflowValidator()
91
+ workflow = create_basic_workflow()
92
+ workflow.inputs = ["input"]
93
+ workflow.outputs = {}
94
+ assert not validator.validate(workflow)
95
+ assert len(validator.errors) == 1
96
+ assert validator.errors[0].error_type == ValidationErrorType.GENERAL
97
+ assert "must contain at least one output" in validator.errors[0].message
98
+
99
+ def test_single_step_workflow(self):
100
+ """Test validation of valid single-step workflow"""
101
+ validator = WorkflowValidator()
102
+
103
+ # Create a step with input and output fields
104
+ workflow = create_valid_workflow()
105
+
106
+ assert validator.validate(workflow)
107
+ assert len(validator.errors) == 0
108
+
109
+
110
+ # Step Validation Tests
111
+ class TestStepValidation:
112
+ def test_missing_required_fields(self):
113
+ """Test validation of step with missing required fields"""
114
+ validator = WorkflowValidator()
115
+ step = ModelStep(
116
+ id="step1",
117
+ name="", # Missing name
118
+ model="", # Missing model
119
+ provider="", # Missing provider
120
+ call_type="", # Missing call_type
121
+ temperature=0.7,
122
+ system_prompt="Test prompt",
123
+ input_fields=[],
124
+ output_fields=[],
125
+ )
126
+ workflow = create_basic_workflow([step])
127
+ workflow.inputs = ["input"]
128
+ workflow.outputs = {"output": "step1.field"}
129
+ assert not validator.validate(workflow)
130
+ assert len(validator.errors) == 1
131
+ assert validator.errors[0].error_type == ValidationErrorType.STEP
132
+
133
+ def test_invalid_step_id(self):
134
+ """Test validation of step with invalid ID format"""
135
+ validator = WorkflowValidator()
136
+ step = create_basic_step("123invalid") # Invalid ID format
137
+ workflow = create_basic_workflow([step])
138
+ workflow.inputs = ["input"]
139
+ workflow.outputs = {"output": "step1.field"}
140
+ assert not validator.validate(workflow)
141
+ assert len(validator.errors) == 1
142
+ assert validator.errors[0].error_type == ValidationErrorType.NAMING
143
+
144
+ def test_llm_temperature_validation(self):
145
+ """Test validation of LLM step temperature"""
146
+ validator = WorkflowValidator()
147
+
148
+ # Test invalid temperature
149
+ step = create_basic_step()
150
+ step.temperature = 1.5 # Invalid temperature
151
+ workflow = create_basic_workflow([step])
152
+ workflow.inputs = ["input"]
153
+ workflow.outputs = {"output": "step1.field"}
154
+ assert not validator.validate(workflow)
155
+ assert len(validator.errors) == 1
156
+ assert validator.errors[0].error_type == ValidationErrorType.RANGE
157
+
158
+ # Test missing temperature
159
+ step = create_basic_step()
160
+ step.temperature = None # Missing temperature
161
+ workflow = create_basic_workflow([step])
162
+ workflow.inputs = ["input"]
163
+ workflow.outputs = {"output": "step1.field"}
164
+ assert not validator.validate(workflow)
165
+ assert len(validator.errors) == 1
166
+ assert validator.errors[0].error_type == ValidationErrorType.STEP
167
+
168
+ def test_llm_system_prompt_validation(self):
169
+ """Test validation of LLM step system prompt"""
170
+ validator = WorkflowValidator()
171
+
172
+ # Test missing system prompt
173
+ step = create_basic_step()
174
+ step.system_prompt = "" # Missing system prompt
175
+ workflow = create_basic_workflow([step])
176
+ workflow.inputs = ["input"]
177
+ workflow.outputs = {"output": "step1.field"}
178
+ assert not validator.validate(workflow)
179
+ assert len(validator.errors) == 1
180
+ assert validator.errors[0].error_type == ValidationErrorType.STEP
181
+
182
+ # Test too long system prompt
183
+ step = create_basic_step()
184
+ step.system_prompt = "x" * 4001 # Too long
185
+ workflow = create_basic_workflow([step])
186
+ workflow.inputs = ["input"]
187
+ workflow.outputs = {"output": "step1.field"}
188
+ assert not validator.validate(workflow)
189
+ assert len(validator.errors) == 1
190
+ assert validator.errors[0].error_type == ValidationErrorType.LENGTH
191
+
192
+
193
+ # Field Validation Tests
194
+ class TestFieldValidation:
195
+ def test_input_field_validation(self):
196
+ """Test validation of input fields"""
197
+ validator = WorkflowValidator()
198
+
199
+ # Test missing required fields
200
+ step = create_basic_step()
201
+ step.input_fields = [InputField(name="", description="", variable="")]
202
+ workflow = create_basic_workflow([step])
203
+ workflow.inputs = ["input"]
204
+ workflow.outputs = {"output": "step1.field"}
205
+ assert not validator.validate(workflow)
206
+ assert len(validator.errors) == 1
207
+ assert validator.errors[0].error_type == ValidationErrorType.STEP
208
+
209
+ # Test invalid field name
210
+ step = create_basic_step()
211
+ step.input_fields = [InputField(name="123invalid", description="test", variable="test")]
212
+ workflow = create_basic_workflow([step])
213
+ workflow.inputs = ["input"]
214
+ workflow.outputs = {"output": "step1.field"}
215
+ assert not validator.validate(workflow)
216
+ assert len(validator.errors) == 1
217
+ assert validator.errors[0].error_type == ValidationErrorType.NAMING
218
+
219
+ # Test too long description
220
+ step = create_basic_step()
221
+ step.input_fields = [InputField(name="test", description="x" * 201, variable="test")]
222
+ workflow = create_basic_workflow([step])
223
+ workflow.inputs = ["input"]
224
+ workflow.outputs = {"output": "step1.field"}
225
+ assert not validator.validate(workflow)
226
+ assert len(validator.errors) == 1
227
+ assert validator.errors[0].error_type == ValidationErrorType.LENGTH
228
+
229
+ def test_output_field_validation(self):
230
+ """Test validation of output fields"""
231
+ validator = WorkflowValidator()
232
+
233
+ # Test missing required fields
234
+ step = create_basic_step()
235
+ step.output_fields = [OutputField(name="", description="", type="str")]
236
+ workflow = create_basic_workflow([step])
237
+ workflow.inputs = ["input"]
238
+ workflow.outputs = {"output": "step1.field"}
239
+ assert not validator.validate(workflow)
240
+ assert len(validator.errors) == 1
241
+ assert validator.errors[0].error_type == ValidationErrorType.STEP
242
+
243
+ # Test invalid field name
244
+ step = create_basic_step()
245
+ step.output_fields = [OutputField(name="123invalid", description="test", type="str")]
246
+ workflow = create_basic_workflow([step])
247
+ workflow.inputs = ["input"]
248
+ workflow.outputs = {"output": "step1.field"}
249
+ assert not validator.validate(workflow)
250
+ assert len(validator.errors) == 1
251
+ assert validator.errors[0].error_type == ValidationErrorType.NAMING
252
+
253
+ def test_field_name_length(self):
254
+ """Test validation of field name length"""
255
+ validator = WorkflowValidator()
256
+
257
+ # Test too long field name
258
+ step = create_basic_step()
259
+ step.input_fields = [InputField(name="x" * 51, description="test", variable="test")]
260
+ workflow = create_basic_workflow([step])
261
+ workflow.inputs = ["input"]
262
+ workflow.outputs = {"output": "step1.field"}
263
+ assert not validator.validate(workflow)
264
+ assert len(validator.errors) == 1
265
+ assert validator.errors[0].error_type == ValidationErrorType.LENGTH
266
+
267
+ def test_field_description_length(self):
268
+ """Test validation of field description length"""
269
+ validator = WorkflowValidator()
270
+
271
+ # Test too long description
272
+ step = create_basic_step()
273
+ step.input_fields = [InputField(name="test", description="x" * 201, variable="test")]
274
+ workflow = create_basic_workflow([step])
275
+ workflow.inputs = ["input"]
276
+ workflow.outputs = {"output": "step1.field"}
277
+ assert not validator.validate(workflow)
278
+ assert len(validator.errors) == 1
279
+ assert validator.errors[0].error_type == ValidationErrorType.LENGTH
280
+
281
+ def test_whitespace_only_strings(self):
282
+ """Test validation of whitespace-only strings"""
283
+ validator = WorkflowValidator()
284
+
285
+ # Test whitespace-only field name
286
+ step = create_basic_step()
287
+ step.input_fields = [InputField(name=" ", description="test", variable="test")]
288
+ workflow = create_basic_workflow([step])
289
+ workflow.inputs = ["input"]
290
+ workflow.outputs = {"output": "step1.field"}
291
+ assert not validator.validate(workflow)
292
+ assert len(validator.errors) == 1
293
+ assert validator.errors[0].error_type == ValidationErrorType.NAMING
294
+
295
+ def test_special_characters(self):
296
+ """Test validation of special characters in names"""
297
+ validator = WorkflowValidator()
298
+
299
+ # Test special characters in field name
300
+ step = create_basic_step()
301
+ step.input_fields = [InputField(name="test@field", description="test", variable="test")]
302
+ workflow = create_basic_workflow([step])
303
+ workflow.inputs = ["input"]
304
+ workflow.outputs = {"output": "step1.field"}
305
+ assert not validator.validate(workflow)
306
+ assert len(validator.errors) == 1
307
+ assert validator.errors[0].error_type == ValidationErrorType.NAMING
308
+
309
+
310
+ # Variable Reference Tests
311
+ class TestVariableReference:
312
+ def test_external_input_validation(self):
313
+ """Test validation of external input variables"""
314
+ validator = WorkflowValidator()
315
+
316
+ # Test invalid external input format
317
+ workflow = create_valid_workflow()
318
+ workflow.inputs = ["step1.field"] # Invalid format
319
+ assert not validator.validate(workflow)
320
+ assert len(validator.errors) == 1
321
+ assert validator.errors[0].error_type == ValidationErrorType.VARIABLE
322
+
323
+ def test_step_output_reference(self):
324
+ """Test validation of step output references"""
325
+ validator = WorkflowValidator()
326
+
327
+ # Test invalid output reference
328
+ workflow = create_basic_workflow()
329
+ workflow.inputs = ["input"]
330
+ workflow.outputs = {"output": "nonexistent_step.field"}
331
+ assert not validator.validate(workflow)
332
+ assert len(validator.errors) == 1
333
+ assert validator.errors[0].error_type == ValidationErrorType.VARIABLE
334
+
335
+ # Test valid output reference
336
+ step = create_basic_step()
337
+ step.output_fields = [OutputField(name="field", description="test", type="str")]
338
+ workflow = create_basic_workflow([step])
339
+ workflow.inputs = ["input"]
340
+ workflow.outputs = {"output": "step1.field"}
341
+ assert validator.validate(workflow)
342
+ assert len(validator.errors) == 0
343
+
344
+
345
+ # DAG Validation Tests
346
+ class TestDAGValidation:
347
+ def test_cycle_detection(self):
348
+ """Test detection of cycles in workflow"""
349
+ validator = WorkflowValidator()
350
+
351
+ # Create a workflow with a cycle
352
+ step1 = create_step_with_fields(
353
+ "step1",
354
+ [InputField(name="input", description="test", variable="step3.output")],
355
+ [OutputField(name="output", description="test", type="str")],
356
+ )
357
+ step2 = create_step_with_fields(
358
+ "step2",
359
+ [InputField(name="input", description="test", variable="step1.output")],
360
+ [OutputField(name="output", description="test", type="str")],
361
+ )
362
+ step3 = create_step_with_fields(
363
+ "step3",
364
+ [InputField(name="input", description="test", variable="step2.output")],
365
+ [OutputField(name="output", description="test", type="str")],
366
+ )
367
+
368
+ workflow = create_basic_workflow([step1, step2, step3])
369
+ workflow.inputs = ["input"]
370
+ workflow.outputs = {"output": "step3.output"}
371
+ assert not validator.validate(workflow)
372
+ assert len(validator.errors) == 1
373
+ assert validator.errors[0].error_type == ValidationErrorType.DAG
374
+
375
+ def test_orphaned_steps(self):
376
+ """Test detection of orphaned steps"""
377
+ validator = WorkflowValidator()
378
+
379
+ # Create a workflow with an orphaned step
380
+ step1 = create_step_with_fields(
381
+ "step1",
382
+ [InputField(name="input", description="test", variable="step2.output")],
383
+ [OutputField(name="output", description="test", type="str")],
384
+ )
385
+ step2 = create_step_with_fields(
386
+ "step2",
387
+ [InputField(name="input", description="test", variable="step1.output")],
388
+ [OutputField(name="output", description="test", type="str")],
389
+ )
390
+ step3 = create_step_with_fields(
391
+ "step3",
392
+ [],
393
+ [OutputField(name="output", description="test", type="str")],
394
+ )
395
+
396
+ workflow = create_basic_workflow([step1, step2, step3])
397
+ workflow.inputs = ["input"]
398
+ workflow.outputs = {"output": "step3.output"}
399
+ assert not validator.validate(workflow)
400
+ assert len(validator.errors) == 1
401
+ assert validator.errors[0].error_type == ValidationErrorType.DAG
402
+
403
+
404
+ # Variable Dependency Tests
405
+ class TestVariableDependencies:
406
+ def test_circular_dependencies(self):
407
+ """Test detection of circular variable dependencies"""
408
+ validator = WorkflowValidator()
409
+
410
+ # Create a workflow with circular variable dependencies
411
+ step1 = create_step_with_fields(
412
+ "step1",
413
+ [InputField(name="input", description="test", variable="step2.output")],
414
+ [OutputField(name="output", description="test", type="str")],
415
+ )
416
+ step2 = create_step_with_fields(
417
+ "step2",
418
+ [InputField(name="input", description="test", variable="step1.output")],
419
+ [OutputField(name="output", description="test", type="str")],
420
+ )
421
+
422
+ workflow = create_basic_workflow([step1, step2])
423
+ workflow.inputs = ["input"]
424
+ workflow.outputs = {"output": "step2.output"}
425
+ assert not validator.validate(workflow)
426
+ assert len(validator.errors) == 1
427
+ assert validator.errors[0].error_type == ValidationErrorType.DAG
428
+
429
+ def test_valid_dependencies(self):
430
+ """Test validation of valid variable dependencies"""
431
+ validator = WorkflowValidator()
432
+
433
+ # Create a workflow with valid dependencies
434
+ step1 = create_step_with_fields(
435
+ "step1",
436
+ [InputField(name="input", description="test", variable="external_input")],
437
+ [OutputField(name="output", description="test", type="str")],
438
+ )
439
+ step2 = create_step_with_fields(
440
+ "step2",
441
+ [InputField(name="input", description="test", variable="step1.output")],
442
+ [OutputField(name="output", description="test", type="str")],
443
+ )
444
+ step3 = create_step_with_fields(
445
+ "step3",
446
+ [InputField(name="input", description="test", variable="step2.output")],
447
+ [OutputField(name="output", description="test", type="str")],
448
+ )
449
+
450
+ workflow = create_basic_workflow([step1, step2, step3])
451
+ workflow.inputs = ["external_input"]
452
+ workflow.outputs = {"output": "step3.output"}
453
+ assert validator.validate(workflow)
454
+ assert len(validator.errors) == 0
455
+
456
+
457
+ # Type Compatibility Tests
458
+ class TestTypeCompatibility:
459
+ def test_basic_type_compatibility(self):
460
+ """Test validation of basic type compatibility"""
461
+ validator = WorkflowValidator()
462
+
463
+ # Create steps with type mismatch
464
+ step1 = create_step_with_fields(
465
+ "step1",
466
+ [InputField(name="input", description="test", variable="external_input")],
467
+ [OutputField(name="output", description="test", type="int")],
468
+ )
469
+ step2 = create_step_with_fields(
470
+ "step2",
471
+ [InputField(name="input", description="test", variable="step1.output")],
472
+ [OutputField(name="output", description="test", type="str")],
473
+ )
474
+
475
+ workflow = create_basic_workflow([step1, step2])
476
+ workflow.inputs = ["external_input"]
477
+ workflow.outputs = {"output": "step2.output"}
478
+ assert validator.validate(workflow)
479
+
480
+ # def test_list_type_compatibility(self):
481
+ # """Test validation of list type compatibility"""
482
+ # validator = WorkflowValidator()
483
+
484
+ # # Test compatible list types
485
+ # step1 = create_step_with_fields(
486
+ # "step1", [], [OutputField(name="output", description="test", type="list[str]")]
487
+ # )
488
+ # step2 = create_step_with_fields(
489
+ # "step2", [InputField(name="input", description="test", variable="step1.output")], []
490
+ # )
491
+
492
+ # workflow = create_basic_workflow([step1, step2])
493
+ # workflow.inputs = ["input"]
494
+ # workflow.outputs = {"output": "step2.output"}
495
+ # assert validator.validate(workflow)
496
+ # assert len(validator.errors) == 0
497
+
498
+ # # Test incompatible list types
499
+ # step1 = create_step_with_fields(
500
+ # "step1", [], [OutputField(name="output", description="test", type="list[int]")]
501
+ # )
502
+ # step2 = create_step_with_fields(
503
+ # "step2", [InputField(name="input", description="test", variable="step1.output")], []
504
+ # )
505
+
506
+ # workflow = create_basic_workflow([step1, step2])
507
+ # workflow.inputs = ["input"]
508
+ # workflow.outputs = {"output": "step2.output"}
509
+ # assert not validator.validate(workflow)
510
+ # assert len(validator.errors) == 1
511
+ # assert validator.errors[0].error_type == ValidationErrorType.TYPE
512
+
513
+
514
+ # Complex Workflow Tests
515
+ class TestComplexWorkflows:
516
+ def test_multi_output_workflow(self):
517
+ """Test validation of workflow with multiple outputs"""
518
+ validator = WorkflowValidator()
519
+
520
+ # Create a workflow with multiple outputs
521
+ step1 = create_step_with_fields(
522
+ "step1",
523
+ [],
524
+ [
525
+ OutputField(name="output1", description="test", type="str"),
526
+ OutputField(name="output2", description="test", type="int"),
527
+ ],
528
+ )
529
+ step2 = create_step_with_fields(
530
+ "step2",
531
+ [InputField(name="input", description="test", variable="step1.output1")],
532
+ [OutputField(name="output", description="test", type="str")],
533
+ )
534
+
535
+ workflow = create_basic_workflow([step1, step2])
536
+ workflow.inputs = ["input"]
537
+ workflow.outputs = {"output1": "step1.output1", "output2": "step1.output2", "output3": "step2.output"}
538
+ assert validator.validate(workflow)
539
+ assert len(validator.errors) == 0
540
+
541
+ def test_complex_dependencies(self):
542
+ """Test validation of workflow with complex dependencies"""
543
+ validator = WorkflowValidator()
544
+
545
+ # Create a workflow with complex dependencies
546
+ step1 = create_step_with_fields(
547
+ "step1",
548
+ [InputField(name="input", description="test", variable="external_input")],
549
+ [OutputField(name="output", description="test", type="str")],
550
+ )
551
+ step2 = create_step_with_fields(
552
+ "step2",
553
+ [InputField(name="input", description="test", variable="step1.output")],
554
+ [OutputField(name="output", description="test", type="str")],
555
+ )
556
+ step3 = create_step_with_fields(
557
+ "step3",
558
+ [
559
+ InputField(name="input1", description="test", variable="step1.output"),
560
+ InputField(name="input2", description="test", variable="step2.output"),
561
+ ],
562
+ [OutputField(name="output", description="test", type="str")],
563
+ )
564
+
565
+ workflow = create_basic_workflow([step1, step2, step3])
566
+ workflow.inputs = ["external_input"]
567
+ workflow.outputs = {"output": "step3.output"}
568
+ assert validator.validate(workflow)
569
+ assert len(validator.errors) == 0
570
+
571
+
572
+ # External Input Tests
573
+ class TestExternalInputs:
574
+ def test_external_input_existence(self):
575
+ """Test validation of external input existence"""
576
+ validator = WorkflowValidator()
577
+
578
+ # Test missing external input
579
+ step = create_step_with_fields(
580
+ "step1", [InputField(name="input", description="test", variable="missing_input")], []
581
+ )
582
+ workflow = create_basic_workflow([step])
583
+ workflow.inputs = ["valid_input"]
584
+ workflow.outputs = {"output": "step1.output"}
585
+ assert not validator.validate(workflow)
586
+ assert len(validator.errors) == 1
587
+ assert validator.errors[0].error_type == ValidationErrorType.VARIABLE
588
+
589
+ def test_external_input_naming_conflicts(self):
590
+ """Test validation of external input naming conflicts"""
591
+ validator = WorkflowValidator()
592
+
593
+ # Test conflict between external input and step output
594
+ step = create_step_with_fields("step1", [], [OutputField(name="output", description="test", type="str")])
595
+ workflow = create_basic_workflow([step])
596
+ workflow.inputs = ["step1.output"] # Conflict with step output
597
+ workflow.outputs = {"output": "step1.output"}
598
+ assert not validator.validate(workflow)
599
+ assert len(validator.errors) == 1
600
+ assert validator.errors[0].error_type == ValidationErrorType.VARIABLE
601
+
602
+
603
+ # Edge Cases
604
+ class TestEdgeCases:
605
+ def test_empty_workflow_with_inputs(self):
606
+ """Test validation of empty workflow with inputs"""
607
+ validator = WorkflowValidator()
608
+ workflow = Workflow(inputs=["input"], outputs={}, steps={})
609
+ assert not validator.validate(workflow)
610
+ assert len(validator.errors) == 1
611
+ assert validator.errors[0].error_type == ValidationErrorType.GENERAL
612
+
613
+ def test_workflow_with_empty_outputs(self):
614
+ """Test validation of workflow with empty outputs"""
615
+ validator = WorkflowValidator()
616
+ workflow = create_valid_workflow()
617
+ workflow.outputs = {} # Empty output reference
618
+ assert not validator.validate(workflow)
619
+ assert len(validator.errors) == 1
620
+ assert validator.errors[0].error_type == ValidationErrorType.GENERAL
621
+
622
+ def test_workflow_with_none_outputs(self):
623
+ """Test validation of workflow with empty outputs"""
624
+ validator = WorkflowValidator()
625
+ workflow = create_valid_workflow()
626
+ workflow.outputs = {"output": None} # Empty output reference
627
+ assert not validator.validate(workflow)
628
+ assert len(validator.errors) == 1
629
+ assert validator.errors[0].error_type == ValidationErrorType.GENERAL
630
+
631
+ def test_workflow_with_duplicate_output_names(self):
632
+ """Test validation of workflow with duplicate output names"""
633
+ validator = WorkflowValidator()
634
+ step = create_step_with_fields(
635
+ "step1",
636
+ [],
637
+ [
638
+ OutputField(name="output", description="test", type="str"),
639
+ OutputField(name="output", description="test", type="str"),
640
+ ],
641
+ )
642
+ workflow = create_basic_workflow([step])
643
+ workflow.inputs = ["input"]
644
+ workflow.outputs = {"output": "step1.output"}
645
+ assert not validator.validate(workflow)
646
+ assert len(validator.errors) == 1
647
+ assert validator.errors[0].error_type == ValidationErrorType.STEP