Maharshi Gor
commited on
Commit
·
0bab47c
1
Parent(s):
e1ce295
Refactors workflow management and model configurations
Browse filesMoves llms.py to the workflows directory and updates imports.
Adds logprobs support to Cohere models and renames buzz_threshold to confidence_threshold for clarity.
Enhances PipelineStateManager to include buzzer configuration and implements cycle detection in workflow utilities for improved stability.
Relates to previous refactoring efforts.
- app.py +4 -3
- src/app_configs.py +5 -2
- src/components/model_pipeline/model_pipeline.py +3 -3
- src/components/model_pipeline/state_manager.py +4 -1
- src/components/model_pipeline/tossup_pipeline.py +171 -0
- src/components/quizbowl/tossup.py +58 -35
- src/workflows/README.md +58 -21
- src/workflows/configs.py +51 -0
- src/workflows/executors.py +254 -98
- src/workflows/factory.py +19 -5
- src/{llms.py → workflows/llms.py} +1 -1
- src/workflows/qb_agents.py +41 -26
- src/workflows/structs.py +49 -2
- src/workflows/utils.py +37 -3
- src/workflows/validators.py +62 -94
- tests/test_executors.py +233 -87
- tests/test_validators.py +173 -57
app.py
CHANGED
@@ -3,7 +3,7 @@ import gradio as gr
|
|
3 |
from apscheduler.schedulers.background import BackgroundScheduler
|
4 |
from huggingface_hub import snapshot_download
|
5 |
|
6 |
-
from app_configs import
|
7 |
from components.quizbowl.bonus import BonusInterface
|
8 |
from components.quizbowl.tossup import TossupInterface
|
9 |
from display.custom_css import css_bonus, css_pipeline, css_tossup
|
@@ -21,6 +21,7 @@ from envs import (
|
|
21 |
TOKEN,
|
22 |
)
|
23 |
from workflows import factory
|
|
|
24 |
|
25 |
|
26 |
def restart_space():
|
@@ -164,12 +165,12 @@ if __name__ == "__main__":
|
|
164 |
with gr.Tabs():
|
165 |
with gr.Tab("Tossup Agents"):
|
166 |
defaults = DEFAULT_SELECTIONS["tossup"] | {
|
167 |
-
"init_workflow": factory.
|
168 |
}
|
169 |
tossup_interface = TossupInterface(demo, tossup_ds, AVAILABLE_MODELS, defaults)
|
170 |
with gr.Tab("Bonus Round Agents"):
|
171 |
defaults = DEFAULT_SELECTIONS["bonus"] | {
|
172 |
-
"init_workflow": factory.
|
173 |
}
|
174 |
bonus_interface = BonusInterface(demo, bonus_ds, AVAILABLE_MODELS, defaults)
|
175 |
|
|
|
3 |
from apscheduler.schedulers.background import BackgroundScheduler
|
4 |
from huggingface_hub import snapshot_download
|
5 |
|
6 |
+
from app_configs import DEFAULT_SELECTIONS, THEME
|
7 |
from components.quizbowl.bonus import BonusInterface
|
8 |
from components.quizbowl.tossup import TossupInterface
|
9 |
from display.custom_css import css_bonus, css_pipeline, css_tossup
|
|
|
21 |
TOKEN,
|
22 |
)
|
23 |
from workflows import factory
|
24 |
+
from workflows.configs import AVAILABLE_MODELS
|
25 |
|
26 |
|
27 |
def restart_space():
|
|
|
165 |
with gr.Tabs():
|
166 |
with gr.Tab("Tossup Agents"):
|
167 |
defaults = DEFAULT_SELECTIONS["tossup"] | {
|
168 |
+
"init_workflow": factory.create_simple_qb_tossup_workflow(),
|
169 |
}
|
170 |
tossup_interface = TossupInterface(demo, tossup_ds, AVAILABLE_MODELS, defaults)
|
171 |
with gr.Tab("Bonus Round Agents"):
|
172 |
defaults = DEFAULT_SELECTIONS["bonus"] | {
|
173 |
+
"init_workflow": factory.create_simple_qb_bonus_workflow(),
|
174 |
}
|
175 |
bonus_interface = BonusInterface(demo, bonus_ds, AVAILABLE_MODELS, defaults)
|
176 |
|
src/app_configs.py
CHANGED
@@ -23,12 +23,15 @@ AVAILABLE_MODELS = {
|
|
23 |
},
|
24 |
"Cohere/command-r": {
|
25 |
"model": "command-r-08-2024",
|
|
|
26 |
},
|
27 |
"Cohere/command-r-plus": {
|
28 |
"model": "command-r-plus-08-2024",
|
|
|
29 |
},
|
30 |
"Cohere/command-r7b": {
|
31 |
"model": "command-r7b-12-2024",
|
|
|
32 |
},
|
33 |
}
|
34 |
|
@@ -37,14 +40,14 @@ DEFAULT_SELECTIONS = {
|
|
37 |
"simple_workflow": False,
|
38 |
"model": "OpenAI/gpt-4o-mini",
|
39 |
"temperature": 0.2,
|
40 |
-
"
|
41 |
"early_stop": True,
|
42 |
},
|
43 |
"bonus": {
|
44 |
"simple_workflow": False,
|
45 |
"model": "OpenAI/gpt-4o-mini",
|
46 |
"temperature": 0.2,
|
47 |
-
"
|
48 |
"early_stop": True,
|
49 |
},
|
50 |
}
|
|
|
23 |
},
|
24 |
"Cohere/command-r": {
|
25 |
"model": "command-r-08-2024",
|
26 |
+
"logprobs": True,
|
27 |
},
|
28 |
"Cohere/command-r-plus": {
|
29 |
"model": "command-r-plus-08-2024",
|
30 |
+
"logprobs": True,
|
31 |
},
|
32 |
"Cohere/command-r7b": {
|
33 |
"model": "command-r7b-12-2024",
|
34 |
+
"logprobs": False,
|
35 |
},
|
36 |
}
|
37 |
|
|
|
40 |
"simple_workflow": False,
|
41 |
"model": "OpenAI/gpt-4o-mini",
|
42 |
"temperature": 0.2,
|
43 |
+
"confidence_threshold": 0.85,
|
44 |
"early_stop": True,
|
45 |
},
|
46 |
"bonus": {
|
47 |
"simple_workflow": False,
|
48 |
"model": "OpenAI/gpt-4o-mini",
|
49 |
"temperature": 0.2,
|
50 |
+
"confidence_threshold": 0.85,
|
51 |
"early_stop": True,
|
52 |
},
|
53 |
}
|
src/components/model_pipeline/model_pipeline.py
CHANGED
@@ -156,10 +156,10 @@ class PipelineInterface:
|
|
156 |
)
|
157 |
return add_step_btn
|
158 |
|
159 |
-
def
|
160 |
dropdowns = {}
|
161 |
variable_options = [UNSELECTED_VAR_NAME] + [v for v in available_variables if v not in self.input_variables]
|
162 |
-
with gr.Column(elem_classes="step-accordion"):
|
163 |
with gr.Row(elem_classes="output-fields-header"):
|
164 |
gr.Markdown("#### Final output variables mapping:")
|
165 |
with gr.Row(elem_classes="output-fields-row"):
|
@@ -260,7 +260,7 @@ class PipelineInterface:
|
|
260 |
concurrency_id="render_output_fields",
|
261 |
)
|
262 |
def render_output_fields(available_variables, pipeline_state):
|
263 |
-
self.
|
264 |
|
265 |
export_btn = gr.Button("Export Pipeline", elem_classes="export-button")
|
266 |
# components.append(export_btn)
|
|
|
156 |
)
|
157 |
return add_step_btn
|
158 |
|
159 |
+
def _render_output_panel(self, available_variables: list[str], pipeline_state: PipelineState):
|
160 |
dropdowns = {}
|
161 |
variable_options = [UNSELECTED_VAR_NAME] + [v for v in available_variables if v not in self.input_variables]
|
162 |
+
with gr.Column(elem_classes="step-accordion control-panel"):
|
163 |
with gr.Row(elem_classes="output-fields-header"):
|
164 |
gr.Markdown("#### Final output variables mapping:")
|
165 |
with gr.Row(elem_classes="output-fields-row"):
|
|
|
260 |
concurrency_id="render_output_fields",
|
261 |
)
|
262 |
def render_output_fields(available_variables, pipeline_state):
|
263 |
+
self._render_output_panel(available_variables, pipeline_state)
|
264 |
|
265 |
export_btn = gr.Button("Export Pipeline", elem_classes="export-button")
|
266 |
# components.append(export_btn)
|
src/components/model_pipeline/state_manager.py
CHANGED
@@ -8,7 +8,7 @@ from pydantic import BaseModel, Field
|
|
8 |
|
9 |
from components import utils
|
10 |
from workflows.factory import create_new_llm_step
|
11 |
-
from workflows.structs import ModelStep, Workflow
|
12 |
|
13 |
|
14 |
def make_step_id(step_number: int):
|
@@ -133,6 +133,9 @@ class PipelineStateManager:
|
|
133 |
def get_formatted_config(self, state: PipelineState, format: Literal["json", "yaml"] = "yaml"):
|
134 |
"""Get the full pipeline configuration."""
|
135 |
config = state.workflow.model_dump(exclude_defaults=True)
|
|
|
|
|
|
|
136 |
if format == "yaml":
|
137 |
return yaml.dump(config, default_flow_style=False, sort_keys=False, indent=4)
|
138 |
else:
|
|
|
8 |
|
9 |
from components import utils
|
10 |
from workflows.factory import create_new_llm_step
|
11 |
+
from workflows.structs import ModelStep, TossupWorkflow, Workflow
|
12 |
|
13 |
|
14 |
def make_step_id(step_number: int):
|
|
|
133 |
def get_formatted_config(self, state: PipelineState, format: Literal["json", "yaml"] = "yaml"):
|
134 |
"""Get the full pipeline configuration."""
|
135 |
config = state.workflow.model_dump(exclude_defaults=True)
|
136 |
+
if isinstance(state.workflow, TossupWorkflow):
|
137 |
+
buzzer_config = state.workflow.buzzer.model_dump(exclude_defaults=False)
|
138 |
+
config["buzzer"] = buzzer_config
|
139 |
if format == "yaml":
|
140 |
return yaml.dump(config, default_flow_style=False, sort_keys=False, indent=4)
|
141 |
else:
|
src/components/model_pipeline/tossup_pipeline.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from app_configs import AVAILABLE_MODELS, UNSELECTED_VAR_NAME
|
5 |
+
from workflows.structs import Buzzer, TossupWorkflow
|
6 |
+
|
7 |
+
from .model_pipeline import PipelineInterface, PipelineState, PipelineUIState
|
8 |
+
|
9 |
+
|
10 |
+
def toggleable_slider(
|
11 |
+
value, minimum, maximum, step, toggle_value=False, label=None, info=None, min_width=200, scale=1
|
12 |
+
):
|
13 |
+
with gr.Column(elem_classes="toggleable", min_width=min_width, scale=scale):
|
14 |
+
show_label = label is not None
|
15 |
+
checkbox = gr.Checkbox(label=label, value=toggle_value, container=False, info=info, show_label=show_label)
|
16 |
+
slider = gr.Slider(
|
17 |
+
minimum=minimum,
|
18 |
+
maximum=maximum,
|
19 |
+
value=value,
|
20 |
+
step=step,
|
21 |
+
label="",
|
22 |
+
interactive=True,
|
23 |
+
show_label=False,
|
24 |
+
container=False,
|
25 |
+
)
|
26 |
+
checkbox.change(fn=lambda x: gr.update(interactive=x), inputs=[checkbox], outputs=[slider])
|
27 |
+
return checkbox, slider
|
28 |
+
|
29 |
+
|
30 |
+
class TossupPipelineState(PipelineState):
|
31 |
+
workflow: TossupWorkflow
|
32 |
+
|
33 |
+
|
34 |
+
class TossupPipelineInterface(PipelineInterface):
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
workflow: TossupWorkflow,
|
38 |
+
ui_state: PipelineUIState | None = None,
|
39 |
+
model_options: list[str] = None,
|
40 |
+
simple: bool = False,
|
41 |
+
show_pipeline_selector: bool = False,
|
42 |
+
defaults: dict = {},
|
43 |
+
):
|
44 |
+
super().__init__(workflow, ui_state, model_options, simple, show_pipeline_selector)
|
45 |
+
self.defaults = defaults
|
46 |
+
|
47 |
+
def update_buzzer(
|
48 |
+
self,
|
49 |
+
state: TossupPipelineState,
|
50 |
+
confidence_threshold: float,
|
51 |
+
method: str,
|
52 |
+
tokens_prob: float | None,
|
53 |
+
):
|
54 |
+
"""Update the buzzer."""
|
55 |
+
if tokens_prob and tokens_prob > 1e-5:
|
56 |
+
log_prob_thresh = float(np.log(tokens_prob)) if tokens_prob > 0 else None
|
57 |
+
else:
|
58 |
+
log_prob_thresh = None
|
59 |
+
state.workflow.buzzer = state.workflow.buzzer.model_copy(
|
60 |
+
update={
|
61 |
+
"method": method,
|
62 |
+
"confidence_threshold": confidence_threshold,
|
63 |
+
"log_prob_threshold": log_prob_thresh,
|
64 |
+
}
|
65 |
+
)
|
66 |
+
Buzzer.model_validate(state.workflow.buzzer)
|
67 |
+
return state
|
68 |
+
|
69 |
+
def update_prob_slider(self, state: TossupPipelineState, answer_var: str, tokens_prob: float | None):
|
70 |
+
"""Update the probability slider based on the answer variable."""
|
71 |
+
if answer_var == UNSELECTED_VAR_NAME:
|
72 |
+
return gr.update(interactive=True)
|
73 |
+
step_id = answer_var.split(".")[0]
|
74 |
+
model_name = state.workflow.steps[step_id].model
|
75 |
+
model_config = AVAILABLE_MODELS[model_name]
|
76 |
+
is_model_with_logprobs = model_config.get("logprobs", False)
|
77 |
+
buzzer = state.workflow.buzzer
|
78 |
+
tokens_prob_threshold = tokens_prob if is_model_with_logprobs else None
|
79 |
+
state = self.update_buzzer(
|
80 |
+
state,
|
81 |
+
confidence_threshold=buzzer.confidence_threshold,
|
82 |
+
method=buzzer.method,
|
83 |
+
tokens_prob=tokens_prob_threshold,
|
84 |
+
)
|
85 |
+
return state, gr.update(interactive=not is_model_with_logprobs)
|
86 |
+
|
87 |
+
def _render_output_panel(self, available_variables: list[str], pipeline_state: TossupPipelineState):
|
88 |
+
dropdowns = {}
|
89 |
+
variable_options = [UNSELECTED_VAR_NAME] + [v for v in available_variables if v not in self.input_variables]
|
90 |
+
with gr.Column(elem_classes="step-accordion control-panel"):
|
91 |
+
with gr.Row(elem_classes="output-fields-header"):
|
92 |
+
gr.Markdown("#### Final output variables mapping:")
|
93 |
+
with gr.Row(elem_classes="output-fields-row"):
|
94 |
+
for output_field in self.required_output_variables:
|
95 |
+
value = pipeline_state.workflow.outputs.get(output_field, UNSELECTED_VAR_NAME)
|
96 |
+
dropdown = gr.Dropdown(
|
97 |
+
label=output_field,
|
98 |
+
value=value,
|
99 |
+
choices=variable_options,
|
100 |
+
interactive=True,
|
101 |
+
elem_classes="output-field-variable",
|
102 |
+
# show_label=False,
|
103 |
+
)
|
104 |
+
dropdown.change(
|
105 |
+
self.sm.update_output_variables,
|
106 |
+
inputs=[self.pipeline_state, gr.State(output_field), dropdown],
|
107 |
+
outputs=[self.pipeline_state],
|
108 |
+
)
|
109 |
+
dropdowns[output_field] = dropdown
|
110 |
+
with gr.Row(elem_classes="output-fields-header"):
|
111 |
+
gr.Markdown("#### Buzzer settings:")
|
112 |
+
with gr.Row(elem_classes="control-panel"):
|
113 |
+
self.confidence_slider = gr.Slider(
|
114 |
+
minimum=0.0,
|
115 |
+
maximum=1.0,
|
116 |
+
value=self.defaults.get("confidence_threshold", 0.85),
|
117 |
+
step=0.01,
|
118 |
+
label="Confidence Threshold",
|
119 |
+
)
|
120 |
+
self.buzzer_method_dropdown = gr.Dropdown(
|
121 |
+
choices=["AND", "OR"],
|
122 |
+
value=self.defaults.get("buzzer_method", "AND"),
|
123 |
+
label="Method",
|
124 |
+
interactive=True,
|
125 |
+
min_width=80,
|
126 |
+
scale=0,
|
127 |
+
)
|
128 |
+
self.prob_slider = gr.Slider(
|
129 |
+
value=self.defaults.get("logits_prob", 0.0),
|
130 |
+
label="Probability threshold",
|
131 |
+
minimum=0.0,
|
132 |
+
maximum=1.0,
|
133 |
+
step=0.001,
|
134 |
+
)
|
135 |
+
|
136 |
+
def update_choices(available_variables):
|
137 |
+
"""Update the choices for the dropdowns"""
|
138 |
+
return [
|
139 |
+
gr.update(choices=available_variables, value=None, selected=None) for dropdown in dropdowns.values()
|
140 |
+
]
|
141 |
+
|
142 |
+
self.variables_state.change(
|
143 |
+
update_choices,
|
144 |
+
inputs=[self.variables_state],
|
145 |
+
outputs=list(dropdowns.values()),
|
146 |
+
)
|
147 |
+
|
148 |
+
gr.on(
|
149 |
+
triggers=[
|
150 |
+
self.confidence_slider.input,
|
151 |
+
self.buzzer_method_dropdown.input,
|
152 |
+
self.prob_slider.input,
|
153 |
+
],
|
154 |
+
fn=self.update_buzzer,
|
155 |
+
inputs=[
|
156 |
+
self.pipeline_state,
|
157 |
+
self.confidence_slider,
|
158 |
+
self.buzzer_method_dropdown,
|
159 |
+
self.prob_slider,
|
160 |
+
],
|
161 |
+
outputs=[self.pipeline_state],
|
162 |
+
)
|
163 |
+
|
164 |
+
# TODO: Do Add model step change triggers as well. (Model name change triggers)
|
165 |
+
answer_dropdown = dropdowns["answer"]
|
166 |
+
if answer_dropdown is not None:
|
167 |
+
answer_dropdown.change(
|
168 |
+
self.update_prob_slider,
|
169 |
+
inputs=[self.pipeline_state, answer_dropdown, self.prob_slider],
|
170 |
+
outputs=[self.pipeline_state, self.prob_slider],
|
171 |
+
)
|
src/components/quizbowl/tossup.py
CHANGED
@@ -8,10 +8,11 @@ from datasets import Dataset
|
|
8 |
from loguru import logger
|
9 |
|
10 |
from components.model_pipeline.model_pipeline import PipelineInterface, PipelineState, PipelineUIState
|
|
|
11 |
from display.formatting import styled_error
|
12 |
from submission import submit
|
13 |
-
from workflows.qb_agents import QuizBowlTossupAgent
|
14 |
-
from workflows.structs import ModelStep,
|
15 |
|
16 |
from . import commons
|
17 |
from .plotting import (
|
@@ -26,6 +27,13 @@ from .utils import evaluate_prediction
|
|
26 |
# TODO: ^^ Same for Bonus
|
27 |
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
def add_model_scores(model_outputs: list[dict], clean_answers: list[str], run_indices: list[int]) -> list[dict]:
|
30 |
"""Add model scores to the model outputs."""
|
31 |
for output, run_idx in zip(model_outputs, run_indices):
|
@@ -90,12 +98,12 @@ def process_tossup_results(results: list[dict], top_k_mode: bool = False) -> pd.
|
|
90 |
)
|
91 |
|
92 |
|
93 |
-
def validate_workflow(workflow:
|
94 |
"""
|
95 |
Validate that a workflow is properly configured for the tossup task.
|
96 |
|
97 |
Args:
|
98 |
-
workflow (
|
99 |
|
100 |
Raises:
|
101 |
ValueError: If the workflow is not properly configured
|
@@ -180,40 +188,36 @@ class TossupInterface:
|
|
180 |
self.output_state = gr.State(value="{}")
|
181 |
self.render()
|
182 |
|
183 |
-
def
|
184 |
"""Render the model interface."""
|
185 |
with gr.Row():
|
186 |
self.model_selector = commons.get_pipeline_selector([])
|
187 |
-
self.pipeline_interface =
|
188 |
workflow,
|
189 |
simple=simple,
|
190 |
model_options=list(self.model_options.keys()),
|
|
|
191 |
)
|
192 |
-
with gr.Row():
|
193 |
-
self.buzz_t_slider = gr.Slider(
|
194 |
-
minimum=0.5,
|
195 |
-
maximum=1.0,
|
196 |
-
value=self.defaults["buzz_threshold"],
|
197 |
-
step=0.01,
|
198 |
-
label="Buzz Threshold",
|
199 |
-
)
|
200 |
-
self.early_stop_checkbox = gr.Checkbox(
|
201 |
-
value=self.defaults["early_stop"],
|
202 |
-
label="Early Stop",
|
203 |
-
info="Stop early if already buzzed",
|
204 |
-
)
|
205 |
|
206 |
def _render_qb_interface(self):
|
207 |
"""Render the quizbowl interface."""
|
208 |
with gr.Row(elem_classes="bonus-header-row form-inline"):
|
209 |
self.qid_selector = commons.get_qid_selector(len(self.ds))
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
self.run_btn = gr.Button("Run on Tossup Question", variant="secondary")
|
211 |
self.question_display = gr.HTML(label="Question", elem_id="tossup-question-display")
|
|
|
212 |
with gr.Row():
|
213 |
self.confidence_plot = gr.Plot(
|
214 |
label="Buzz Confidence",
|
215 |
format="webp",
|
216 |
)
|
|
|
217 |
self.results_table = gr.DataFrame(
|
218 |
label="Model Outputs",
|
219 |
value=pd.DataFrame(columns=["Token Position", "Correct?", "Confidence", "Prediction"]),
|
@@ -240,7 +244,7 @@ class TossupInterface:
|
|
240 |
with gr.Row():
|
241 |
# Model Panel
|
242 |
with gr.Column(scale=1):
|
243 |
-
self.
|
244 |
|
245 |
with gr.Column(scale=1):
|
246 |
self._render_qb_interface()
|
@@ -268,13 +272,15 @@ class TossupInterface:
|
|
268 |
except Exception as e:
|
269 |
return f"Error loading question: {str(e)}"
|
270 |
|
271 |
-
def get_model_outputs(
|
|
|
|
|
272 |
"""Get the model outputs for a given question ID."""
|
273 |
question_runs = []
|
274 |
tokens = example["question"].split()
|
275 |
for run_idx in example["run_indices"]:
|
276 |
question_runs.append(" ".join(tokens[: run_idx + 1]))
|
277 |
-
agent = QuizBowlTossupAgent(pipeline_state.workflow
|
278 |
outputs = list(agent.run(question_runs, early_stop=early_stop))
|
279 |
outputs = add_model_scores(outputs, example["clean_answers"], example["run_indices"])
|
280 |
return outputs
|
@@ -297,7 +303,6 @@ class TossupInterface:
|
|
297 |
self,
|
298 |
question_id: int,
|
299 |
pipeline_state: PipelineState,
|
300 |
-
buzz_threshold: float,
|
301 |
early_stop: bool = True,
|
302 |
) -> tuple[str, Any, Any]:
|
303 |
"""Run the agent in tossup mode with a system prompt."""
|
@@ -307,24 +312,34 @@ class TossupInterface:
|
|
307 |
if not self.ds or question_id < 0 or question_id >= len(self.ds):
|
308 |
return "Invalid question ID or dataset not loaded", None, None
|
309 |
example = self.ds[question_id]
|
310 |
-
outputs = self.get_model_outputs(example, pipeline_state,
|
311 |
|
312 |
# Process results and prepare visualization data
|
313 |
tokens_html, plot_data, output_state = initialize_eval_interface(example, outputs)
|
314 |
df = process_tossup_results(outputs)
|
|
|
315 |
return (
|
316 |
tokens_html,
|
317 |
-
gr.update(value=plot_data, label=f"Buzz Confidence on Question {question_id + 1}"),
|
318 |
gr.update(value=output_state),
|
|
|
319 |
gr.update(value=df, label=f"Model Outputs for Question {question_id + 1}"),
|
|
|
|
|
320 |
)
|
321 |
except Exception as e:
|
322 |
import traceback
|
323 |
|
324 |
-
error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
|
325 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
326 |
|
327 |
-
def evaluate(self, pipeline_state: PipelineState,
|
328 |
"""Evaluate the tossup questions."""
|
329 |
try:
|
330 |
# Validate inputs
|
@@ -336,7 +351,7 @@ class TossupInterface:
|
|
336 |
token_positions = []
|
337 |
correctness = []
|
338 |
for example in progress.tqdm(self.ds, desc="Evaluating tossup questions"):
|
339 |
-
model_outputs = self.get_model_outputs(example, pipeline_state,
|
340 |
if model_outputs[-1]["buzz"]:
|
341 |
buzz_counts += 1
|
342 |
if model_outputs[-1]["score"] == 1:
|
@@ -355,12 +370,19 @@ class TossupInterface:
|
|
355 |
)
|
356 |
plot_data = create_scatter_pyplot(token_positions, correctness)
|
357 |
return (
|
358 |
-
gr.update(value=df, label="Scores on Sample Set"),
|
359 |
gr.update(value=plot_data, label="Buzz Positions on Sample Set"),
|
|
|
|
|
360 |
)
|
361 |
except Exception as e:
|
|
|
|
|
362 |
logger.exception(f"Error evaluating tossups: {e.args}")
|
363 |
-
return
|
|
|
|
|
|
|
|
|
364 |
|
365 |
def submit_model(
|
366 |
self, model_name: str, description: str, pipeline_state: PipelineState, profile: gr.OAuthProfile = None
|
@@ -401,21 +423,22 @@ class TossupInterface:
|
|
401 |
inputs=[
|
402 |
self.qid_selector,
|
403 |
self.pipeline_interface.pipeline_state,
|
404 |
-
self.buzz_t_slider,
|
405 |
self.early_stop_checkbox,
|
406 |
],
|
407 |
outputs=[
|
408 |
self.question_display,
|
409 |
-
self.confidence_plot,
|
410 |
self.output_state,
|
|
|
411 |
self.results_table,
|
|
|
|
|
412 |
],
|
413 |
)
|
414 |
|
415 |
self.eval_btn.click(
|
416 |
fn=self.evaluate,
|
417 |
-
inputs=[self.pipeline_interface.pipeline_state
|
418 |
-
outputs=[self.results_table, self.
|
419 |
)
|
420 |
|
421 |
self.submit_btn.click(
|
|
|
8 |
from loguru import logger
|
9 |
|
10 |
from components.model_pipeline.model_pipeline import PipelineInterface, PipelineState, PipelineUIState
|
11 |
+
from components.model_pipeline.tossup_pipeline import TossupPipelineInterface, TossupPipelineState
|
12 |
from display.formatting import styled_error
|
13 |
from submission import submit
|
14 |
+
from workflows.qb_agents import QuizBowlTossupAgent, TossupResult
|
15 |
+
from workflows.structs import ModelStep, TossupWorkflow
|
16 |
|
17 |
from . import commons
|
18 |
from .plotting import (
|
|
|
27 |
# TODO: ^^ Same for Bonus
|
28 |
|
29 |
|
30 |
+
class ScoredTossupResult(TossupResult):
|
31 |
+
"""Result of a tossup question with evaluation score and position."""
|
32 |
+
|
33 |
+
score: int # Correctness score of the answer
|
34 |
+
token_position: int # Position in the question where prediction was made
|
35 |
+
|
36 |
+
|
37 |
def add_model_scores(model_outputs: list[dict], clean_answers: list[str], run_indices: list[int]) -> list[dict]:
|
38 |
"""Add model scores to the model outputs."""
|
39 |
for output, run_idx in zip(model_outputs, run_indices):
|
|
|
98 |
)
|
99 |
|
100 |
|
101 |
+
def validate_workflow(workflow: TossupWorkflow):
|
102 |
"""
|
103 |
Validate that a workflow is properly configured for the tossup task.
|
104 |
|
105 |
Args:
|
106 |
+
workflow (TossupWorkflow): The workflow to validate
|
107 |
|
108 |
Raises:
|
109 |
ValueError: If the workflow is not properly configured
|
|
|
188 |
self.output_state = gr.State(value="{}")
|
189 |
self.render()
|
190 |
|
191 |
+
def _render_pipeline_interface(self, workflow: TossupWorkflow, simple: bool = True):
|
192 |
"""Render the model interface."""
|
193 |
with gr.Row():
|
194 |
self.model_selector = commons.get_pipeline_selector([])
|
195 |
+
self.pipeline_interface = TossupPipelineInterface(
|
196 |
workflow,
|
197 |
simple=simple,
|
198 |
model_options=list(self.model_options.keys()),
|
199 |
+
defaults=self.defaults,
|
200 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
|
202 |
def _render_qb_interface(self):
|
203 |
"""Render the quizbowl interface."""
|
204 |
with gr.Row(elem_classes="bonus-header-row form-inline"):
|
205 |
self.qid_selector = commons.get_qid_selector(len(self.ds))
|
206 |
+
self.early_stop_checkbox = gr.Checkbox(
|
207 |
+
value=self.defaults["early_stop"],
|
208 |
+
label="Early Stop",
|
209 |
+
info="Stop if already buzzed",
|
210 |
+
scale=0,
|
211 |
+
)
|
212 |
self.run_btn = gr.Button("Run on Tossup Question", variant="secondary")
|
213 |
self.question_display = gr.HTML(label="Question", elem_id="tossup-question-display")
|
214 |
+
self.error_display = gr.HTML(label="Error", elem_id="tossup-error-display", visible=False)
|
215 |
with gr.Row():
|
216 |
self.confidence_plot = gr.Plot(
|
217 |
label="Buzz Confidence",
|
218 |
format="webp",
|
219 |
)
|
220 |
+
self.model_outputs_display = gr.JSON(label="Model Outputs", value="{}", visible=False)
|
221 |
self.results_table = gr.DataFrame(
|
222 |
label="Model Outputs",
|
223 |
value=pd.DataFrame(columns=["Token Position", "Correct?", "Confidence", "Prediction"]),
|
|
|
244 |
with gr.Row():
|
245 |
# Model Panel
|
246 |
with gr.Column(scale=1):
|
247 |
+
self._render_pipeline_interface(workflow, simple=self.defaults["simple_workflow"])
|
248 |
|
249 |
with gr.Column(scale=1):
|
250 |
self._render_qb_interface()
|
|
|
272 |
except Exception as e:
|
273 |
return f"Error loading question: {str(e)}"
|
274 |
|
275 |
+
def get_model_outputs(
|
276 |
+
self, example: dict, pipeline_state: PipelineState, early_stop: bool
|
277 |
+
) -> list[ScoredTossupResult]:
|
278 |
"""Get the model outputs for a given question ID."""
|
279 |
question_runs = []
|
280 |
tokens = example["question"].split()
|
281 |
for run_idx in example["run_indices"]:
|
282 |
question_runs.append(" ".join(tokens[: run_idx + 1]))
|
283 |
+
agent = QuizBowlTossupAgent(pipeline_state.workflow)
|
284 |
outputs = list(agent.run(question_runs, early_stop=early_stop))
|
285 |
outputs = add_model_scores(outputs, example["clean_answers"], example["run_indices"])
|
286 |
return outputs
|
|
|
303 |
self,
|
304 |
question_id: int,
|
305 |
pipeline_state: PipelineState,
|
|
|
306 |
early_stop: bool = True,
|
307 |
) -> tuple[str, Any, Any]:
|
308 |
"""Run the agent in tossup mode with a system prompt."""
|
|
|
312 |
if not self.ds or question_id < 0 or question_id >= len(self.ds):
|
313 |
return "Invalid question ID or dataset not loaded", None, None
|
314 |
example = self.ds[question_id]
|
315 |
+
outputs = self.get_model_outputs(example, pipeline_state, early_stop)
|
316 |
|
317 |
# Process results and prepare visualization data
|
318 |
tokens_html, plot_data, output_state = initialize_eval_interface(example, outputs)
|
319 |
df = process_tossup_results(outputs)
|
320 |
+
step_outputs = [output["step_outputs"] for output in outputs]
|
321 |
return (
|
322 |
tokens_html,
|
|
|
323 |
gr.update(value=output_state),
|
324 |
+
gr.update(value=plot_data, label=f"Buzz Confidence on Question {question_id + 1}"),
|
325 |
gr.update(value=df, label=f"Model Outputs for Question {question_id + 1}"),
|
326 |
+
gr.update(value=step_outputs, label=f"Step Outputs for Question {question_id + 1}", visible=True),
|
327 |
+
gr.update(visible=False),
|
328 |
)
|
329 |
except Exception as e:
|
330 |
import traceback
|
331 |
|
332 |
+
error_msg = styled_error(f"Error: {str(e)}\n{traceback.format_exc()}")
|
333 |
+
return (
|
334 |
+
gr.skip(),
|
335 |
+
gr.skip(),
|
336 |
+
gr.skip(),
|
337 |
+
gr.skip(),
|
338 |
+
gr.update(visible=False),
|
339 |
+
gr.update(visible=True, value=error_msg),
|
340 |
+
)
|
341 |
|
342 |
+
def evaluate(self, pipeline_state: PipelineState, progress: gr.Progress = gr.Progress()):
|
343 |
"""Evaluate the tossup questions."""
|
344 |
try:
|
345 |
# Validate inputs
|
|
|
351 |
token_positions = []
|
352 |
correctness = []
|
353 |
for example in progress.tqdm(self.ds, desc="Evaluating tossup questions"):
|
354 |
+
model_outputs = self.get_model_outputs(example, pipeline_state, early_stop=True)
|
355 |
if model_outputs[-1]["buzz"]:
|
356 |
buzz_counts += 1
|
357 |
if model_outputs[-1]["score"] == 1:
|
|
|
370 |
)
|
371 |
plot_data = create_scatter_pyplot(token_positions, correctness)
|
372 |
return (
|
|
|
373 |
gr.update(value=plot_data, label="Buzz Positions on Sample Set"),
|
374 |
+
gr.update(value=df, label="Scores on Sample Set"),
|
375 |
+
gr.update(visible=False),
|
376 |
)
|
377 |
except Exception as e:
|
378 |
+
import traceback
|
379 |
+
|
380 |
logger.exception(f"Error evaluating tossups: {e.args}")
|
381 |
+
return (
|
382 |
+
gr.skip(),
|
383 |
+
gr.skip(),
|
384 |
+
gr.update(visible=True, value=styled_error(f"Error: {str(e)}\n{traceback.format_exc()}")),
|
385 |
+
)
|
386 |
|
387 |
def submit_model(
|
388 |
self, model_name: str, description: str, pipeline_state: PipelineState, profile: gr.OAuthProfile = None
|
|
|
423 |
inputs=[
|
424 |
self.qid_selector,
|
425 |
self.pipeline_interface.pipeline_state,
|
|
|
426 |
self.early_stop_checkbox,
|
427 |
],
|
428 |
outputs=[
|
429 |
self.question_display,
|
|
|
430 |
self.output_state,
|
431 |
+
self.confidence_plot,
|
432 |
self.results_table,
|
433 |
+
self.model_outputs_display,
|
434 |
+
self.error_display,
|
435 |
],
|
436 |
)
|
437 |
|
438 |
self.eval_btn.click(
|
439 |
fn=self.evaluate,
|
440 |
+
inputs=[self.pipeline_interface.pipeline_state],
|
441 |
+
outputs=[self.confidence_plot, self.results_table, self.error_display],
|
442 |
)
|
443 |
|
444 |
self.submit_btn.click(
|
src/workflows/README.md
CHANGED
@@ -12,25 +12,44 @@ The workflows subpackage enables the creation and execution of workflows where m
|
|
12 |
|
13 |
Contains the core data structures used throughout the workflow system:
|
14 |
|
15 |
-
- `
|
|
|
16 |
- `ModelStep`: Represents a single step in a workflow with input fields, output fields, and model details
|
17 |
- `Workflow`: A collection of ModelSteps with their identifiers
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
### `utils.py`
|
20 |
|
21 |
Provides utility functions for workflow operations:
|
22 |
|
23 |
-
- `_create_variable_step_mapping`: Maps variables to the steps that produce them
|
24 |
- `create_dependency_graph`: Builds a dependency graph representing the execution order constraints
|
25 |
- `topological_sort`: Sorts steps in execution order based on their dependencies
|
|
|
26 |
|
27 |
-
### `
|
28 |
|
29 |
Handles the execution of workflows:
|
30 |
|
31 |
-
-
|
32 |
-
-
|
33 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
### `errors.py`
|
36 |
|
@@ -43,36 +62,50 @@ Defines custom exceptions for workflow-related errors:
|
|
43 |
## Usage Example
|
44 |
|
45 |
```python
|
46 |
-
from workflows.structs import
|
47 |
|
48 |
# Define a workflow with two steps
|
49 |
step1 = ModelStep(
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
54 |
)
|
55 |
|
56 |
step2 = ModelStep(
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
61 |
)
|
62 |
|
63 |
-
workflow = Workflow(
|
|
|
|
|
|
|
|
|
64 |
|
65 |
# Execute the workflow
|
66 |
-
from workflows.
|
67 |
|
68 |
result = execute_workflow(
|
69 |
workflow=workflow,
|
70 |
-
input_values={"
|
|
|
|
|
71 |
)
|
72 |
|
73 |
# Access results
|
74 |
-
|
75 |
-
|
|
|
|
|
76 |
```
|
77 |
|
78 |
## Error Handling
|
@@ -82,6 +115,8 @@ The workflows system provides robust error handling:
|
|
82 |
- Detects cyclic dependencies in workflow definitions
|
83 |
- Validates input/output variable references
|
84 |
- Ensures all required inputs are provided
|
|
|
|
|
85 |
|
86 |
## Extending the Workflows System
|
87 |
|
@@ -89,4 +124,6 @@ To extend the workflows system:
|
|
89 |
|
90 |
1. Add new model step types by extending the `ModelStep` class
|
91 |
2. Create custom field types by extending validation in the execution logic
|
92 |
-
3. Implement additional error types in `errors.py` for specialized error handling
|
|
|
|
|
|
12 |
|
13 |
Contains the core data structures used throughout the workflow system:
|
14 |
|
15 |
+
- `InputField`: Represents an input field with name, description, and variable reference
|
16 |
+
- `OutputField`: Represents an output field with name, type, and description
|
17 |
- `ModelStep`: Represents a single step in a workflow with input fields, output fields, and model details
|
18 |
- `Workflow`: A collection of ModelSteps with their identifiers
|
19 |
+
- `TossupWorkflow`: Specialized workflow for quizbowl tossup questions with buzzing capability
|
20 |
+
|
21 |
+
### `configs.py`
|
22 |
+
|
23 |
+
Provides configuration settings and constants:
|
24 |
+
|
25 |
+
- `AVAILABLE_MODELS`: Supported model configurations from various providers
|
26 |
+
- `TYPE_MAP`: Mapping of supported field types to Python types
|
27 |
+
- `FUNCTION_MAP`: Built-in transformation functions for input/output processing
|
28 |
|
29 |
### `utils.py`
|
30 |
|
31 |
Provides utility functions for workflow operations:
|
32 |
|
|
|
33 |
- `create_dependency_graph`: Builds a dependency graph representing the execution order constraints
|
34 |
- `topological_sort`: Sorts steps in execution order based on their dependencies
|
35 |
+
- `detect_cycles`: Identifies cyclic dependencies in workflow definitions
|
36 |
|
37 |
+
### `executors.py`
|
38 |
|
39 |
Handles the execution of workflows:
|
40 |
|
41 |
+
- `execute_model_step`: Executes a single model step with input processing and output collection
|
42 |
+
- `execute_simple_workflow`: Handles single-step workflows
|
43 |
+
- `execute_multi_step_workflow`: Manages multi-step workflows with dependency resolution
|
44 |
+
- `execute_workflow`: Main entry point that routes to appropriate executor based on workflow complexity
|
45 |
+
|
46 |
+
### `validators.py`
|
47 |
+
|
48 |
+
Provides workflow validation functionality:
|
49 |
+
|
50 |
+
- `ValidationErrorType`: Enumeration of possible validation error types
|
51 |
+
- `WorkflowValidationError`: Base class for validation errors
|
52 |
+
- Validation functions for steps, DAGs, variables, and types
|
53 |
|
54 |
### `errors.py`
|
55 |
|
|
|
62 |
## Usage Example
|
63 |
|
64 |
```python
|
65 |
+
from workflows.structs import InputField, ModelStep, OutputField, Workflow
|
66 |
|
67 |
# Define a workflow with two steps
|
68 |
step1 = ModelStep(
|
69 |
+
id="step1",
|
70 |
+
model="gpt-4o-mini",
|
71 |
+
provider="OpenAI",
|
72 |
+
call_type="llm",
|
73 |
+
system_prompt="Step1 processing",
|
74 |
+
input_fields=[InputField(name="value", description="Input value", variable="input.value")],
|
75 |
+
output_fields=[OutputField(name="result", description="Processed result", type="str", func="upper")],
|
76 |
)
|
77 |
|
78 |
step2 = ModelStep(
|
79 |
+
id="step2",
|
80 |
+
model="gpt-4o-mini",
|
81 |
+
provider="OpenAI",
|
82 |
+
call_type="llm",
|
83 |
+
system_prompt="Step2 processing",
|
84 |
+
input_fields=[InputField(name="result", description="Result from step1", variable="step1.result")],
|
85 |
+
output_fields=[OutputField(name="final", description="Final output", type="str", func="lower")],
|
86 |
)
|
87 |
|
88 |
+
workflow = Workflow(
|
89 |
+
steps={"step1": step1, "step2": step2},
|
90 |
+
inputs=["input.value"],
|
91 |
+
outputs={"final": "step2.final"}
|
92 |
+
)
|
93 |
|
94 |
# Execute the workflow
|
95 |
+
from workflows.executors import execute_workflow
|
96 |
|
97 |
result = execute_workflow(
|
98 |
workflow=workflow,
|
99 |
+
input_values={"input.value": "Hello, World!"},
|
100 |
+
return_full_content=True,
|
101 |
+
logprob_step="step2"
|
102 |
)
|
103 |
|
104 |
# Access results
|
105 |
+
final_output = result["final_outputs"]["final"]
|
106 |
+
intermediate_results = result["intermediate_outputs"]
|
107 |
+
step_contents = result["step_contents"]
|
108 |
+
logprob = result["logprob"]
|
109 |
```
|
110 |
|
111 |
## Error Handling
|
|
|
115 |
- Detects cyclic dependencies in workflow definitions
|
116 |
- Validates input/output variable references
|
117 |
- Ensures all required inputs are provided
|
118 |
+
- Supports custom validation rules through the validation system
|
119 |
+
- Provides detailed error messages for debugging
|
120 |
|
121 |
## Extending the Workflows System
|
122 |
|
|
|
124 |
|
125 |
1. Add new model step types by extending the `ModelStep` class
|
126 |
2. Create custom field types by extending validation in the execution logic
|
127 |
+
3. Implement additional error types in `errors.py` for specialized error handling
|
128 |
+
4. Add new transformation functions to `FUNCTION_MAP` in `configs.py`
|
129 |
+
5. Create specialized workflow types by extending the `Workflow` class
|
src/workflows/configs.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Configuration settings for the workflows package.
|
3 |
+
|
4 |
+
This module contains configuration settings and constants used across the workflows package,
|
5 |
+
including model configurations, workflow settings, and other package-wide constants.
|
6 |
+
"""
|
7 |
+
|
8 |
+
AVAILABLE_MODELS = {
|
9 |
+
"OpenAI/gpt-4o": {
|
10 |
+
"model": "gpt-4o-2024-11-20",
|
11 |
+
},
|
12 |
+
"OpenAI/gpt-4o-mini": {
|
13 |
+
"model": "gpt-4o-mini-2024-07-18",
|
14 |
+
},
|
15 |
+
"OpenAI/gpt-3.5-turbo": {
|
16 |
+
"model": "gpt-3.5-turbo-0125",
|
17 |
+
},
|
18 |
+
"Anthropic/claude-3-7-sonnet": {
|
19 |
+
"model": "claude-3-7-sonnet-20250219",
|
20 |
+
},
|
21 |
+
"Anthropic/claude-3-5-sonnet": {
|
22 |
+
"model": "claude-3-5-sonnet-20241022",
|
23 |
+
},
|
24 |
+
"Anthropic/claude-3-5-haiku": {
|
25 |
+
"model": "claude-3-5-haiku-20241022",
|
26 |
+
},
|
27 |
+
"Cohere/command-r": {
|
28 |
+
"model": "command-r-08-2024",
|
29 |
+
},
|
30 |
+
"Cohere/command-r-plus": {
|
31 |
+
"model": "command-r-plus-08-2024",
|
32 |
+
},
|
33 |
+
"Cohere/command-r7b": {
|
34 |
+
"model": "command-r7b-12-2024",
|
35 |
+
},
|
36 |
+
}
|
37 |
+
|
38 |
+
# Function mapping for input/output transformations
|
39 |
+
TYPE_MAP = {
|
40 |
+
"str": str,
|
41 |
+
"int": int,
|
42 |
+
"float": float,
|
43 |
+
"bool": bool,
|
44 |
+
}
|
45 |
+
|
46 |
+
FUNCTION_MAP = {
|
47 |
+
"upper": str.upper,
|
48 |
+
"lower": str.lower,
|
49 |
+
"len": len,
|
50 |
+
"split": str.split,
|
51 |
+
}
|
src/workflows/executors.py
CHANGED
@@ -1,13 +1,3 @@
|
|
1 |
-
# %%
|
2 |
-
from typing import Any
|
3 |
-
|
4 |
-
import pydantic
|
5 |
-
|
6 |
-
from llms import completion
|
7 |
-
from workflows.errors import WorkflowError
|
8 |
-
from workflows.structs import InputField, ModelStep, OutputField, Workflow
|
9 |
-
from workflows.utils import create_dependency_graph, topological_sort
|
10 |
-
|
11 |
"""
|
12 |
Core workflow execution functionality.
|
13 |
|
@@ -18,42 +8,48 @@ with the litellm library to handle model interactions.
|
|
18 |
Key components:
|
19 |
- Utility functions for input/output transformation
|
20 |
- Input processing and validation
|
21 |
-
- Model step execution
|
22 |
- Complete workflow execution with dependency resolution
|
|
|
|
|
23 |
|
24 |
The module orchestrates the execution of steps in the correct order based on their
|
25 |
-
dependencies and manages the flow of data between steps.
|
|
|
|
|
|
|
|
|
26 |
"""
|
27 |
|
|
|
28 |
|
29 |
-
|
30 |
-
if isinstance(x, str):
|
31 |
-
return x.upper()
|
32 |
-
return x
|
33 |
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
-
def lower(x):
|
36 |
-
if isinstance(x, str):
|
37 |
-
return x.lower()
|
38 |
-
return x
|
39 |
|
|
|
|
|
|
|
40 |
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
"float": float,
|
45 |
-
"bool": bool,
|
46 |
-
}
|
47 |
|
48 |
-
|
49 |
-
|
50 |
-
"lower": lower,
|
51 |
-
"len": len,
|
52 |
-
"split": str.split,
|
53 |
-
}
|
54 |
|
|
|
|
|
55 |
|
56 |
-
|
|
|
|
|
|
|
57 |
return TYPE_MAP.get(type_str, eval(type_str))
|
58 |
|
59 |
|
@@ -95,10 +91,72 @@ def create_processed_inputs(model_step: ModelStep, available_vars: dict[str, Any
|
|
95 |
return processed_inputs
|
96 |
|
97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
# %%
|
99 |
def execute_model_step(
|
100 |
-
model_step: ModelStep,
|
101 |
-
|
|
|
|
|
|
|
102 |
"""
|
103 |
Executes a model step using the provided available variables.
|
104 |
|
@@ -117,10 +175,14 @@ def execute_model_step(
|
|
117 |
input/output specifications, and system prompt.
|
118 |
available_vars (dict[str, Any]): A dictionary of all variables available to this step,
|
119 |
including outputs from previous steps and external inputs.
|
|
|
|
|
|
|
|
|
120 |
|
121 |
Returns:
|
122 |
-
|
123 |
-
|
124 |
|
125 |
Raises:
|
126 |
WorkflowError: If there's an error in input processing, model execution,
|
@@ -136,8 +198,8 @@ def execute_model_step(
|
|
136 |
... input_fields=[InputField(name="text", variable="input_text", description="Text to summarize")],
|
137 |
... output_fields=[OutputField(name="summary", type="str", description="Summary of the text")]
|
138 |
... )
|
139 |
-
>>> execute_model_step(step, {"input_text": "Long text to be summarized..."})
|
140 |
-
|
141 |
"""
|
142 |
# Ensure inputs are processed using the specified functions in input_fields.
|
143 |
processed_inputs = create_processed_inputs(model_step, available_vars)
|
@@ -159,28 +221,25 @@ def execute_model_step(
|
|
159 |
system=model_step.system_prompt,
|
160 |
prompt=step_result,
|
161 |
response_format=ModelResponse,
|
|
|
162 |
)
|
163 |
-
|
164 |
-
# model=model_step.model,
|
165 |
-
# messages=[{"role": "user", "content": step_result}],
|
166 |
-
# response_format=ModelResponse,
|
167 |
-
# )
|
168 |
-
|
169 |
-
# Extract and parse the model response
|
170 |
-
# model_response_content = api_response["choices"][0]["message"]["content"]
|
171 |
-
# model_response = json.loads(model_response_content)
|
172 |
-
model_response = api_response["output"]
|
173 |
# Map the parsed response to the output fields
|
174 |
-
outputs = {field.name:
|
|
|
175 |
if return_full_content:
|
176 |
-
|
177 |
-
|
|
|
|
|
178 |
|
179 |
|
180 |
-
# %%
|
181 |
def execute_multi_step_workflow(
|
182 |
-
workflow: Workflow,
|
183 |
-
|
|
|
|
|
|
|
184 |
"""
|
185 |
Execute the given workflow as a computational graph.
|
186 |
|
@@ -203,12 +262,11 @@ def execute_multi_step_workflow(
|
|
203 |
Keys should match the required workflow.inputs.
|
204 |
return_full_content (bool, optional): If True, returns the full content of each step.
|
205 |
Defaults to False.
|
|
|
|
|
206 |
|
207 |
Returns:
|
208 |
-
A
|
209 |
-
- A dictionary of the workflow's outputs, with keys matching the variables defined in workflow.outputs.
|
210 |
-
- A dictionary of all computed values during workflow execution, including intermediate results.
|
211 |
-
- A dictionary of step contents, only populated if return_full_content is True.
|
212 |
|
213 |
Raises:
|
214 |
UnknownVariableError: If an input_field references a variable that is not
|
@@ -252,49 +310,75 @@ def execute_multi_step_workflow(
|
|
252 |
|
253 |
# Step 4: Execute steps in topological order.
|
254 |
step_contents: dict[str, Any] = {}
|
|
|
255 |
for step_id in execution_order:
|
256 |
step = workflow.steps[step_id]
|
257 |
-
|
258 |
-
outputs = execute_model_step(step, computed_values, return_full_content=return_full_content)
|
259 |
# Execute the step
|
|
|
|
|
|
|
|
|
|
|
260 |
if return_full_content:
|
261 |
-
|
262 |
-
|
263 |
-
outputs = {f"{step_id}.{k}": v for k, v in outputs.items()}
|
264 |
computed_values.update(outputs)
|
265 |
|
266 |
# Step 5: Gather and return workflow outputs.
|
267 |
final_outputs: dict[str, Any] = {}
|
268 |
for target, var in workflow.outputs.items():
|
269 |
if var not in computed_values:
|
270 |
-
raise WorkflowError(
|
|
|
|
|
271 |
final_outputs[target] = computed_values[var]
|
272 |
|
273 |
-
|
274 |
-
|
275 |
-
|
|
|
|
|
|
|
276 |
|
277 |
|
278 |
def execute_simple_workflow(
|
279 |
-
workflow: Workflow,
|
280 |
-
|
281 |
-
|
|
|
|
|
|
|
|
|
282 |
|
283 |
-
This is
|
|
|
|
|
284 |
|
285 |
Args:
|
286 |
-
workflow: The workflow to execute
|
287 |
-
input_values:
|
288 |
-
|
|
|
|
|
|
|
|
|
289 |
|
290 |
Returns:
|
291 |
-
|
292 |
-
|
293 |
-
- computed_values: Dictionary of all computed values
|
294 |
-
- step_contents: Dictionary of step contents (if return_full_content=True)
|
295 |
|
296 |
Raises:
|
297 |
-
WorkflowError: If the workflow has more than one step
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
298 |
"""
|
299 |
if len(workflow.steps) != 1:
|
300 |
raise WorkflowError("Simple workflow must have exactly one step")
|
@@ -302,19 +386,17 @@ def execute_simple_workflow(
|
|
302 |
# Get the single step
|
303 |
step = list(workflow.steps.values())[0]
|
304 |
|
|
|
|
|
305 |
# Validate inputs
|
306 |
for var in workflow.inputs:
|
307 |
if var not in input_values:
|
308 |
raise WorkflowError(f"Missing required workflow input: {var}")
|
309 |
|
310 |
# Execute the step
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
else:
|
315 |
-
step_outputs = execute_model_step(step, input_values, return_full_content=False)
|
316 |
-
step_contents = {}
|
317 |
-
|
318 |
# Prepare the final outputs
|
319 |
final_outputs = {}
|
320 |
for target, var in workflow.outputs.items():
|
@@ -328,27 +410,101 @@ def execute_simple_workflow(
|
|
328 |
raise WorkflowError(f"Invalid output mapping: {var} does not match step ID {step.id}")
|
329 |
|
330 |
# Prepare computed values (prefixed with step ID)
|
331 |
-
computed_values = input_values.
|
332 |
-
computed_values.update({f"{step.id}.{k}": v for k, v in step_outputs.items()})
|
333 |
|
334 |
-
return
|
|
|
|
|
|
|
|
|
|
|
335 |
|
336 |
|
337 |
def execute_workflow(
|
338 |
-
workflow: Workflow,
|
339 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
340 |
if len(workflow.steps) > 1:
|
341 |
-
return execute_multi_step_workflow(workflow, input_values, return_full_content)
|
342 |
else:
|
343 |
-
return execute_simple_workflow(workflow, input_values, return_full_content)
|
344 |
|
345 |
|
346 |
def run_examples():
|
347 |
"""
|
348 |
-
Runs
|
349 |
-
|
350 |
-
|
351 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
352 |
"""
|
353 |
print("Example 1: Successful Workflow Execution")
|
354 |
# Example 1: Simple linear workflow.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
"""
|
2 |
Core workflow execution functionality.
|
3 |
|
|
|
8 |
Key components:
|
9 |
- Utility functions for input/output transformation
|
10 |
- Input processing and validation
|
11 |
+
- Model step execution with support for log probabilities
|
12 |
- Complete workflow execution with dependency resolution
|
13 |
+
- Support for both simple (single-step) and multi-step workflows
|
14 |
+
- Structured output collection with intermediate results
|
15 |
|
16 |
The module orchestrates the execution of steps in the correct order based on their
|
17 |
+
dependencies and manages the flow of data between steps. It supports:
|
18 |
+
- Full content tracking for debugging
|
19 |
+
- Log probability calculation for specific steps
|
20 |
+
- Flexible input/output transformations
|
21 |
+
- Error handling and validation
|
22 |
"""
|
23 |
|
24 |
+
from typing import Any, TypedDict
|
25 |
|
26 |
+
import pydantic
|
|
|
|
|
|
|
27 |
|
28 |
+
from .configs import FUNCTION_MAP, TYPE_MAP
|
29 |
+
from .errors import WorkflowError
|
30 |
+
from .llms import completion
|
31 |
+
from .structs import InputField, ModelStep, OutputField, Workflow
|
32 |
+
from .utils import create_dependency_graph, topological_sort
|
33 |
|
|
|
|
|
|
|
|
|
34 |
|
35 |
+
def get_type(type_str: str) -> type:
|
36 |
+
"""
|
37 |
+
Converts a type string to its corresponding Python type.
|
38 |
|
39 |
+
This function maps type strings to their actual Python type objects. It first checks
|
40 |
+
the TYPE_MAP dictionary for predefined mappings, and if not found, falls back to
|
41 |
+
evaluating the type string directly.
|
|
|
|
|
|
|
42 |
|
43 |
+
Args:
|
44 |
+
type_str (str): A string representation of a type (e.g., "str", "int", "list[str]")
|
|
|
|
|
|
|
|
|
45 |
|
46 |
+
Returns:
|
47 |
+
type: The corresponding Python type object
|
48 |
|
49 |
+
Note:
|
50 |
+
Uses eval() for non-predefined types, which has security implications if used
|
51 |
+
with untrusted input. This is intended for internal use with validated type strings.
|
52 |
+
"""
|
53 |
return TYPE_MAP.get(type_str, eval(type_str))
|
54 |
|
55 |
|
|
|
91 |
return processed_inputs
|
92 |
|
93 |
|
94 |
+
class ModelStepResult(TypedDict):
|
95 |
+
"""
|
96 |
+
Result of executing a model step.
|
97 |
+
|
98 |
+
This TypedDict contains the outputs and metadata from executing a single model step,
|
99 |
+
including the processed output values, the full response content, and log probability
|
100 |
+
information when requested.
|
101 |
+
|
102 |
+
Attributes:
|
103 |
+
outputs (dict[str, Any]): A dictionary of processed outputs from the model step,
|
104 |
+
with keys matching the output field names.
|
105 |
+
content (str | None): The full content of the model's response, only populated
|
106 |
+
if return_full_content is True.
|
107 |
+
logprob (float | None): The log probability of the model step output, only populated
|
108 |
+
if logprobs is True.
|
109 |
+
"""
|
110 |
+
|
111 |
+
# A dictionary of processed outputs from the model step,
|
112 |
+
# with keys matching the output field names.
|
113 |
+
outputs: dict[str, Any]
|
114 |
+
|
115 |
+
# The full content of the model step.
|
116 |
+
content: str | None
|
117 |
+
|
118 |
+
# The log probability of the model step output if requested.
|
119 |
+
logprob: float | None
|
120 |
+
|
121 |
+
|
122 |
+
class WorkflowOutput(TypedDict):
|
123 |
+
"""
|
124 |
+
Result of executing a complete workflow.
|
125 |
+
|
126 |
+
This TypedDict contains the outputs and metadata from executing a workflow,
|
127 |
+
including final outputs, intermediate values, step contents, and log probabilities.
|
128 |
+
|
129 |
+
Attributes:
|
130 |
+
final_outputs (dict[str, Any]): The final output values produced by the workflow,
|
131 |
+
with keys matching the names defined in workflow.outputs.
|
132 |
+
intermediate_outputs (dict[str, Any]): All computed values during workflow execution,
|
133 |
+
including both external inputs and outputs from all steps.
|
134 |
+
step_contents (dict[str, Any]): Full response content for each step, keyed by step ID.
|
135 |
+
Only populated if return_full_content is True.
|
136 |
+
logprob (float | None): The log probability of the specified step's output.
|
137 |
+
Only populated if logprob_step is specified.
|
138 |
+
"""
|
139 |
+
|
140 |
+
# A dictionary of the workflow's outputs, with keys matching the variables defined in workflow.outputs.
|
141 |
+
final_outputs: dict[str, Any]
|
142 |
+
|
143 |
+
# A dictionary of all computed values during workflow execution, including intermediate results.
|
144 |
+
intermediate_outputs: dict[str, Any]
|
145 |
+
|
146 |
+
# A dictionary of step contents, only populated if return_full_content is True.
|
147 |
+
step_contents: dict[str, Any]
|
148 |
+
|
149 |
+
# The log probability of the workflow output if requested.
|
150 |
+
logprob: float | None
|
151 |
+
|
152 |
+
|
153 |
# %%
|
154 |
def execute_model_step(
|
155 |
+
model_step: ModelStep,
|
156 |
+
available_vars: dict[str, Any],
|
157 |
+
return_full_content: bool = False,
|
158 |
+
logprobs: bool = False,
|
159 |
+
) -> ModelStepResult:
|
160 |
"""
|
161 |
Executes a model step using the provided available variables.
|
162 |
|
|
|
175 |
input/output specifications, and system prompt.
|
176 |
available_vars (dict[str, Any]): A dictionary of all variables available to this step,
|
177 |
including outputs from previous steps and external inputs.
|
178 |
+
return_full_content (bool, optional): If True, includes the full model response content
|
179 |
+
in the result. Defaults to False.
|
180 |
+
logprobs (bool, optional): If True, calculates and returns log probability information
|
181 |
+
for the model response. Defaults to False.
|
182 |
|
183 |
Returns:
|
184 |
+
ModelStepResult: A TypedDict containing processed outputs, optional full content,
|
185 |
+
and optional log probability information.
|
186 |
|
187 |
Raises:
|
188 |
WorkflowError: If there's an error in input processing, model execution,
|
|
|
198 |
... input_fields=[InputField(name="text", variable="input_text", description="Text to summarize")],
|
199 |
... output_fields=[OutputField(name="summary", type="str", description="Summary of the text")]
|
200 |
... )
|
201 |
+
>>> result = execute_model_step(step, {"input_text": "Long text to be summarized..."})
|
202 |
+
>>> summary = result["outputs"]["summary"]
|
203 |
"""
|
204 |
# Ensure inputs are processed using the specified functions in input_fields.
|
205 |
processed_inputs = create_processed_inputs(model_step, available_vars)
|
|
|
221 |
system=model_step.system_prompt,
|
222 |
prompt=step_result,
|
223 |
response_format=ModelResponse,
|
224 |
+
logprobs=logprobs,
|
225 |
)
|
226 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
# Map the parsed response to the output fields
|
228 |
+
outputs = {field.name: api_response["output"][field.name] for field in model_step.output_fields}
|
229 |
+
result = ModelStepResult(outputs=outputs, content=None, logprob=None)
|
230 |
if return_full_content:
|
231 |
+
result["content"] = api_response["content"]
|
232 |
+
if logprobs:
|
233 |
+
result["logprob"] = api_response["log_prob"]
|
234 |
+
return result
|
235 |
|
236 |
|
|
|
237 |
def execute_multi_step_workflow(
|
238 |
+
workflow: Workflow,
|
239 |
+
input_values: dict[str, Any],
|
240 |
+
return_full_content: bool = False,
|
241 |
+
logprob_step: str | None = None,
|
242 |
+
) -> WorkflowOutput:
|
243 |
"""
|
244 |
Execute the given workflow as a computational graph.
|
245 |
|
|
|
262 |
Keys should match the required workflow.inputs.
|
263 |
return_full_content (bool, optional): If True, returns the full content of each step.
|
264 |
Defaults to False.
|
265 |
+
logprob_step (str, optional): The ID of the step to use for log probability calculation.
|
266 |
+
Defaults to None.
|
267 |
|
268 |
Returns:
|
269 |
+
WorkflowOutput: A dictionary of workflow outputs, including final outputs, intermediate outputs, and step contents.
|
|
|
|
|
|
|
270 |
|
271 |
Raises:
|
272 |
UnknownVariableError: If an input_field references a variable that is not
|
|
|
310 |
|
311 |
# Step 4: Execute steps in topological order.
|
312 |
step_contents: dict[str, Any] = {}
|
313 |
+
logprob = None
|
314 |
for step_id in execution_order:
|
315 |
step = workflow.steps[step_id]
|
316 |
+
return_logprobs = logprob_step == step_id
|
|
|
317 |
# Execute the step
|
318 |
+
result = execute_model_step(
|
319 |
+
step, computed_values, return_full_content=return_full_content, logprobs=return_logprobs
|
320 |
+
)
|
321 |
+
if return_logprobs:
|
322 |
+
logprob = result["logprob"]
|
323 |
if return_full_content:
|
324 |
+
step_contents[step_id] = result["content"]
|
325 |
+
outputs = {f"{step_id}.{k}": v for k, v in result["outputs"].items()}
|
|
|
326 |
computed_values.update(outputs)
|
327 |
|
328 |
# Step 5: Gather and return workflow outputs.
|
329 |
final_outputs: dict[str, Any] = {}
|
330 |
for target, var in workflow.outputs.items():
|
331 |
if var not in computed_values:
|
332 |
+
raise WorkflowError(
|
333 |
+
f"Workflow output variable {var} was not produced. Computed values: {computed_values.keys()}"
|
334 |
+
)
|
335 |
final_outputs[target] = computed_values[var]
|
336 |
|
337 |
+
return WorkflowOutput(
|
338 |
+
final_outputs=final_outputs,
|
339 |
+
intermediate_outputs=computed_values,
|
340 |
+
step_contents=step_contents,
|
341 |
+
logprob=logprob,
|
342 |
+
)
|
343 |
|
344 |
|
345 |
def execute_simple_workflow(
|
346 |
+
workflow: Workflow,
|
347 |
+
input_values: dict[str, Any],
|
348 |
+
return_full_content: bool = False,
|
349 |
+
logprob_step: bool | str = False,
|
350 |
+
) -> WorkflowOutput:
|
351 |
+
"""
|
352 |
+
Execute a simple workflow with a single step.
|
353 |
|
354 |
+
This is an optimized version of workflow execution for workflows containing only one step.
|
355 |
+
It bypasses the dependency graph building and topological sorting steps, providing a more
|
356 |
+
direct execution path for simple workflows.
|
357 |
|
358 |
Args:
|
359 |
+
workflow (Workflow): The workflow to execute, which must contain exactly one step.
|
360 |
+
input_values (dict[str, Any]): External input values to be used by the workflow.
|
361 |
+
Keys should match the required workflow.inputs.
|
362 |
+
return_full_content (bool, optional): If True, includes the full model response content
|
363 |
+
in the result. Defaults to False.
|
364 |
+
logprobs (bool, optional): If True, calculates and returns log probability information
|
365 |
+
for the model response. Defaults to False.
|
366 |
|
367 |
Returns:
|
368 |
+
WorkflowOutput: A TypedDict containing the workflow outputs, intermediate values,
|
369 |
+
optional step contents, and optional log probability information.
|
|
|
|
|
370 |
|
371 |
Raises:
|
372 |
+
WorkflowError: If the workflow has more than one step or if required inputs are missing.
|
373 |
+
|
374 |
+
Example:
|
375 |
+
>>> workflow = Workflow(
|
376 |
+
... steps={"extract": ModelStep(...)},
|
377 |
+
... inputs=["text"],
|
378 |
+
... outputs={"entities": "extract.entities"}
|
379 |
+
... )
|
380 |
+
>>> result = execute_simple_workflow(workflow, {"text": "Apple is launching a new product."})
|
381 |
+
>>> entities = result["final_outputs"]["entities"]
|
382 |
"""
|
383 |
if len(workflow.steps) != 1:
|
384 |
raise WorkflowError("Simple workflow must have exactly one step")
|
|
|
386 |
# Get the single step
|
387 |
step = list(workflow.steps.values())[0]
|
388 |
|
389 |
+
logprobs = logprob_step is True or logprob_step == step.id
|
390 |
+
|
391 |
# Validate inputs
|
392 |
for var in workflow.inputs:
|
393 |
if var not in input_values:
|
394 |
raise WorkflowError(f"Missing required workflow input: {var}")
|
395 |
|
396 |
# Execute the step
|
397 |
+
step_result = execute_model_step(step, input_values, return_full_content=return_full_content, logprobs=logprobs)
|
398 |
+
step_outputs = step_result["outputs"]
|
399 |
+
step_contents = {step.id: step_result["content"]} if return_full_content else {}
|
|
|
|
|
|
|
|
|
400 |
# Prepare the final outputs
|
401 |
final_outputs = {}
|
402 |
for target, var in workflow.outputs.items():
|
|
|
410 |
raise WorkflowError(f"Invalid output mapping: {var} does not match step ID {step.id}")
|
411 |
|
412 |
# Prepare computed values (prefixed with step ID)
|
413 |
+
computed_values = input_values | {f"{step.id}.{k}": v for k, v in step_outputs.items()}
|
|
|
414 |
|
415 |
+
return WorkflowOutput(
|
416 |
+
final_outputs=final_outputs,
|
417 |
+
intermediate_outputs=computed_values,
|
418 |
+
step_contents=step_contents,
|
419 |
+
logprob=step_result.get("logprob"),
|
420 |
+
)
|
421 |
|
422 |
|
423 |
def execute_workflow(
|
424 |
+
workflow: Workflow,
|
425 |
+
input_values: dict[str, Any],
|
426 |
+
return_full_content: bool = False,
|
427 |
+
logprob_step: str | bool = False,
|
428 |
+
) -> WorkflowOutput:
|
429 |
+
"""
|
430 |
+
Main entry point for executing workflows of any complexity.
|
431 |
+
|
432 |
+
This function serves as a router that delegates to the appropriate specialized
|
433 |
+
execution function based on the complexity of the workflow:
|
434 |
+
- For single-step workflows, it calls execute_simple_workflow
|
435 |
+
- For multi-step workflows, it calls execute_multi_step_workflow
|
436 |
+
|
437 |
+
This abstraction allows callers to use a consistent interface regardless of
|
438 |
+
the workflow's complexity.
|
439 |
+
|
440 |
+
Args:
|
441 |
+
workflow (Workflow): The workflow to execute, containing steps, their
|
442 |
+
dependencies, and input/output specifications.
|
443 |
+
input_values (dict[str, Any]): External input values to be used by the workflow.
|
444 |
+
Keys should match the required workflow.inputs.
|
445 |
+
return_full_content (bool, optional): If True, includes the full model response
|
446 |
+
content in the result. Defaults to False.
|
447 |
+
logprob_step (str | bool, optional): Either a string with the ID of the step for which
|
448 |
+
to calculate log probability, or a boolean flag.
|
449 |
+
If False, no log probabilities are calculated.
|
450 |
+
Defaults to False.
|
451 |
+
|
452 |
+
Returns:
|
453 |
+
WorkflowOutput: A TypedDict containing the workflow outputs, intermediate values,
|
454 |
+
optional step contents, and optional log probability information.
|
455 |
+
|
456 |
+
Raises:
|
457 |
+
WorkflowError: For any workflow-related errors, such as missing required inputs,
|
458 |
+
circular dependencies, or invalid variable references.
|
459 |
+
|
460 |
+
Example:
|
461 |
+
>>> workflow = Workflow(
|
462 |
+
... steps={"extract": ModelStep(...), "analyze": ModelStep(...)},
|
463 |
+
... inputs=["text"],
|
464 |
+
... outputs={"sentiment": "analyze.sentiment"}
|
465 |
+
... )
|
466 |
+
>>> result = execute_workflow(
|
467 |
+
... workflow,
|
468 |
+
... {"text": "Apple is launching a new product."},
|
469 |
+
... return_full_content=True,
|
470 |
+
... logprob_step="analyze"
|
471 |
+
... )
|
472 |
+
>>> print(result["final_outputs"]["sentiment"])
|
473 |
+
"positive"
|
474 |
+
"""
|
475 |
if len(workflow.steps) > 1:
|
476 |
+
return execute_multi_step_workflow(workflow, input_values, return_full_content, logprob_step)
|
477 |
else:
|
478 |
+
return execute_simple_workflow(workflow, input_values, return_full_content, logprob_step)
|
479 |
|
480 |
|
481 |
def run_examples():
|
482 |
"""
|
483 |
+
Runs example workflows demonstrating key functionality and error handling.
|
484 |
+
|
485 |
+
This function creates and executes three different example workflows to showcase:
|
486 |
+
|
487 |
+
1. Successful workflow execution:
|
488 |
+
- A linear two-step workflow with proper dependency flow
|
489 |
+
- Input transformation using the 'upper' function
|
490 |
+
- Output transformation using the 'lower' function
|
491 |
+
- Proper variable passing between steps
|
492 |
+
|
493 |
+
2. Cyclic dependency detection:
|
494 |
+
- A workflow with two steps that depend on each other circularly
|
495 |
+
- Demonstrates the error handling for cyclic dependencies
|
496 |
+
- Shows how the system prevents infinite execution loops
|
497 |
+
|
498 |
+
3. Unknown variable detection:
|
499 |
+
- A workflow that references a variable not provided as input or by any step
|
500 |
+
- Demonstrates validation of variable references
|
501 |
+
- Shows error handling for missing dependencies
|
502 |
+
|
503 |
+
Each example prints its result or the error encountered, making this function
|
504 |
+
useful for testing and demonstration purposes.
|
505 |
+
|
506 |
+
Returns:
|
507 |
+
None: This function prints its results and doesn't return a value.
|
508 |
"""
|
509 |
print("Example 1: Successful Workflow Execution")
|
510 |
# Example 1: Simple linear workflow.
|
src/workflows/factory.py
CHANGED
@@ -1,5 +1,14 @@
|
|
1 |
# %%
|
2 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
INITIAL_SYS_PROMPT = """You are a helpful performant question answering bot.
|
5 |
Given a question clue, output your most likely guess in a couple words with a calibrated confidence for the guess.
|
@@ -71,8 +80,8 @@ def create_first_llm_step() -> ModelStep:
|
|
71 |
)
|
72 |
|
73 |
|
74 |
-
def
|
75 |
-
return
|
76 |
inputs=["question_text"],
|
77 |
outputs={"answer": "A.answer", "confidence": "A.confidence"},
|
78 |
steps={
|
@@ -99,6 +108,11 @@ def create_quizbowl_simple_workflow():
|
|
99 |
],
|
100 |
)
|
101 |
},
|
|
|
|
|
|
|
|
|
|
|
102 |
)
|
103 |
|
104 |
|
@@ -114,7 +128,7 @@ CONFIDENCE: <0-1>
|
|
114 |
EXPLANATION: <your reasoning>"""
|
115 |
|
116 |
|
117 |
-
def
|
118 |
"""Create a simple model step for bonus questions."""
|
119 |
return Workflow(
|
120 |
inputs=["leadin", "part"],
|
@@ -126,7 +140,7 @@ def create_quizbowl_bonus_simple_workflow() -> Workflow:
|
|
126 |
model="gpt-4o-mini",
|
127 |
provider="OpenAI",
|
128 |
temperature=0.3,
|
129 |
-
call_type=
|
130 |
system_prompt=BONUS_SYS_PROMPT,
|
131 |
input_fields=[
|
132 |
InputField(
|
|
|
1 |
# %%
|
2 |
+
from .structs import (
|
3 |
+
Buzzer,
|
4 |
+
BuzzerMethod,
|
5 |
+
CallType,
|
6 |
+
InputField,
|
7 |
+
ModelStep,
|
8 |
+
OutputField,
|
9 |
+
TossupWorkflow,
|
10 |
+
Workflow,
|
11 |
+
)
|
12 |
|
13 |
INITIAL_SYS_PROMPT = """You are a helpful performant question answering bot.
|
14 |
Given a question clue, output your most likely guess in a couple words with a calibrated confidence for the guess.
|
|
|
80 |
)
|
81 |
|
82 |
|
83 |
+
def create_simple_qb_tossup_workflow():
|
84 |
+
return TossupWorkflow(
|
85 |
inputs=["question_text"],
|
86 |
outputs={"answer": "A.answer", "confidence": "A.confidence"},
|
87 |
steps={
|
|
|
108 |
],
|
109 |
)
|
110 |
},
|
111 |
+
buzzer=Buzzer(
|
112 |
+
confidence_threshold=0.75,
|
113 |
+
prob_threshold=None,
|
114 |
+
method=BuzzerMethod.AND,
|
115 |
+
),
|
116 |
)
|
117 |
|
118 |
|
|
|
128 |
EXPLANATION: <your reasoning>"""
|
129 |
|
130 |
|
131 |
+
def create_simple_qb_bonus_workflow() -> Workflow:
|
132 |
"""Create a simple model step for bonus questions."""
|
133 |
return Workflow(
|
134 |
inputs=["leadin", "part"],
|
|
|
140 |
model="gpt-4o-mini",
|
141 |
provider="OpenAI",
|
142 |
temperature=0.3,
|
143 |
+
call_type=CallType.LLM,
|
144 |
system_prompt=BONUS_SYS_PROMPT,
|
145 |
input_fields=[
|
146 |
InputField(
|
src/{llms.py → workflows/llms.py}
RENAMED
@@ -13,7 +13,7 @@ from openai import OpenAI
|
|
13 |
from pydantic import BaseModel, Field
|
14 |
from rich import print as rprint
|
15 |
|
16 |
-
from
|
17 |
|
18 |
|
19 |
def _openai_is_json_mode_supported(model_name: str) -> bool:
|
|
|
13 |
from pydantic import BaseModel, Field
|
14 |
from rich import print as rprint
|
15 |
|
16 |
+
from .configs import AVAILABLE_MODELS
|
17 |
|
18 |
|
19 |
def _openai_is_json_mode_supported(model_name: str) -> bool:
|
src/workflows/qb_agents.py
CHANGED
@@ -1,20 +1,36 @@
|
|
1 |
import time
|
2 |
-
from typing import Any, Iterable
|
3 |
|
4 |
-
from
|
5 |
-
from
|
6 |
|
7 |
|
8 |
-
def _get_workflow_response(
|
9 |
-
workflow: Workflow, available_vars: dict[str, Any]
|
10 |
-
) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any], float]:
|
11 |
"""Get response from executing a complete workflow."""
|
12 |
start_time = time.time()
|
13 |
-
|
14 |
-
workflow, available_vars, return_full_content=True
|
15 |
-
)
|
16 |
response_time = time.time() - start_time
|
17 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
|
20 |
class QuizBowlTossupAgent:
|
@@ -23,7 +39,7 @@ class QuizBowlTossupAgent:
|
|
23 |
external_input_variable = "question_text"
|
24 |
output_variables = ["answer", "confidence"]
|
25 |
|
26 |
-
def __init__(self, workflow:
|
27 |
"""Initialize the multi-step tossup agent.
|
28 |
|
29 |
Args:
|
@@ -31,7 +47,6 @@ class QuizBowlTossupAgent:
|
|
31 |
buzz_threshold: Confidence threshold for buzzing
|
32 |
"""
|
33 |
self.workflow = workflow
|
34 |
-
self.buzz_threshold = buzz_threshold
|
35 |
self.output_variables = list(workflow.outputs.keys())
|
36 |
|
37 |
# Validate input variables
|
@@ -43,7 +58,7 @@ class QuizBowlTossupAgent:
|
|
43 |
if out_var not in workflow.outputs:
|
44 |
raise ValueError(f"Output variable {out_var} not found in workflow outputs")
|
45 |
|
46 |
-
def run(self, question_runs: list[str], early_stop: bool = True) -> Iterable[
|
47 |
"""Process a tossup question and decide when to buzz based on confidence.
|
48 |
|
49 |
Args:
|
@@ -63,26 +78,26 @@ class QuizBowlTossupAgent:
|
|
63 |
"""
|
64 |
for i, question_text in enumerate(question_runs):
|
65 |
# Execute the complete workflow
|
66 |
-
|
67 |
self.workflow, {self.external_input_variable: question_text}
|
68 |
)
|
69 |
-
|
70 |
-
buzz = final_outputs["confidence"]
|
71 |
-
result = {
|
|
|
72 |
"answer": final_outputs["answer"],
|
73 |
"confidence": final_outputs["confidence"],
|
74 |
"buzz": buzz,
|
75 |
"question_fragment": question_text,
|
76 |
-
"
|
77 |
-
"
|
78 |
"response_time": response_time,
|
79 |
-
"step_outputs": computed_values, # Include intermediate step outputs
|
80 |
}
|
81 |
|
82 |
yield result
|
83 |
|
84 |
# If we've reached the confidence threshold, buzz and stop
|
85 |
-
if early_stop and buzz:
|
86 |
return
|
87 |
|
88 |
|
@@ -111,7 +126,7 @@ class QuizBowlBonusAgent:
|
|
111 |
if out_var not in workflow.outputs:
|
112 |
raise ValueError(f"Output variable {out_var} not found in workflow outputs")
|
113 |
|
114 |
-
def run(self, leadin: str, part: str) ->
|
115 |
"""Process a bonus part with the given leadin.
|
116 |
|
117 |
Args:
|
@@ -127,21 +142,21 @@ class QuizBowlBonusAgent:
|
|
127 |
- response_time: Time taken for response
|
128 |
- step_outputs: Outputs from each step
|
129 |
"""
|
130 |
-
|
131 |
self.workflow,
|
132 |
{
|
133 |
"leadin": leadin,
|
134 |
"part": part,
|
135 |
},
|
136 |
)
|
137 |
-
|
138 |
return {
|
139 |
"answer": final_outputs["answer"],
|
140 |
"confidence": final_outputs["confidence"],
|
141 |
"explanation": final_outputs["explanation"],
|
142 |
-
"step_contents": step_contents,
|
143 |
"response_time": response_time,
|
144 |
-
"step_outputs":
|
145 |
}
|
146 |
|
147 |
|
|
|
1 |
import time
|
2 |
+
from typing import Any, Iterable, TypedDict
|
3 |
|
4 |
+
from .executors import WorkflowOutput, execute_workflow
|
5 |
+
from .structs import TossupWorkflow, Workflow
|
6 |
|
7 |
|
8 |
+
def _get_workflow_response(workflow: Workflow, available_vars: dict[str, Any]) -> tuple[WorkflowOutput, float]:
|
|
|
|
|
9 |
"""Get response from executing a complete workflow."""
|
10 |
start_time = time.time()
|
11 |
+
workflow_output = execute_workflow(workflow, available_vars, return_full_content=True)
|
|
|
|
|
12 |
response_time = time.time() - start_time
|
13 |
+
return workflow_output, response_time
|
14 |
+
|
15 |
+
|
16 |
+
class TossupResult(TypedDict):
|
17 |
+
answer: str
|
18 |
+
confidence: float
|
19 |
+
buzz: bool
|
20 |
+
question_fragment: str
|
21 |
+
position: int
|
22 |
+
step_contents: list[str]
|
23 |
+
response_time: float
|
24 |
+
step_outputs: dict[str, Any]
|
25 |
+
|
26 |
+
|
27 |
+
class BonusResult(TypedDict):
|
28 |
+
answer: str
|
29 |
+
confidence: float
|
30 |
+
explanation: str
|
31 |
+
response_time: float
|
32 |
+
step_contents: list[str]
|
33 |
+
step_outputs: dict[str, Any]
|
34 |
|
35 |
|
36 |
class QuizBowlTossupAgent:
|
|
|
39 |
external_input_variable = "question_text"
|
40 |
output_variables = ["answer", "confidence"]
|
41 |
|
42 |
+
def __init__(self, workflow: TossupWorkflow):
|
43 |
"""Initialize the multi-step tossup agent.
|
44 |
|
45 |
Args:
|
|
|
47 |
buzz_threshold: Confidence threshold for buzzing
|
48 |
"""
|
49 |
self.workflow = workflow
|
|
|
50 |
self.output_variables = list(workflow.outputs.keys())
|
51 |
|
52 |
# Validate input variables
|
|
|
58 |
if out_var not in workflow.outputs:
|
59 |
raise ValueError(f"Output variable {out_var} not found in workflow outputs")
|
60 |
|
61 |
+
def run(self, question_runs: list[str], early_stop: bool = True) -> Iterable[TossupResult]:
|
62 |
"""Process a tossup question and decide when to buzz based on confidence.
|
63 |
|
64 |
Args:
|
|
|
78 |
"""
|
79 |
for i, question_text in enumerate(question_runs):
|
80 |
# Execute the complete workflow
|
81 |
+
workflow_output, response_time = _get_workflow_response(
|
82 |
self.workflow, {self.external_input_variable: question_text}
|
83 |
)
|
84 |
+
final_outputs = workflow_output["final_outputs"]
|
85 |
+
buzz = self.workflow.buzzer.run(final_outputs["confidence"], logprob=final_outputs.get("logprob"))
|
86 |
+
result: TossupResult = {
|
87 |
+
"position": i + 1,
|
88 |
"answer": final_outputs["answer"],
|
89 |
"confidence": final_outputs["confidence"],
|
90 |
"buzz": buzz,
|
91 |
"question_fragment": question_text,
|
92 |
+
"step_contents": workflow_output["step_contents"],
|
93 |
+
"step_outputs": workflow_output["intermediate_outputs"], # Include intermediate step outputs
|
94 |
"response_time": response_time,
|
|
|
95 |
}
|
96 |
|
97 |
yield result
|
98 |
|
99 |
# If we've reached the confidence threshold, buzz and stop
|
100 |
+
if early_stop and result["buzz"]:
|
101 |
return
|
102 |
|
103 |
|
|
|
126 |
if out_var not in workflow.outputs:
|
127 |
raise ValueError(f"Output variable {out_var} not found in workflow outputs")
|
128 |
|
129 |
+
def run(self, leadin: str, part: str) -> BonusResult:
|
130 |
"""Process a bonus part with the given leadin.
|
131 |
|
132 |
Args:
|
|
|
142 |
- response_time: Time taken for response
|
143 |
- step_outputs: Outputs from each step
|
144 |
"""
|
145 |
+
workflow_output, response_time = _get_workflow_response(
|
146 |
self.workflow,
|
147 |
{
|
148 |
"leadin": leadin,
|
149 |
"part": part,
|
150 |
},
|
151 |
)
|
152 |
+
final_outputs = workflow_output["final_outputs"]
|
153 |
return {
|
154 |
"answer": final_outputs["answer"],
|
155 |
"confidence": final_outputs["confidence"],
|
156 |
"explanation": final_outputs["explanation"],
|
157 |
+
"step_contents": workflow_output["step_contents"],
|
158 |
"response_time": response_time,
|
159 |
+
"step_outputs": workflow_output["intermediate_outputs"], # Include intermediate step outputs
|
160 |
}
|
161 |
|
162 |
|
src/workflows/structs.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
# %%
|
|
|
2 |
from typing import Any, Literal, Optional
|
3 |
|
|
|
4 |
from pydantic import BaseModel, Field, model_validator
|
5 |
|
6 |
"""
|
@@ -20,6 +22,7 @@ All classes use Pydantic's BaseModel for validation and serialization support.
|
|
20 |
"""
|
21 |
FieldType = Literal["input", "output"]
|
22 |
|
|
|
23 |
SUPPORTED_TYPES = Literal["str", "int", "float", "bool", "list[str]", "list[int]", "list[float]", "list[bool]"]
|
24 |
"""Supported field types for input and output fields"""
|
25 |
|
@@ -68,6 +71,12 @@ class OutputField(BaseModel):
|
|
68 |
func: str | None = None
|
69 |
|
70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
class ModelStep(BaseModel):
|
72 |
"""
|
73 |
Represents a single step in a workflow.
|
@@ -89,7 +98,7 @@ class ModelStep(BaseModel):
|
|
89 |
name: str
|
90 |
model: str
|
91 |
provider: str
|
92 |
-
call_type:
|
93 |
|
94 |
# TODO: Validate that this is not None for call_type = llm
|
95 |
temperature: Optional[float] = None
|
@@ -231,4 +240,42 @@ class Workflow(BaseModel):
|
|
231 |
return list(variables)
|
232 |
|
233 |
|
234 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# %%
|
2 |
+
from enum import Enum
|
3 |
from typing import Any, Literal, Optional
|
4 |
|
5 |
+
import numpy as np
|
6 |
from pydantic import BaseModel, Field, model_validator
|
7 |
|
8 |
"""
|
|
|
22 |
"""
|
23 |
FieldType = Literal["input", "output"]
|
24 |
|
25 |
+
|
26 |
SUPPORTED_TYPES = Literal["str", "int", "float", "bool", "list[str]", "list[int]", "list[float]", "list[bool]"]
|
27 |
"""Supported field types for input and output fields"""
|
28 |
|
|
|
71 |
func: str | None = None
|
72 |
|
73 |
|
74 |
+
class CallType(str, Enum):
|
75 |
+
LLM = "llm"
|
76 |
+
SEARCH = "search"
|
77 |
+
PYTHON_FUNC = "python_func"
|
78 |
+
|
79 |
+
|
80 |
class ModelStep(BaseModel):
|
81 |
"""
|
82 |
Represents a single step in a workflow.
|
|
|
98 |
name: str
|
99 |
model: str
|
100 |
provider: str
|
101 |
+
call_type: CallType = CallType.LLM
|
102 |
|
103 |
# TODO: Validate that this is not None for call_type = llm
|
104 |
temperature: Optional[float] = None
|
|
|
240 |
return list(variables)
|
241 |
|
242 |
|
243 |
+
class BuzzerMethod(str, Enum):
|
244 |
+
AND = "AND"
|
245 |
+
OR = "OR"
|
246 |
+
|
247 |
+
|
248 |
+
class Buzzer(BaseModel):
|
249 |
+
"""Configuration for when to buzz in a tossup question."""
|
250 |
+
|
251 |
+
method: BuzzerMethod = BuzzerMethod.AND # Logic to combine thresholds
|
252 |
+
confidence_threshold: float = Field(default=0.8, ge=0.0, le=1.0) # Minimum confidence to trigger a buzz
|
253 |
+
prob_threshold: float | None = None # Optional log probability threshold
|
254 |
+
|
255 |
+
def run(self, confidence: float, prob: float | None = None, logprob: float | None = None) -> bool:
|
256 |
+
"""Run the buzzer logic."""
|
257 |
+
if logprob is not None and prob is not None:
|
258 |
+
raise ValueError("Cannot provide both logprob and prob")
|
259 |
+
if logprob is not None:
|
260 |
+
prob = np.exp(logprob)
|
261 |
+
if self.prob_threshold is None:
|
262 |
+
return confidence >= self.confidence_threshold
|
263 |
+
if self.method == BuzzerMethod.AND:
|
264 |
+
return confidence >= self.confidence_threshold and prob >= self.prob_threshold
|
265 |
+
elif self.method == BuzzerMethod.OR:
|
266 |
+
return confidence >= self.confidence_threshold or prob >= self.prob_threshold
|
267 |
+
else:
|
268 |
+
raise ValueError(f"Invalid buzzer method: {self.method}")
|
269 |
+
|
270 |
+
@model_validator(mode="after")
|
271 |
+
def validate_method_with_log_prob(cls, data):
|
272 |
+
"""Validate that if prob_threshold is None, method must be 'and'."""
|
273 |
+
if data.prob_threshold is None and data.method != BuzzerMethod.AND:
|
274 |
+
raise ValueError("If prob_threshold is None, method must be 'and'")
|
275 |
+
return data
|
276 |
+
|
277 |
+
|
278 |
+
class TossupWorkflow(Workflow):
|
279 |
+
"""Workflow specialized for tossup questions with buzzing capability."""
|
280 |
+
|
281 |
+
buzzer: Buzzer
|
src/workflows/utils.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
from collections import deque
|
2 |
-
from typing import Any
|
3 |
|
4 |
-
from
|
5 |
-
from
|
6 |
|
7 |
"""
|
8 |
Utilities for workflow dependency management and execution order determination.
|
@@ -98,6 +98,40 @@ def create_dependency_graph(workflow: Workflow, input_values: dict[str, Any]) ->
|
|
98 |
return dependencies
|
99 |
|
100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
def topological_sort(dependencies: dict[str, set[str]]) -> list[str]:
|
102 |
"""
|
103 |
Performs a topological sort on a dependency graph and detects cycles using Kahn's algorithm.
|
|
|
1 |
from collections import deque
|
2 |
+
from typing import Any, Iterable
|
3 |
|
4 |
+
from .errors import CyclicDependencyError, UnknownVariableError, WorkflowError
|
5 |
+
from .structs import Workflow
|
6 |
|
7 |
"""
|
8 |
Utilities for workflow dependency management and execution order determination.
|
|
|
98 |
return dependencies
|
99 |
|
100 |
|
101 |
+
def detect_cycles(dep_graph: dict[str, Iterable[str]]) -> str | None:
|
102 |
+
"""Detects cycles in the dependency graph.
|
103 |
+
Args:
|
104 |
+
dep_graph: A dictionary where the keys are node IDs and the values are the dependent node IDs
|
105 |
+
Returns:
|
106 |
+
The first step id of a model_step that is part of a cycle, None if no cycles are found
|
107 |
+
"""
|
108 |
+
# Check for cycles in step dependencies
|
109 |
+
visited = set()
|
110 |
+
path = set()
|
111 |
+
|
112 |
+
def has_cycle(node: str) -> bool:
|
113 |
+
if node in path:
|
114 |
+
return True
|
115 |
+
if node in visited:
|
116 |
+
return False
|
117 |
+
|
118 |
+
visited.add(node)
|
119 |
+
path.add(node)
|
120 |
+
|
121 |
+
for neighbor in dep_graph.get(node, set()):
|
122 |
+
if has_cycle(neighbor):
|
123 |
+
return True
|
124 |
+
|
125 |
+
path.remove(node)
|
126 |
+
return False
|
127 |
+
|
128 |
+
# Check each step for cycles
|
129 |
+
for node_id in dep_graph:
|
130 |
+
if has_cycle(node_id):
|
131 |
+
return node_id
|
132 |
+
return None
|
133 |
+
|
134 |
+
|
135 |
def topological_sort(dependencies: dict[str, set[str]]) -> list[str]:
|
136 |
"""
|
137 |
Performs a topological sort on a dependency graph and detects cycles using Kahn's algorithm.
|
src/workflows/validators.py
CHANGED
@@ -2,9 +2,10 @@ import keyword
|
|
2 |
import re
|
3 |
from dataclasses import dataclass
|
4 |
from enum import Enum
|
5 |
-
from typing import
|
6 |
|
7 |
-
from .structs import InputField, ModelStep, OutputField, Workflow
|
|
|
8 |
|
9 |
SUPPORTED_TYPES = {"str", "int", "float", "bool", "list[str]", "list[int]", "list[float]", "list[bool]"}
|
10 |
|
@@ -13,7 +14,7 @@ MAX_FIELD_NAME_LENGTH = 50
|
|
13 |
MAX_DESCRIPTION_LENGTH = 200
|
14 |
MAX_SYSTEM_PROMPT_LENGTH = 4000
|
15 |
MIN_TEMPERATURE = 0.0
|
16 |
-
MAX_TEMPERATURE =
|
17 |
|
18 |
|
19 |
class ValidationErrorType(Enum):
|
@@ -42,16 +43,42 @@ class ValidationError:
|
|
42 |
class WorkflowValidationError(Exception):
|
43 |
"""Base class for workflow validation errors"""
|
44 |
|
45 |
-
def __init__(self, errors:
|
46 |
self.errors = errors
|
47 |
super().__init__(f"Workflow validation failed with {len(errors)} errors")
|
48 |
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
class WorkflowValidator:
|
51 |
"""Validates workflows for correctness and consistency"""
|
52 |
|
53 |
def __init__(self):
|
54 |
-
self.errors:
|
55 |
self.workflow: Optional[Workflow] = None
|
56 |
|
57 |
def validate(self, workflow: Workflow) -> bool:
|
@@ -106,7 +133,7 @@ class WorkflowValidator:
|
|
106 |
return False
|
107 |
|
108 |
# Verify the output field exists in the step
|
109 |
-
_, field_name =
|
110 |
if not any(field.name == field_name for field in step.output_fields):
|
111 |
self.errors.append(
|
112 |
ValidationError(
|
@@ -153,7 +180,7 @@ class WorkflowValidator:
|
|
153 |
return False
|
154 |
|
155 |
# Verify the output field exists in the referenced step
|
156 |
-
step_id, field_name =
|
157 |
if step_id not in workflow.steps:
|
158 |
self.errors.append(
|
159 |
ValidationError(ValidationErrorType.VARIABLE, f"Referenced step '{step_id}' not found")
|
@@ -172,47 +199,22 @@ class WorkflowValidator:
|
|
172 |
)
|
173 |
return False
|
174 |
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
# Check for cycles in step dependencies
|
181 |
-
visited = set()
|
182 |
-
path = set()
|
183 |
-
|
184 |
-
def has_cycle(node: str) -> bool:
|
185 |
-
if node in path:
|
186 |
-
return True
|
187 |
-
if node in visited:
|
188 |
-
return False
|
189 |
-
|
190 |
-
visited.add(node)
|
191 |
-
path.add(node)
|
192 |
-
|
193 |
-
for neighbor in dep_graph.get(node, set()):
|
194 |
-
if has_cycle(neighbor):
|
195 |
-
return True
|
196 |
-
|
197 |
-
path.remove(node)
|
198 |
-
return False
|
199 |
-
|
200 |
-
# Check each step for cycles
|
201 |
-
for step_id in workflow.steps:
|
202 |
-
if has_cycle(step_id):
|
203 |
-
self.errors.append(
|
204 |
-
ValidationError(ValidationErrorType.DAG, f"Circular dependency detected involving step: {step_id}")
|
205 |
)
|
206 |
-
|
|
|
207 |
|
208 |
# Check for orphaned steps (steps that aren't used by any other step)
|
209 |
used_steps = set()
|
210 |
for deps in dep_graph.values():
|
211 |
used_steps.update(deps)
|
212 |
-
print("Used steps: ", used_steps)
|
213 |
for step_id in workflow.steps:
|
214 |
if step_id not in used_steps and not any(
|
215 |
-
output_var and
|
216 |
for output_var in workflow.outputs.values()
|
217 |
):
|
218 |
self.errors.append(ValidationError(ValidationErrorType.DAG, f"Orphaned step detected: {step_id}"))
|
@@ -277,7 +279,7 @@ class WorkflowValidator:
|
|
277 |
return False
|
278 |
|
279 |
# Validate temperature for LLM call type
|
280 |
-
if step.call_type ==
|
281 |
if step.temperature is None:
|
282 |
self.errors.append(
|
283 |
ValidationError(ValidationErrorType.STEP, "LLM step must specify temperature", step.id)
|
@@ -295,7 +297,7 @@ class WorkflowValidator:
|
|
295 |
return False
|
296 |
|
297 |
# Validate system prompt for LLM call type
|
298 |
-
if step.call_type ==
|
299 |
if not step.system_prompt:
|
300 |
self.errors.append(
|
301 |
ValidationError(ValidationErrorType.STEP, "LLM step must specify system prompt", step.id)
|
@@ -477,50 +479,32 @@ class WorkflowValidator:
|
|
477 |
def _validate_variable_dependencies(self, workflow: Workflow) -> bool:
|
478 |
"""Validates variable dependencies between steps"""
|
479 |
# Build variable dependency graph
|
480 |
-
var_graph:
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
|
|
|
|
490 |
|
491 |
# Check for cycles in variable dependencies
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
return True
|
498 |
-
if node in visited:
|
499 |
-
return False
|
500 |
-
|
501 |
-
visited.add(node)
|
502 |
-
path.add(node)
|
503 |
-
|
504 |
-
for neighbor in var_graph.get(node, set()):
|
505 |
-
if has_cycle(neighbor):
|
506 |
-
return True
|
507 |
-
|
508 |
-
path.remove(node)
|
509 |
return False
|
510 |
|
511 |
-
# Check each variable for cycles
|
512 |
-
for var in var_graph:
|
513 |
-
if has_cycle(var):
|
514 |
-
self.errors.append(
|
515 |
-
ValidationError(ValidationErrorType.VARIABLE, f"Circular variable dependency detected: {var}")
|
516 |
-
)
|
517 |
-
return False
|
518 |
-
|
519 |
# Validate external input existence
|
520 |
external_inputs = set(workflow.inputs)
|
521 |
for step in workflow.steps.values():
|
522 |
for field in step.input_fields:
|
523 |
-
step_id, field_name =
|
524 |
if not step_id and field_name not in external_inputs:
|
525 |
self.errors.append(
|
526 |
ValidationError(
|
@@ -533,22 +517,6 @@ class WorkflowValidator:
|
|
533 |
|
534 |
return True
|
535 |
|
536 |
-
def _get_step_dependencies(self, step: ModelStep) -> Set[str]:
|
537 |
-
"""Gets set of step IDs that this step depends on"""
|
538 |
-
deps = set()
|
539 |
-
for field in step.input_fields:
|
540 |
-
step_id = self._parse_variable_reference(field.variable)[0]
|
541 |
-
if step_id:
|
542 |
-
deps.add(step_id)
|
543 |
-
return deps
|
544 |
-
|
545 |
-
def _parse_variable_reference(self, var: str) -> Tuple[Optional[str], str]:
|
546 |
-
"""Extracts step_id and field_name from variable reference"""
|
547 |
-
parts = var.split(".")
|
548 |
-
if len(parts) == 1:
|
549 |
-
return None, parts[0]
|
550 |
-
return parts[0], parts[1]
|
551 |
-
|
552 |
def _is_valid_variable_reference(self, var: str) -> bool:
|
553 |
"""Validates if a variable reference is properly formatted"""
|
554 |
if not self.workflow:
|
|
|
2 |
import re
|
3 |
from dataclasses import dataclass
|
4 |
from enum import Enum
|
5 |
+
from typing import Optional
|
6 |
|
7 |
+
from .structs import CallType, InputField, ModelStep, OutputField, Workflow
|
8 |
+
from .utils import detect_cycles
|
9 |
|
10 |
SUPPORTED_TYPES = {"str", "int", "float", "bool", "list[str]", "list[int]", "list[float]", "list[bool]"}
|
11 |
|
|
|
14 |
MAX_DESCRIPTION_LENGTH = 200
|
15 |
MAX_SYSTEM_PROMPT_LENGTH = 4000
|
16 |
MIN_TEMPERATURE = 0.0
|
17 |
+
MAX_TEMPERATURE = 10.0
|
18 |
|
19 |
|
20 |
class ValidationErrorType(Enum):
|
|
|
43 |
class WorkflowValidationError(Exception):
|
44 |
"""Base class for workflow validation errors"""
|
45 |
|
46 |
+
def __init__(self, errors: list[ValidationError]):
|
47 |
self.errors = errors
|
48 |
super().__init__(f"Workflow validation failed with {len(errors)} errors")
|
49 |
|
50 |
|
51 |
+
def _parse_variable_reference(var: str) -> tuple[Optional[str], str]:
|
52 |
+
"""Extracts step_id and field_name from variable reference"""
|
53 |
+
parts = var.split(".")
|
54 |
+
if len(parts) == 1:
|
55 |
+
return None, parts[0]
|
56 |
+
return parts[0], parts[1]
|
57 |
+
|
58 |
+
|
59 |
+
def _get_step_dependencies(step: ModelStep) -> set[str]:
|
60 |
+
"""Gets set of step IDs that this step depends on"""
|
61 |
+
deps = set()
|
62 |
+
for field in step.input_fields:
|
63 |
+
step_id, _ = _parse_variable_reference(field.variable)
|
64 |
+
if step_id:
|
65 |
+
deps.add(step_id)
|
66 |
+
return deps
|
67 |
+
|
68 |
+
|
69 |
+
def create_step_dep_graph(workflow: Workflow) -> dict[str, set[str]]:
|
70 |
+
"""Creates a dependency graph of steps"""
|
71 |
+
dep_graph: dict[str, set[str]] = {}
|
72 |
+
for step_id, step in workflow.steps.items():
|
73 |
+
dep_graph[step_id] = _get_step_dependencies(step)
|
74 |
+
return dep_graph
|
75 |
+
|
76 |
+
|
77 |
class WorkflowValidator:
|
78 |
"""Validates workflows for correctness and consistency"""
|
79 |
|
80 |
def __init__(self):
|
81 |
+
self.errors: list[ValidationError] = []
|
82 |
self.workflow: Optional[Workflow] = None
|
83 |
|
84 |
def validate(self, workflow: Workflow) -> bool:
|
|
|
133 |
return False
|
134 |
|
135 |
# Verify the output field exists in the step
|
136 |
+
_, field_name = _parse_variable_reference(output_var)
|
137 |
if not any(field.name == field_name for field in step.output_fields):
|
138 |
self.errors.append(
|
139 |
ValidationError(
|
|
|
180 |
return False
|
181 |
|
182 |
# Verify the output field exists in the referenced step
|
183 |
+
step_id, field_name = _parse_variable_reference(output_var)
|
184 |
if step_id not in workflow.steps:
|
185 |
self.errors.append(
|
186 |
ValidationError(ValidationErrorType.VARIABLE, f"Referenced step '{step_id}' not found")
|
|
|
199 |
)
|
200 |
return False
|
201 |
|
202 |
+
dep_graph = create_step_dep_graph(workflow)
|
203 |
+
if cycle_step_id := detect_cycles(dep_graph):
|
204 |
+
self.errors.append(
|
205 |
+
ValidationError(
|
206 |
+
ValidationErrorType.DAG, f"Circular dependency detected involving step: {cycle_step_id}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
)
|
208 |
+
)
|
209 |
+
return False
|
210 |
|
211 |
# Check for orphaned steps (steps that aren't used by any other step)
|
212 |
used_steps = set()
|
213 |
for deps in dep_graph.values():
|
214 |
used_steps.update(deps)
|
|
|
215 |
for step_id in workflow.steps:
|
216 |
if step_id not in used_steps and not any(
|
217 |
+
output_var and _parse_variable_reference(output_var)[0] == step_id
|
218 |
for output_var in workflow.outputs.values()
|
219 |
):
|
220 |
self.errors.append(ValidationError(ValidationErrorType.DAG, f"Orphaned step detected: {step_id}"))
|
|
|
279 |
return False
|
280 |
|
281 |
# Validate temperature for LLM call type
|
282 |
+
if step.call_type == CallType.LLM:
|
283 |
if step.temperature is None:
|
284 |
self.errors.append(
|
285 |
ValidationError(ValidationErrorType.STEP, "LLM step must specify temperature", step.id)
|
|
|
297 |
return False
|
298 |
|
299 |
# Validate system prompt for LLM call type
|
300 |
+
if step.call_type == CallType.LLM:
|
301 |
if not step.system_prompt:
|
302 |
self.errors.append(
|
303 |
ValidationError(ValidationErrorType.STEP, "LLM step must specify system prompt", step.id)
|
|
|
479 |
def _validate_variable_dependencies(self, workflow: Workflow) -> bool:
|
480 |
"""Validates variable dependencies between steps"""
|
481 |
# Build variable dependency graph
|
482 |
+
var_graph: dict[str, set[str]] = {}
|
483 |
+
|
484 |
+
def create_var_dep_graph(workflow: Workflow) -> dict[str, set[str]]:
|
485 |
+
var_graph: dict[str, set[str]] = {}
|
486 |
+
for step_id, step in workflow.steps.items():
|
487 |
+
for field in step.input_fields:
|
488 |
+
if field.variable not in var_graph:
|
489 |
+
var_graph[field.variable] = set()
|
490 |
+
# Add dependency from input variable to step's outputs
|
491 |
+
for output in step.output_fields:
|
492 |
+
var_graph[field.variable].add(f"{step_id}.{output.name}")
|
493 |
+
return var_graph
|
494 |
|
495 |
# Check for cycles in variable dependencies
|
496 |
+
var_graph = create_var_dep_graph(workflow)
|
497 |
+
if cycle_var := detect_cycles(var_graph):
|
498 |
+
self.errors.append(
|
499 |
+
ValidationError(ValidationErrorType.VARIABLE, f"Circular variable dependency detected: {cycle_var}")
|
500 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
501 |
return False
|
502 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
503 |
# Validate external input existence
|
504 |
external_inputs = set(workflow.inputs)
|
505 |
for step in workflow.steps.values():
|
506 |
for field in step.input_fields:
|
507 |
+
step_id, field_name = _parse_variable_reference(field.variable)
|
508 |
if not step_id and field_name not in external_inputs:
|
509 |
self.errors.append(
|
510 |
ValidationError(
|
|
|
517 |
|
518 |
return True
|
519 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
520 |
def _is_valid_variable_reference(self, var: str) -> bool:
|
521 |
"""Validates if a variable reference is properly formatted"""
|
522 |
if not self.workflow:
|
tests/test_executors.py
CHANGED
@@ -8,37 +8,33 @@ from workflows.executors import (
|
|
8 |
create_processed_inputs,
|
9 |
execute_model_step,
|
10 |
execute_workflow,
|
11 |
-
lower,
|
12 |
-
upper,
|
13 |
)
|
14 |
-
from workflows.structs import InputField, ModelStep, OutputField, Workflow
|
15 |
|
16 |
# Tests for utility functions
|
|
|
|
|
17 |
|
18 |
|
19 |
-
|
20 |
-
"""Test the upper function with different input types."""
|
21 |
-
assert upper("hello") == "HELLO"
|
22 |
-
assert upper("Hello World") == "HELLO WORLD"
|
23 |
-
assert upper("") == ""
|
24 |
-
# Non-string inputs should be returned unchanged
|
25 |
-
assert upper(123) == 123
|
26 |
-
assert upper([1, 2, 3]) == [1, 2, 3]
|
27 |
-
assert upper(None) is None
|
28 |
|
29 |
|
30 |
-
def
|
31 |
-
|
32 |
-
assert
|
33 |
-
assert
|
34 |
-
assert
|
35 |
-
|
36 |
-
assert
|
37 |
-
assert
|
38 |
-
assert
|
39 |
|
40 |
|
41 |
-
|
|
|
|
|
|
|
|
|
42 |
|
43 |
|
44 |
def test_create_processed_inputs_basic():
|
@@ -47,8 +43,8 @@ def test_create_processed_inputs_basic():
|
|
47 |
id="test_step",
|
48 |
name="Test Step",
|
49 |
model="gpt-4",
|
50 |
-
provider="
|
51 |
-
call_type=
|
52 |
system_prompt="Test prompt",
|
53 |
input_fields=[InputField(name="text", description="Input text", variable="input_text")],
|
54 |
output_fields=[],
|
@@ -65,8 +61,8 @@ def test_create_processed_inputs_with_transformation():
|
|
65 |
id="test_step",
|
66 |
name="Test Step",
|
67 |
model="gpt-4",
|
68 |
-
provider="
|
69 |
-
call_type=
|
70 |
system_prompt="Test prompt",
|
71 |
input_fields=[
|
72 |
InputField(name="upper_text", description="Uppercase text", variable="input_text", func="upper"),
|
@@ -86,8 +82,8 @@ def test_create_processed_inputs_missing_var():
|
|
86 |
id="test_step",
|
87 |
name="Test Step",
|
88 |
model="gpt-4",
|
89 |
-
provider="
|
90 |
-
call_type=
|
91 |
system_prompt="Test prompt",
|
92 |
input_fields=[InputField(name="text", description="Input text", variable="missing_var")],
|
93 |
output_fields=[],
|
@@ -104,8 +100,8 @@ def test_create_processed_inputs_unknown_func():
|
|
104 |
id="test_step",
|
105 |
name="Test Step",
|
106 |
model="gpt-4",
|
107 |
-
provider="
|
108 |
-
call_type=
|
109 |
system_prompt="Test prompt",
|
110 |
input_fields=[InputField(name="text", description="Input text", variable="input_text", func="unknown_func")],
|
111 |
output_fields=[],
|
@@ -136,7 +132,7 @@ def test_execute_model_step_success(mock_completion):
|
|
136 |
name="Summarize Text",
|
137 |
model="gpt-3.5-turbo",
|
138 |
provider="OpenAI",
|
139 |
-
call_type=
|
140 |
system_prompt="Summarize the text",
|
141 |
input_fields=[InputField(name="text", description="Text to summarize", variable="input_text")],
|
142 |
output_fields=[OutputField(name="summary", description="Summary of the text", type="str")],
|
@@ -146,7 +142,13 @@ def test_execute_model_step_success(mock_completion):
|
|
146 |
result = execute_model_step(step, {"input_text": "Long text to be summarized..."})
|
147 |
|
148 |
# Verify the results
|
149 |
-
assert result
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
|
151 |
# Verify the litellm call was made correctly
|
152 |
mock_completion.assert_called_once()
|
@@ -155,6 +157,77 @@ def test_execute_model_step_success(mock_completion):
|
|
155 |
assert "Summarize the text" in kwargs["system"]
|
156 |
|
157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
@patch("workflows.executors.completion")
|
159 |
def test_execute_model_step_error(mock_completion):
|
160 |
"""Test handling of errors in model step execution."""
|
@@ -166,8 +239,8 @@ def test_execute_model_step_error(mock_completion):
|
|
166 |
id="summarize",
|
167 |
name="Summarize Text",
|
168 |
model="gpt-3.5-turbo",
|
169 |
-
provider="
|
170 |
-
call_type=
|
171 |
system_prompt="Summarize the text",
|
172 |
input_fields=[InputField(name="text", description="Text to summarize", variable="input_text")],
|
173 |
output_fields=[OutputField(name="summary", description="Summary of the text", type="str")],
|
@@ -185,15 +258,16 @@ def test_execute_model_step_error(mock_completion):
|
|
185 |
def test_execute_workflow_simple(mock_execute_step):
|
186 |
"""Test execution of a simple workflow with a single step."""
|
187 |
# Configure mock to return expected outputs
|
188 |
-
|
|
|
189 |
|
190 |
# Create a simple workflow
|
191 |
step = ModelStep(
|
192 |
id="summarize",
|
193 |
name="Summarize Text",
|
194 |
model="gpt-3.5-turbo",
|
195 |
-
provider="
|
196 |
-
call_type=
|
197 |
system_prompt="Summarize the text",
|
198 |
input_fields=[InputField(name="text", description="Text to summarize", variable="input_text")],
|
199 |
output_fields=[OutputField(name="summary", description="Summary of the text", type="str")],
|
@@ -202,15 +276,21 @@ def test_execute_workflow_simple(mock_execute_step):
|
|
202 |
workflow = Workflow(steps={"summarize": step}, inputs=["input_text"], outputs={"summary": "summarize.summary"})
|
203 |
|
204 |
# Execute the workflow
|
205 |
-
|
206 |
-
workflow, {"input_text": "Long text to be summarized..."}
|
207 |
-
)
|
208 |
|
209 |
# Verify the results
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
# Verify execute_model_step was called correctly
|
215 |
mock_execute_step.assert_called_once()
|
216 |
|
@@ -220,12 +300,12 @@ def test_execute_workflow_multi_step(mock_execute_step):
|
|
220 |
"""Test execution of a multi-step workflow with dependencies."""
|
221 |
|
222 |
# Configure mock to return different values based on the step
|
223 |
-
def side_effect(step, available_vars, return_full_content=False):
|
224 |
if step.id == "extract":
|
225 |
-
return {"entities": ["Apple", "product"]}
|
226 |
elif step.id == "analyze":
|
227 |
-
return {"sentiment": "positive"}
|
228 |
-
return {}
|
229 |
|
230 |
mock_execute_step.side_effect = side_effect
|
231 |
|
@@ -234,8 +314,8 @@ def test_execute_workflow_multi_step(mock_execute_step):
|
|
234 |
id="extract",
|
235 |
name="Extract Entities",
|
236 |
model="gpt-3.5-turbo",
|
237 |
-
provider="
|
238 |
-
call_type=
|
239 |
system_prompt="Extract entities",
|
240 |
input_fields=[InputField(name="text", description="Text to analyze", variable="input_text")],
|
241 |
output_fields=[OutputField(name="entities", description="Extracted entities", type="list[str]")],
|
@@ -246,8 +326,8 @@ def test_execute_workflow_multi_step(mock_execute_step):
|
|
246 |
id="analyze",
|
247 |
name="Analyze Sentiment",
|
248 |
model="gpt-4",
|
249 |
-
provider="
|
250 |
-
call_type=
|
251 |
system_prompt="Analyze sentiment",
|
252 |
input_fields=[InputField(name="entities", description="Entities to analyze", variable="extract.entities")],
|
253 |
output_fields=[OutputField(name="sentiment", description="Sentiment analysis", type="str")],
|
@@ -260,19 +340,22 @@ def test_execute_workflow_multi_step(mock_execute_step):
|
|
260 |
)
|
261 |
|
262 |
# Execute the workflow
|
263 |
-
|
264 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
265 |
)
|
266 |
|
267 |
-
# Verify the results
|
268 |
-
assert final_outputs == {"entities": ["Apple", "product"], "sentiment": "positive"}
|
269 |
-
assert computed_values == {
|
270 |
-
"input_text": "Apple is launching a new product tomorrow.",
|
271 |
-
"extract.entities": ["Apple", "product"],
|
272 |
-
"analyze.sentiment": "positive",
|
273 |
-
}
|
274 |
-
assert step_contents == {}
|
275 |
-
|
276 |
# Verify execute_model_step was called twice (once for each step)
|
277 |
assert mock_execute_step.call_count == 2
|
278 |
|
@@ -283,8 +366,8 @@ def test_execute_workflow_missing_input():
|
|
283 |
id="summarize",
|
284 |
name="Summarize Text",
|
285 |
model="gpt-3.5-turbo",
|
286 |
-
provider="
|
287 |
-
call_type=
|
288 |
system_prompt="Summarize the text",
|
289 |
input_fields=[InputField(name="text", description="Text to summarize", variable="input_text")],
|
290 |
output_fields=[OutputField(name="summary", description="Summary of the text", type="str")],
|
@@ -297,24 +380,32 @@ def test_execute_workflow_missing_input():
|
|
297 |
execute_workflow(workflow, {})
|
298 |
|
299 |
|
300 |
-
|
301 |
-
def test_execute_workflow_cyclic_dependency(mock_dependency_graph):
|
302 |
"""Test that a cyclic dependency in the workflow raises an appropriate error."""
|
303 |
# Make create_dependency_graph raise a CyclicDependencyError
|
304 |
-
mock_dependency_graph.side_effect = CyclicDependencyError()
|
305 |
|
306 |
-
|
307 |
-
id="
|
308 |
-
name="Test Step",
|
309 |
model="gpt-3.5-turbo",
|
310 |
-
provider="
|
311 |
-
call_type=
|
312 |
system_prompt="Test",
|
313 |
-
input_fields=[],
|
314 |
-
output_fields=[],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
315 |
)
|
316 |
|
317 |
-
workflow = Workflow(steps=[
|
318 |
|
319 |
# This should propagate the CyclicDependencyError
|
320 |
with pytest.raises(CyclicDependencyError):
|
@@ -325,15 +416,20 @@ def test_execute_workflow_cyclic_dependency(mock_dependency_graph):
|
|
325 |
def test_execute_workflow_with_full_content(mock_execute_step):
|
326 |
"""Test execution of a workflow with return_full_content=True."""
|
327 |
# Configure mock to return expected outputs and content
|
328 |
-
|
|
|
|
|
|
|
|
|
|
|
329 |
|
330 |
# Create a simple workflow
|
331 |
step = ModelStep(
|
332 |
id="summarize",
|
333 |
name="Summarize Text",
|
334 |
model="gpt-3.5-turbo",
|
335 |
-
provider="
|
336 |
-
call_type=
|
337 |
system_prompt="Summarize the text",
|
338 |
input_fields=[InputField(name="text", description="Text to summarize", variable="input_text")],
|
339 |
output_fields=[OutputField(name="summary", description="Summary of the text", type="str")],
|
@@ -342,14 +438,64 @@ def test_execute_workflow_with_full_content(mock_execute_step):
|
|
342 |
workflow = Workflow(steps=[step], inputs=["input_text"], outputs={"summary": "summarize.summary"})
|
343 |
|
344 |
# Execute the workflow with return_full_content=True
|
345 |
-
|
346 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
347 |
)
|
348 |
|
349 |
-
# Verify the results
|
350 |
-
assert final_outputs == {"summary": "This is a summary"}
|
351 |
-
assert computed_values == {"input_text": "Long text to be summarized...", "summarize.summary": "This is a summary"}
|
352 |
-
assert step_contents == {"summarize": "Full model response content"}
|
353 |
-
|
354 |
# Verify execute_model_step was called correctly with return_full_content=True
|
355 |
-
mock_execute_step.assert_called_once_with(step,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
create_processed_inputs,
|
9 |
execute_model_step,
|
10 |
execute_workflow,
|
|
|
|
|
11 |
)
|
12 |
+
from workflows.structs import CallType, InputField, ModelStep, OutputField, Workflow
|
13 |
|
14 |
# Tests for utility functions
|
15 |
+
lower = str.lower
|
16 |
+
upper = str.upper
|
17 |
|
18 |
|
19 |
+
# Tests for create_processed_inputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
|
22 |
+
def assert_model_step_result(result: dict, expected_result: dict):
|
23 |
+
# Verify the results
|
24 |
+
assert isinstance(result, dict)
|
25 |
+
assert "outputs" in result
|
26 |
+
assert "content" in result
|
27 |
+
assert "logprob" in result
|
28 |
+
assert result["outputs"] == expected_result["outputs"]
|
29 |
+
assert result["content"] == expected_result["content"]
|
30 |
+
assert result["logprob"] == expected_result["logprob"]
|
31 |
|
32 |
|
33 |
+
def assert_workflow_output(output: dict, expected_output: dict):
|
34 |
+
assert isinstance(output, dict)
|
35 |
+
for key in ["final_outputs", "intermediate_outputs", "step_contents", "logprob"]:
|
36 |
+
assert key in output
|
37 |
+
assert output[key] == expected_output[key]
|
38 |
|
39 |
|
40 |
def test_create_processed_inputs_basic():
|
|
|
43 |
id="test_step",
|
44 |
name="Test Step",
|
45 |
model="gpt-4",
|
46 |
+
provider="OpenAI",
|
47 |
+
call_type=CallType.LLM,
|
48 |
system_prompt="Test prompt",
|
49 |
input_fields=[InputField(name="text", description="Input text", variable="input_text")],
|
50 |
output_fields=[],
|
|
|
61 |
id="test_step",
|
62 |
name="Test Step",
|
63 |
model="gpt-4",
|
64 |
+
provider="OpenAI",
|
65 |
+
call_type=CallType.LLM,
|
66 |
system_prompt="Test prompt",
|
67 |
input_fields=[
|
68 |
InputField(name="upper_text", description="Uppercase text", variable="input_text", func="upper"),
|
|
|
82 |
id="test_step",
|
83 |
name="Test Step",
|
84 |
model="gpt-4",
|
85 |
+
provider="OpenAI",
|
86 |
+
call_type=CallType.LLM,
|
87 |
system_prompt="Test prompt",
|
88 |
input_fields=[InputField(name="text", description="Input text", variable="missing_var")],
|
89 |
output_fields=[],
|
|
|
100 |
id="test_step",
|
101 |
name="Test Step",
|
102 |
model="gpt-4",
|
103 |
+
provider="OpenAI",
|
104 |
+
call_type=CallType.LLM,
|
105 |
system_prompt="Test prompt",
|
106 |
input_fields=[InputField(name="text", description="Input text", variable="input_text", func="unknown_func")],
|
107 |
output_fields=[],
|
|
|
132 |
name="Summarize Text",
|
133 |
model="gpt-3.5-turbo",
|
134 |
provider="OpenAI",
|
135 |
+
call_type=CallType.LLM,
|
136 |
system_prompt="Summarize the text",
|
137 |
input_fields=[InputField(name="text", description="Text to summarize", variable="input_text")],
|
138 |
output_fields=[OutputField(name="summary", description="Summary of the text", type="str")],
|
|
|
142 |
result = execute_model_step(step, {"input_text": "Long text to be summarized..."})
|
143 |
|
144 |
# Verify the results
|
145 |
+
assert isinstance(result, dict)
|
146 |
+
assert "outputs" in result
|
147 |
+
assert "content" in result
|
148 |
+
assert "logprob" in result
|
149 |
+
assert result["outputs"] == {"summary": "This is a summary"}
|
150 |
+
assert result["content"] is None
|
151 |
+
assert result["logprob"] is None
|
152 |
|
153 |
# Verify the litellm call was made correctly
|
154 |
mock_completion.assert_called_once()
|
|
|
157 |
assert "Summarize the text" in kwargs["system"]
|
158 |
|
159 |
|
160 |
+
@patch("workflows.executors.completion")
|
161 |
+
def test_execute_model_step_with_full_content(mock_completion):
|
162 |
+
"""Test execution of a model step with full content returned."""
|
163 |
+
# Mock the litellm response
|
164 |
+
mock_response = {
|
165 |
+
"content": "Full model response content",
|
166 |
+
"output": {"summary": "This is a summary"},
|
167 |
+
}
|
168 |
+
mock_completion.return_value = mock_response
|
169 |
+
|
170 |
+
# Create a test step
|
171 |
+
step = ModelStep(
|
172 |
+
id="summarize",
|
173 |
+
name="Summarize Text",
|
174 |
+
model="gpt-3.5-turbo",
|
175 |
+
provider="OpenAI",
|
176 |
+
call_type=CallType.LLM,
|
177 |
+
system_prompt="Summarize the text",
|
178 |
+
input_fields=[InputField(name="text", description="Text to summarize", variable="input_text")],
|
179 |
+
output_fields=[OutputField(name="summary", description="Summary of the text", type="str")],
|
180 |
+
)
|
181 |
+
|
182 |
+
# Execute the step with return_full_content=True
|
183 |
+
result = execute_model_step(step, {"input_text": "Long text to be summarized..."}, return_full_content=True)
|
184 |
+
|
185 |
+
# Verify the results
|
186 |
+
assert isinstance(result, dict)
|
187 |
+
assert "outputs" in result
|
188 |
+
assert "content" in result
|
189 |
+
assert "logprob" in result
|
190 |
+
assert result["outputs"] == {"summary": "This is a summary"}
|
191 |
+
assert result["content"] == "Full model response content"
|
192 |
+
assert result["logprob"] is None
|
193 |
+
|
194 |
+
|
195 |
+
@patch("workflows.executors.completion")
|
196 |
+
def test_execute_model_step_with_logprobs(mock_completion):
|
197 |
+
"""Test execution of a model step with log probabilities."""
|
198 |
+
# Mock the litellm response with log probability
|
199 |
+
mock_response = {
|
200 |
+
"content": json.dumps({"summary": "This is a summary"}),
|
201 |
+
"output": {"summary": "This is a summary"},
|
202 |
+
"log_prob": -2.5,
|
203 |
+
}
|
204 |
+
mock_completion.return_value = mock_response
|
205 |
+
|
206 |
+
# Create a test step
|
207 |
+
step = ModelStep(
|
208 |
+
id="summarize",
|
209 |
+
name="Summarize Text",
|
210 |
+
model="gpt-3.5-turbo",
|
211 |
+
provider="OpenAI",
|
212 |
+
call_type=CallType.LLM,
|
213 |
+
system_prompt="Summarize the text",
|
214 |
+
input_fields=[InputField(name="text", description="Text to summarize", variable="input_text")],
|
215 |
+
output_fields=[OutputField(name="summary", description="Summary of the text", type="str")],
|
216 |
+
)
|
217 |
+
|
218 |
+
# Execute the step with logprobs=True
|
219 |
+
result = execute_model_step(step, {"input_text": "Long text to be summarized..."}, logprobs=True)
|
220 |
+
|
221 |
+
# Verify the results
|
222 |
+
assert isinstance(result, dict)
|
223 |
+
assert "outputs" in result
|
224 |
+
assert "content" in result
|
225 |
+
assert "logprob" in result
|
226 |
+
assert result["outputs"] == {"summary": "This is a summary"}
|
227 |
+
assert result["content"] is None
|
228 |
+
assert result["logprob"] == -2.5
|
229 |
+
|
230 |
+
|
231 |
@patch("workflows.executors.completion")
|
232 |
def test_execute_model_step_error(mock_completion):
|
233 |
"""Test handling of errors in model step execution."""
|
|
|
239 |
id="summarize",
|
240 |
name="Summarize Text",
|
241 |
model="gpt-3.5-turbo",
|
242 |
+
provider="OpenAI",
|
243 |
+
call_type=CallType.LLM,
|
244 |
system_prompt="Summarize the text",
|
245 |
input_fields=[InputField(name="text", description="Text to summarize", variable="input_text")],
|
246 |
output_fields=[OutputField(name="summary", description="Summary of the text", type="str")],
|
|
|
258 |
def test_execute_workflow_simple(mock_execute_step):
|
259 |
"""Test execution of a simple workflow with a single step."""
|
260 |
# Configure mock to return expected outputs
|
261 |
+
mock_result = {"outputs": {"summary": "This is a summary"}, "content": None, "logprob": None}
|
262 |
+
mock_execute_step.return_value = mock_result
|
263 |
|
264 |
# Create a simple workflow
|
265 |
step = ModelStep(
|
266 |
id="summarize",
|
267 |
name="Summarize Text",
|
268 |
model="gpt-3.5-turbo",
|
269 |
+
provider="OpenAI",
|
270 |
+
call_type=CallType.LLM,
|
271 |
system_prompt="Summarize the text",
|
272 |
input_fields=[InputField(name="text", description="Text to summarize", variable="input_text")],
|
273 |
output_fields=[OutputField(name="summary", description="Summary of the text", type="str")],
|
|
|
276 |
workflow = Workflow(steps={"summarize": step}, inputs=["input_text"], outputs={"summary": "summarize.summary"})
|
277 |
|
278 |
# Execute the workflow
|
279 |
+
result = execute_workflow(workflow, {"input_text": "Long text to be summarized..."})
|
|
|
|
|
280 |
|
281 |
# Verify the results
|
282 |
+
assert_workflow_output(
|
283 |
+
result,
|
284 |
+
{
|
285 |
+
"final_outputs": {"summary": "This is a summary"},
|
286 |
+
"intermediate_outputs": {
|
287 |
+
"input_text": "Long text to be summarized...",
|
288 |
+
"summarize.summary": "This is a summary",
|
289 |
+
},
|
290 |
+
"step_contents": {},
|
291 |
+
"logprob": None,
|
292 |
+
},
|
293 |
+
)
|
294 |
# Verify execute_model_step was called correctly
|
295 |
mock_execute_step.assert_called_once()
|
296 |
|
|
|
300 |
"""Test execution of a multi-step workflow with dependencies."""
|
301 |
|
302 |
# Configure mock to return different values based on the step
|
303 |
+
def side_effect(step, available_vars, return_full_content=False, logprobs=False):
|
304 |
if step.id == "extract":
|
305 |
+
return {"outputs": {"entities": ["Apple", "product"]}, "content": None, "logprob": None}
|
306 |
elif step.id == "analyze":
|
307 |
+
return {"outputs": {"sentiment": "positive"}, "content": None, "logprob": None}
|
308 |
+
return {"outputs": {}, "content": None, "logprob": None}
|
309 |
|
310 |
mock_execute_step.side_effect = side_effect
|
311 |
|
|
|
314 |
id="extract",
|
315 |
name="Extract Entities",
|
316 |
model="gpt-3.5-turbo",
|
317 |
+
provider="OpenAI",
|
318 |
+
call_type=CallType.LLM,
|
319 |
system_prompt="Extract entities",
|
320 |
input_fields=[InputField(name="text", description="Text to analyze", variable="input_text")],
|
321 |
output_fields=[OutputField(name="entities", description="Extracted entities", type="list[str]")],
|
|
|
326 |
id="analyze",
|
327 |
name="Analyze Sentiment",
|
328 |
model="gpt-4",
|
329 |
+
provider="OpenAI",
|
330 |
+
call_type=CallType.LLM,
|
331 |
system_prompt="Analyze sentiment",
|
332 |
input_fields=[InputField(name="entities", description="Entities to analyze", variable="extract.entities")],
|
333 |
output_fields=[OutputField(name="sentiment", description="Sentiment analysis", type="str")],
|
|
|
340 |
)
|
341 |
|
342 |
# Execute the workflow
|
343 |
+
result = execute_workflow(workflow, {"input_text": "Apple is launching a new product tomorrow."})
|
344 |
+
|
345 |
+
assert_workflow_output(
|
346 |
+
result,
|
347 |
+
{
|
348 |
+
"final_outputs": {"entities": ["Apple", "product"], "sentiment": "positive"},
|
349 |
+
"intermediate_outputs": {
|
350 |
+
"input_text": "Apple is launching a new product tomorrow.",
|
351 |
+
"extract.entities": ["Apple", "product"],
|
352 |
+
"analyze.sentiment": "positive",
|
353 |
+
},
|
354 |
+
"step_contents": {},
|
355 |
+
"logprob": None,
|
356 |
+
},
|
357 |
)
|
358 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
359 |
# Verify execute_model_step was called twice (once for each step)
|
360 |
assert mock_execute_step.call_count == 2
|
361 |
|
|
|
366 |
id="summarize",
|
367 |
name="Summarize Text",
|
368 |
model="gpt-3.5-turbo",
|
369 |
+
provider="OpenAI",
|
370 |
+
call_type=CallType.LLM,
|
371 |
system_prompt="Summarize the text",
|
372 |
input_fields=[InputField(name="text", description="Text to summarize", variable="input_text")],
|
373 |
output_fields=[OutputField(name="summary", description="Summary of the text", type="str")],
|
|
|
380 |
execute_workflow(workflow, {})
|
381 |
|
382 |
|
383 |
+
def test_execute_workflow_cyclic_dependency():
|
|
|
384 |
"""Test that a cyclic dependency in the workflow raises an appropriate error."""
|
385 |
# Make create_dependency_graph raise a CyclicDependencyError
|
|
|
386 |
|
387 |
+
step1 = ModelStep(
|
388 |
+
id="t1",
|
389 |
+
name="Test Step 1",
|
390 |
model="gpt-3.5-turbo",
|
391 |
+
provider="OpenAI",
|
392 |
+
call_type=CallType.LLM,
|
393 |
system_prompt="Test",
|
394 |
+
input_fields=[InputField(name="v1", description="", variable="t2.var")],
|
395 |
+
output_fields=[OutputField(name="out", description="")],
|
396 |
+
)
|
397 |
+
step2 = ModelStep(
|
398 |
+
id="t2",
|
399 |
+
name="Test Step 2",
|
400 |
+
model="gpt-3.5-turbo",
|
401 |
+
provider="OpenAI",
|
402 |
+
call_type=CallType.LLM,
|
403 |
+
system_prompt="Test",
|
404 |
+
input_fields=[InputField(name="v2", description="", variable="t1.out")],
|
405 |
+
output_fields=[OutputField(name="var", description="")],
|
406 |
)
|
407 |
|
408 |
+
workflow = Workflow(steps=[step1, step2], inputs=[], outputs={})
|
409 |
|
410 |
# This should propagate the CyclicDependencyError
|
411 |
with pytest.raises(CyclicDependencyError):
|
|
|
416 |
def test_execute_workflow_with_full_content(mock_execute_step):
|
417 |
"""Test execution of a workflow with return_full_content=True."""
|
418 |
# Configure mock to return expected outputs and content
|
419 |
+
mock_result = {
|
420 |
+
"outputs": {"summary": "This is a summary"},
|
421 |
+
"content": "Full model response content",
|
422 |
+
"logprob": None,
|
423 |
+
}
|
424 |
+
mock_execute_step.return_value = mock_result
|
425 |
|
426 |
# Create a simple workflow
|
427 |
step = ModelStep(
|
428 |
id="summarize",
|
429 |
name="Summarize Text",
|
430 |
model="gpt-3.5-turbo",
|
431 |
+
provider="OpenAI",
|
432 |
+
call_type=CallType.LLM,
|
433 |
system_prompt="Summarize the text",
|
434 |
input_fields=[InputField(name="text", description="Text to summarize", variable="input_text")],
|
435 |
output_fields=[OutputField(name="summary", description="Summary of the text", type="str")],
|
|
|
438 |
workflow = Workflow(steps=[step], inputs=["input_text"], outputs={"summary": "summarize.summary"})
|
439 |
|
440 |
# Execute the workflow with return_full_content=True
|
441 |
+
inputs = {"input_text": "Long text to be summarized..."}
|
442 |
+
result = execute_workflow(workflow, inputs, return_full_content=True)
|
443 |
+
|
444 |
+
assert_workflow_output(
|
445 |
+
result,
|
446 |
+
{
|
447 |
+
"final_outputs": {"summary": "This is a summary"},
|
448 |
+
"intermediate_outputs": {
|
449 |
+
"input_text": "Long text to be summarized...",
|
450 |
+
"summarize.summary": "This is a summary",
|
451 |
+
},
|
452 |
+
"step_contents": {"summarize": "Full model response content"},
|
453 |
+
"logprob": None,
|
454 |
+
},
|
455 |
)
|
456 |
|
|
|
|
|
|
|
|
|
|
|
457 |
# Verify execute_model_step was called correctly with return_full_content=True
|
458 |
+
mock_execute_step.assert_called_once_with(step, inputs, return_full_content=True, logprobs=False)
|
459 |
+
|
460 |
+
|
461 |
+
@patch("workflows.executors.execute_model_step")
|
462 |
+
def test_execute_workflow_with_logprob(mock_execute_step):
|
463 |
+
"""Test execution of a workflow with logprob_step specified."""
|
464 |
+
# Configure mock to return expected outputs with logprob
|
465 |
+
mock_result = {"outputs": {"summary": "This is a summary"}, "content": None, "logprob": -2.5}
|
466 |
+
mock_execute_step.return_value = mock_result
|
467 |
+
|
468 |
+
# Create a simple workflow
|
469 |
+
step = ModelStep(
|
470 |
+
id="summarize",
|
471 |
+
name="Summarize Text",
|
472 |
+
model="gpt-3.5-turbo",
|
473 |
+
provider="OpenAI",
|
474 |
+
call_type=CallType.LLM,
|
475 |
+
system_prompt="Summarize the text",
|
476 |
+
input_fields=[InputField(name="text", description="Text to summarize", variable="input_text")],
|
477 |
+
output_fields=[OutputField(name="summary", description="Summary of the text", type="str")],
|
478 |
+
)
|
479 |
+
|
480 |
+
workflow = Workflow(steps={"summarize": step}, inputs=["input_text"], outputs={"summary": "summarize.summary"})
|
481 |
+
|
482 |
+
# Execute the workflow with logprob_step specified
|
483 |
+
result = execute_workflow(workflow, {"input_text": "Long text to be summarized..."}, logprob_step="summarize")
|
484 |
+
|
485 |
+
# Verify the results
|
486 |
+
assert_workflow_output(
|
487 |
+
result,
|
488 |
+
{
|
489 |
+
"final_outputs": {"summary": "This is a summary"},
|
490 |
+
"logprob": -2.5,
|
491 |
+
"intermediate_outputs": {
|
492 |
+
"input_text": "Long text to be summarized...",
|
493 |
+
"summarize.summary": "This is a summary",
|
494 |
+
},
|
495 |
+
"step_contents": {},
|
496 |
+
},
|
497 |
+
)
|
498 |
+
# Verify execute_model_step was called with logprobs=True
|
499 |
+
mock_execute_step.assert_called_once()
|
500 |
+
args, kwargs = mock_execute_step.call_args
|
501 |
+
assert kwargs["logprobs"] is True
|
tests/test_validators.py
CHANGED
@@ -1,21 +1,21 @@
|
|
1 |
-
from typing import
|
2 |
|
3 |
import pytest
|
4 |
from pydantic import ValidationError as PydanticValidationError
|
5 |
|
6 |
-
from workflows.structs import InputField, ModelStep, OutputField, Workflow
|
7 |
-
from workflows.validators import ValidationError, ValidationErrorType, WorkflowValidator
|
8 |
|
9 |
|
10 |
# Test Data
|
11 |
-
def
|
12 |
"""Creates a basic valid step for testing"""
|
13 |
return ModelStep(
|
14 |
id=step_id,
|
15 |
name="Test Step",
|
16 |
model="gpt-4",
|
17 |
provider="openai",
|
18 |
-
call_type=
|
19 |
temperature=0.7,
|
20 |
system_prompt="Test prompt",
|
21 |
input_fields=[],
|
@@ -23,16 +23,32 @@ def create_basic_step(step_id: str = "step1") -> ModelStep:
|
|
23 |
)
|
24 |
|
25 |
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
"""Creates a basic valid workflow for testing"""
|
28 |
if steps is None:
|
29 |
-
steps = [
|
30 |
return Workflow(inputs=[], outputs={}, steps={step.id: step for step in steps})
|
31 |
|
32 |
|
33 |
# Additional Test Data
|
34 |
def create_step_with_fields(
|
35 |
-
step_id: str, input_fields:
|
36 |
) -> ModelStep:
|
37 |
"""Creates a step with specific input and output fields"""
|
38 |
return ModelStep(
|
@@ -40,7 +56,7 @@ def create_step_with_fields(
|
|
40 |
name="Test Step",
|
41 |
model="gpt-4",
|
42 |
provider="openai",
|
43 |
-
call_type=
|
44 |
temperature=0.7,
|
45 |
system_prompt="Test prompt",
|
46 |
input_fields=input_fields,
|
@@ -117,15 +133,15 @@ class TestStepValidation:
|
|
117 |
name="", # Missing name
|
118 |
model="", # Missing model
|
119 |
provider="", # Missing provider
|
120 |
-
call_type=
|
121 |
temperature=0.7,
|
122 |
system_prompt="Test prompt",
|
123 |
input_fields=[],
|
124 |
output_fields=[],
|
125 |
)
|
126 |
workflow = create_basic_workflow([step])
|
127 |
-
workflow.inputs = ["
|
128 |
-
workflow.outputs = {"output": "step1.
|
129 |
assert not validator.validate(workflow)
|
130 |
assert len(validator.errors) == 1
|
131 |
assert validator.errors[0].error_type == ValidationErrorType.STEP
|
@@ -135,32 +151,34 @@ class TestStepValidation:
|
|
135 |
validator = WorkflowValidator()
|
136 |
step = create_basic_step("123invalid") # Invalid ID format
|
137 |
workflow = create_basic_workflow([step])
|
138 |
-
workflow.inputs = ["
|
139 |
-
workflow.outputs = {"output": "step1.
|
140 |
assert not validator.validate(workflow)
|
141 |
assert len(validator.errors) == 1
|
142 |
assert validator.errors[0].error_type == ValidationErrorType.NAMING
|
143 |
|
144 |
-
def
|
145 |
"""Test validation of LLM step temperature"""
|
146 |
validator = WorkflowValidator()
|
147 |
|
148 |
# Test invalid temperature
|
149 |
step = create_basic_step()
|
150 |
-
step.temperature =
|
151 |
workflow = create_basic_workflow([step])
|
152 |
-
workflow.inputs = ["
|
153 |
-
workflow.outputs = {"output": "step1.
|
154 |
assert not validator.validate(workflow)
|
155 |
assert len(validator.errors) == 1
|
156 |
assert validator.errors[0].error_type == ValidationErrorType.RANGE
|
157 |
|
|
|
158 |
# Test missing temperature
|
|
|
159 |
step = create_basic_step()
|
160 |
step.temperature = None # Missing temperature
|
161 |
workflow = create_basic_workflow([step])
|
162 |
-
workflow.inputs = ["
|
163 |
-
workflow.outputs = {"output": "step1.
|
164 |
assert not validator.validate(workflow)
|
165 |
assert len(validator.errors) == 1
|
166 |
assert validator.errors[0].error_type == ValidationErrorType.STEP
|
@@ -173,8 +191,8 @@ class TestStepValidation:
|
|
173 |
step = create_basic_step()
|
174 |
step.system_prompt = "" # Missing system prompt
|
175 |
workflow = create_basic_workflow([step])
|
176 |
-
workflow.inputs = ["
|
177 |
-
workflow.outputs = {"output": "step1.
|
178 |
assert not validator.validate(workflow)
|
179 |
assert len(validator.errors) == 1
|
180 |
assert validator.errors[0].error_type == ValidationErrorType.STEP
|
@@ -183,8 +201,8 @@ class TestStepValidation:
|
|
183 |
step = create_basic_step()
|
184 |
step.system_prompt = "x" * 4001 # Too long
|
185 |
workflow = create_basic_workflow([step])
|
186 |
-
workflow.inputs = ["
|
187 |
-
workflow.outputs = {"output": "step1.
|
188 |
assert not validator.validate(workflow)
|
189 |
assert len(validator.errors) == 1
|
190 |
assert validator.errors[0].error_type == ValidationErrorType.LENGTH
|
@@ -477,39 +495,6 @@ class TestTypeCompatibility:
|
|
477 |
workflow.outputs = {"output": "step2.output"}
|
478 |
assert validator.validate(workflow)
|
479 |
|
480 |
-
# def test_list_type_compatibility(self):
|
481 |
-
# """Test validation of list type compatibility"""
|
482 |
-
# validator = WorkflowValidator()
|
483 |
-
|
484 |
-
# # Test compatible list types
|
485 |
-
# step1 = create_step_with_fields(
|
486 |
-
# "step1", [], [OutputField(name="output", description="test", type="list[str]")]
|
487 |
-
# )
|
488 |
-
# step2 = create_step_with_fields(
|
489 |
-
# "step2", [InputField(name="input", description="test", variable="step1.output")], []
|
490 |
-
# )
|
491 |
-
|
492 |
-
# workflow = create_basic_workflow([step1, step2])
|
493 |
-
# workflow.inputs = ["input"]
|
494 |
-
# workflow.outputs = {"output": "step2.output"}
|
495 |
-
# assert validator.validate(workflow)
|
496 |
-
# assert len(validator.errors) == 0
|
497 |
-
|
498 |
-
# # Test incompatible list types
|
499 |
-
# step1 = create_step_with_fields(
|
500 |
-
# "step1", [], [OutputField(name="output", description="test", type="list[int]")]
|
501 |
-
# )
|
502 |
-
# step2 = create_step_with_fields(
|
503 |
-
# "step2", [InputField(name="input", description="test", variable="step1.output")], []
|
504 |
-
# )
|
505 |
-
|
506 |
-
# workflow = create_basic_workflow([step1, step2])
|
507 |
-
# workflow.inputs = ["input"]
|
508 |
-
# workflow.outputs = {"output": "step2.output"}
|
509 |
-
# assert not validator.validate(workflow)
|
510 |
-
# assert len(validator.errors) == 1
|
511 |
-
# assert validator.errors[0].error_type == ValidationErrorType.TYPE
|
512 |
-
|
513 |
|
514 |
# Complex Workflow Tests
|
515 |
class TestComplexWorkflows:
|
@@ -569,6 +554,92 @@ class TestComplexWorkflows:
|
|
569 |
assert len(validator.errors) == 0
|
570 |
|
571 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
572 |
# External Input Tests
|
573 |
class TestExternalInputs:
|
574 |
def test_external_input_existence(self):
|
@@ -645,3 +716,48 @@ class TestEdgeCases:
|
|
645 |
assert not validator.validate(workflow)
|
646 |
assert len(validator.errors) == 1
|
647 |
assert validator.errors[0].error_type == ValidationErrorType.STEP
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any
|
2 |
|
3 |
import pytest
|
4 |
from pydantic import ValidationError as PydanticValidationError
|
5 |
|
6 |
+
from workflows.structs import CallType, InputField, ModelStep, OutputField, Workflow
|
7 |
+
from workflows.validators import ValidationError, ValidationErrorType, WorkflowValidator, _parse_variable_reference
|
8 |
|
9 |
|
10 |
# Test Data
|
11 |
+
def create_empty_step(step_id: str = "step1") -> ModelStep:
|
12 |
"""Creates a basic valid step for testing"""
|
13 |
return ModelStep(
|
14 |
id=step_id,
|
15 |
name="Test Step",
|
16 |
model="gpt-4",
|
17 |
provider="openai",
|
18 |
+
call_type=CallType.LLM,
|
19 |
temperature=0.7,
|
20 |
system_prompt="Test prompt",
|
21 |
input_fields=[],
|
|
|
23 |
)
|
24 |
|
25 |
|
26 |
+
# Test Data
|
27 |
+
def create_basic_step(step_id: str = "step1") -> ModelStep:
|
28 |
+
"""Creates a basic valid step for testing"""
|
29 |
+
return ModelStep(
|
30 |
+
id=step_id,
|
31 |
+
name="Test Step",
|
32 |
+
model="gpt-4",
|
33 |
+
provider="openai",
|
34 |
+
call_type=CallType.LLM,
|
35 |
+
temperature=0.7,
|
36 |
+
system_prompt="Test prompt",
|
37 |
+
input_fields=[InputField(name="input", description="test", variable="external_input")],
|
38 |
+
output_fields=[OutputField(name="output", description="test", type="str")],
|
39 |
+
)
|
40 |
+
|
41 |
+
|
42 |
+
def create_basic_workflow(steps: list[ModelStep] | None = None) -> Workflow:
|
43 |
"""Creates a basic valid workflow for testing"""
|
44 |
if steps is None:
|
45 |
+
steps = [create_empty_step()]
|
46 |
return Workflow(inputs=[], outputs={}, steps={step.id: step for step in steps})
|
47 |
|
48 |
|
49 |
# Additional Test Data
|
50 |
def create_step_with_fields(
|
51 |
+
step_id: str, input_fields: list[InputField], output_fields: list[OutputField]
|
52 |
) -> ModelStep:
|
53 |
"""Creates a step with specific input and output fields"""
|
54 |
return ModelStep(
|
|
|
56 |
name="Test Step",
|
57 |
model="gpt-4",
|
58 |
provider="openai",
|
59 |
+
call_type=CallType.LLM,
|
60 |
temperature=0.7,
|
61 |
system_prompt="Test prompt",
|
62 |
input_fields=input_fields,
|
|
|
133 |
name="", # Missing name
|
134 |
model="", # Missing model
|
135 |
provider="", # Missing provider
|
136 |
+
call_type=CallType.LLM, # Missing call_type
|
137 |
temperature=0.7,
|
138 |
system_prompt="Test prompt",
|
139 |
input_fields=[],
|
140 |
output_fields=[],
|
141 |
)
|
142 |
workflow = create_basic_workflow([step])
|
143 |
+
workflow.inputs = ["external_input"]
|
144 |
+
workflow.outputs = {"output": "step1.output"}
|
145 |
assert not validator.validate(workflow)
|
146 |
assert len(validator.errors) == 1
|
147 |
assert validator.errors[0].error_type == ValidationErrorType.STEP
|
|
|
151 |
validator = WorkflowValidator()
|
152 |
step = create_basic_step("123invalid") # Invalid ID format
|
153 |
workflow = create_basic_workflow([step])
|
154 |
+
workflow.inputs = ["external_input"]
|
155 |
+
workflow.outputs = {"output": "step1.output"}
|
156 |
assert not validator.validate(workflow)
|
157 |
assert len(validator.errors) == 1
|
158 |
assert validator.errors[0].error_type == ValidationErrorType.NAMING
|
159 |
|
160 |
+
def test_llm_temperature_validation_invalid(self):
|
161 |
"""Test validation of LLM step temperature"""
|
162 |
validator = WorkflowValidator()
|
163 |
|
164 |
# Test invalid temperature
|
165 |
step = create_basic_step()
|
166 |
+
step.temperature = -0.5 # Invalid temperature
|
167 |
workflow = create_basic_workflow([step])
|
168 |
+
workflow.inputs = ["external_input"]
|
169 |
+
workflow.outputs = {"output": "step1.output"}
|
170 |
assert not validator.validate(workflow)
|
171 |
assert len(validator.errors) == 1
|
172 |
assert validator.errors[0].error_type == ValidationErrorType.RANGE
|
173 |
|
174 |
+
def test_llm_temperature_validation_missing(self):
|
175 |
# Test missing temperature
|
176 |
+
validator = WorkflowValidator()
|
177 |
step = create_basic_step()
|
178 |
step.temperature = None # Missing temperature
|
179 |
workflow = create_basic_workflow([step])
|
180 |
+
workflow.inputs = ["external_input"]
|
181 |
+
workflow.outputs = {"output": "step1.output"}
|
182 |
assert not validator.validate(workflow)
|
183 |
assert len(validator.errors) == 1
|
184 |
assert validator.errors[0].error_type == ValidationErrorType.STEP
|
|
|
191 |
step = create_basic_step()
|
192 |
step.system_prompt = "" # Missing system prompt
|
193 |
workflow = create_basic_workflow([step])
|
194 |
+
workflow.inputs = ["external_input"]
|
195 |
+
workflow.outputs = {"output": "step1.output"}
|
196 |
assert not validator.validate(workflow)
|
197 |
assert len(validator.errors) == 1
|
198 |
assert validator.errors[0].error_type == ValidationErrorType.STEP
|
|
|
201 |
step = create_basic_step()
|
202 |
step.system_prompt = "x" * 4001 # Too long
|
203 |
workflow = create_basic_workflow([step])
|
204 |
+
workflow.inputs = ["external_input"]
|
205 |
+
workflow.outputs = {"output": "step1.output"}
|
206 |
assert not validator.validate(workflow)
|
207 |
assert len(validator.errors) == 1
|
208 |
assert validator.errors[0].error_type == ValidationErrorType.LENGTH
|
|
|
495 |
workflow.outputs = {"output": "step2.output"}
|
496 |
assert validator.validate(workflow)
|
497 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
498 |
|
499 |
# Complex Workflow Tests
|
500 |
class TestComplexWorkflows:
|
|
|
554 |
assert len(validator.errors) == 0
|
555 |
|
556 |
|
557 |
+
# Log Probability Validation Tests
|
558 |
+
class TestLogProbabilityValidation:
|
559 |
+
def test_logprob_step_validation(self):
|
560 |
+
"""Test validation of log probability step references"""
|
561 |
+
validator = WorkflowValidator()
|
562 |
+
|
563 |
+
# Create a workflow with multiple steps
|
564 |
+
step1 = create_step_with_fields(
|
565 |
+
"step1",
|
566 |
+
[InputField(name="input", description="test", variable="external_input")],
|
567 |
+
[OutputField(name="output", description="test", type="str")],
|
568 |
+
)
|
569 |
+
step2 = create_step_with_fields(
|
570 |
+
"step2",
|
571 |
+
[InputField(name="input", description="test", variable="step1.output")],
|
572 |
+
[OutputField(name="output", description="test", type="str")],
|
573 |
+
)
|
574 |
+
|
575 |
+
workflow = create_basic_workflow([step1, step2])
|
576 |
+
workflow.inputs = ["external_input"]
|
577 |
+
workflow.outputs = {"output": "step2.output"}
|
578 |
+
|
579 |
+
# Validate the workflow first
|
580 |
+
assert validator.validate(workflow)
|
581 |
+
validator.errors = [] # Clear any previous errors
|
582 |
+
|
583 |
+
# Test that a valid step ID is accepted
|
584 |
+
valid_logprob_step = "step1"
|
585 |
+
assert valid_logprob_step in workflow.steps
|
586 |
+
# A validator for logprob_step would check if the step exists in workflow.steps
|
587 |
+
|
588 |
+
# Test that an invalid step ID is caught
|
589 |
+
invalid_logprob_step = "nonexistent_step"
|
590 |
+
assert invalid_logprob_step not in workflow.steps
|
591 |
+
# A validator for logprob_step would report an error for a non-existent step
|
592 |
+
|
593 |
+
|
594 |
+
# Output Structure Tests
|
595 |
+
class TestOutputStructure:
|
596 |
+
def test_workflow_output_structure(self):
|
597 |
+
"""Test the expected structure of workflow outputs"""
|
598 |
+
# Sample output dictionary matching WorkflowOutput structure
|
599 |
+
output: dict[str, dict | None] = {
|
600 |
+
"final_outputs": {},
|
601 |
+
"intermediate_outputs": {},
|
602 |
+
"step_contents": {},
|
603 |
+
"logprob": None,
|
604 |
+
}
|
605 |
+
|
606 |
+
# Verify that all expected keys are present
|
607 |
+
assert "final_outputs" in output
|
608 |
+
assert "intermediate_outputs" in output
|
609 |
+
assert "step_contents" in output
|
610 |
+
assert "logprob" in output
|
611 |
+
|
612 |
+
# Test with populated values
|
613 |
+
output = {
|
614 |
+
"final_outputs": {"output": "result"},
|
615 |
+
"intermediate_outputs": {"step1.output": "result", "input": "value"},
|
616 |
+
"step_contents": {"step1": "Full content"},
|
617 |
+
"logprob": -2.5,
|
618 |
+
}
|
619 |
+
|
620 |
+
assert output["final_outputs"] == {"output": "result"}
|
621 |
+
assert output["intermediate_outputs"]["step1.output"] == "result"
|
622 |
+
assert output["step_contents"]["step1"] == "Full content"
|
623 |
+
assert output["logprob"] == -2.5
|
624 |
+
|
625 |
+
def test_model_step_result_structure(self):
|
626 |
+
"""Test the expected structure of model step results"""
|
627 |
+
# Sample result dictionary matching ModelStepResult structure
|
628 |
+
result: dict[str, Any] = {"outputs": {}, "content": None, "logprob": None}
|
629 |
+
|
630 |
+
# Verify that all expected keys are present
|
631 |
+
assert "outputs" in result
|
632 |
+
assert "content" in result
|
633 |
+
assert "logprob" in result
|
634 |
+
|
635 |
+
# Test with populated values
|
636 |
+
result = {"outputs": {"field": "value"}, "content": "Full response", "logprob": -1.5}
|
637 |
+
|
638 |
+
assert result["outputs"] == {"field": "value"}
|
639 |
+
assert result["content"] == "Full response"
|
640 |
+
assert result["logprob"] == -1.5
|
641 |
+
|
642 |
+
|
643 |
# External Input Tests
|
644 |
class TestExternalInputs:
|
645 |
def test_external_input_existence(self):
|
|
|
716 |
assert not validator.validate(workflow)
|
717 |
assert len(validator.errors) == 1
|
718 |
assert validator.errors[0].error_type == ValidationErrorType.STEP
|
719 |
+
|
720 |
+
|
721 |
+
# Extended validator tests for actual implementation
|
722 |
+
class TestExtendedValidation:
|
723 |
+
def test_parse_variable_reference(self):
|
724 |
+
"""Test the _parse_variable_reference method"""
|
725 |
+
validator = WorkflowValidator()
|
726 |
+
|
727 |
+
# Test external input reference
|
728 |
+
step_id, field_name = _parse_variable_reference("input_var")
|
729 |
+
assert step_id is None
|
730 |
+
assert field_name == "input_var"
|
731 |
+
|
732 |
+
# Test step output reference
|
733 |
+
step_id, field_name = _parse_variable_reference("step1.output")
|
734 |
+
assert step_id == "step1"
|
735 |
+
assert field_name == "output"
|
736 |
+
|
737 |
+
def test_is_valid_identifier(self):
|
738 |
+
"""Test the _is_valid_identifier method"""
|
739 |
+
validator = WorkflowValidator()
|
740 |
+
|
741 |
+
# Valid identifiers
|
742 |
+
assert validator._is_valid_identifier("valid_name")
|
743 |
+
assert validator._is_valid_identifier("ValidName")
|
744 |
+
assert validator._is_valid_identifier("name123")
|
745 |
+
|
746 |
+
# Invalid identifiers
|
747 |
+
assert not validator._is_valid_identifier("") # Empty
|
748 |
+
assert not validator._is_valid_identifier(" ") # Whitespace
|
749 |
+
assert not validator._is_valid_identifier("123name") # Starts with number
|
750 |
+
assert not validator._is_valid_identifier("name-with-hyphens") # Has hyphens
|
751 |
+
assert not validator._is_valid_identifier("name.with.dots") # Has dots
|
752 |
+
|
753 |
+
def test_is_valid_external_input(self):
|
754 |
+
"""Test the _is_valid_external_input method"""
|
755 |
+
validator = WorkflowValidator()
|
756 |
+
|
757 |
+
# Valid external inputs
|
758 |
+
assert validator._is_valid_external_input("input_var")
|
759 |
+
|
760 |
+
# Invalid external inputs
|
761 |
+
assert not validator._is_valid_external_input("") # Empty
|
762 |
+
assert not validator._is_valid_external_input("input.var") # Contains dot
|
763 |
+
assert not validator._is_valid_external_input("123input") # Starts with number
|