Spaces:
Running
Running
Maharshi Gor
commited on
Commit
·
193db9d
1
Parent(s):
a562808
First Working commit
Browse files- .gitignore +23 -0
- Makefile +13 -0
- app.py +125 -0
- pyproject.toml +13 -0
- requirements.txt +27 -0
- src/components/__init__.py +1 -0
- src/components/model_pipeline/__init__.py +0 -0
- src/components/model_pipeline/model_pipeline.py +291 -0
- src/components/model_pipeline/state_manager.py +180 -0
- src/components/model_step/__init__.py +0 -0
- src/components/model_step/model_step.py +477 -0
- src/components/model_step/state_manager.py +152 -0
- src/components/model_step/ui_components.py +91 -0
- src/components/quizbowl/__init__.py +0 -0
- src/components/quizbowl/bonus.py +399 -0
- src/components/quizbowl/plotting.py +194 -0
- src/components/quizbowl/tossup.py +426 -0
- src/components/quizbowl/utils.py +86 -0
- src/components/utils.py +29 -0
- src/display/__init__.py +0 -0
- src/display/css_html_js.py +122 -0
- src/display/custom_css.py +413 -0
- src/display/formatting.py +27 -0
- src/display/utils.py +110 -0
- src/envs.py +86 -0
- src/submission/structs.py +58 -0
- src/submission/submit.py +170 -0
- src/utils.py +38 -0
- src/workflows/README.md +92 -0
- src/workflows/errors.py +63 -0
- src/workflows/executors.py +440 -0
- src/workflows/factory.py +150 -0
- src/workflows/qb/__init__.py +0 -0
- src/workflows/qb/simple_agent.py +194 -0
- src/workflows/quizbowl_agent.py +269 -0
- src/workflows/structs.py +229 -0
- src/workflows/utils.py +161 -0
- src/workflows/validators.py +586 -0
- tests/conftest.py +5 -0
- tests/test_executors.py +295 -0
- tests/test_utils.py +159 -0
- tests/test_validators.py +647 -0
.gitignore
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Purpose: Ignore certain files and directories in git
|
2 |
+
*.pyc
|
3 |
+
*.pyi
|
4 |
+
*.pyo
|
5 |
+
*.pyd
|
6 |
+
*.pyw
|
7 |
+
*.pyz
|
8 |
+
*.pywz
|
9 |
+
*.pywz
|
10 |
+
|
11 |
+
auto_evals/
|
12 |
+
venv/
|
13 |
+
__pycache__/
|
14 |
+
.env
|
15 |
+
.ipynb_checkpoints
|
16 |
+
*ipynb
|
17 |
+
.vscode/
|
18 |
+
|
19 |
+
eval-queue/
|
20 |
+
eval-results/
|
21 |
+
eval-queue-bk/
|
22 |
+
eval-results-bk/
|
23 |
+
logs/
|
Makefile
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.PHONY: style format
|
2 |
+
|
3 |
+
|
4 |
+
style:
|
5 |
+
python -m black --line-length 119 .
|
6 |
+
python -m isort .
|
7 |
+
ruff check --fix .
|
8 |
+
|
9 |
+
|
10 |
+
quality:
|
11 |
+
python -m black --check --line-length 119 .
|
12 |
+
python -m isort --check-only .
|
13 |
+
ruff check .
|
app.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datasets
|
2 |
+
import gradio as gr
|
3 |
+
|
4 |
+
from components.quizbowl.bonus import BonusInterface
|
5 |
+
from components.quizbowl.tossup import TossupInterface
|
6 |
+
from display.custom_css import css_pipeline, css_tossup
|
7 |
+
|
8 |
+
# Constants
|
9 |
+
from envs import AVAILABLE_MODELS, DEFAULT_SELECTIONS, PLAYGROUND_DATASET_NAMES, THEME
|
10 |
+
from workflows import factory
|
11 |
+
|
12 |
+
js_preamble = """
|
13 |
+
<link href="https://fonts.cdnfonts.com/css/roboto-mono" rel="stylesheet">
|
14 |
+
|
15 |
+
<script>
|
16 |
+
const gradioApp = document.getElementsByTagName('gradio-app')[0];
|
17 |
+
console.log("Gradio app:", gradioApp);
|
18 |
+
console.log(gradioApp.querySelectorAll('.token'));
|
19 |
+
console.log(document.querySelectorAll('.token'));
|
20 |
+
|
21 |
+
// Function to trigger Python callback
|
22 |
+
const setHiddenIndex = (index) => {
|
23 |
+
console.log("Setting hidden index to:", index);
|
24 |
+
const hiddenIndex = gradioApp.querySelector("#hidden-index textarea");
|
25 |
+
if (hiddenIndex) {
|
26 |
+
hiddenIndex.value = index;
|
27 |
+
let event = new Event("input", { bubbles: true});
|
28 |
+
Object.defineProperty(event, "target", { value: hiddenIndex});
|
29 |
+
hiddenIndex.dispatchEvent(event);
|
30 |
+
}
|
31 |
+
};
|
32 |
+
|
33 |
+
// Add event listeners to all tokens
|
34 |
+
function setupTokenListeners() {
|
35 |
+
const tokens = gradioApp.querySelectorAll('.token');
|
36 |
+
console.log("Tokens:", tokens);
|
37 |
+
tokens.forEach(token => {
|
38 |
+
token.addEventListener('mouseover', function() {
|
39 |
+
const index = parseInt(this.getAttribute('data-index'));
|
40 |
+
console.log("Mouseover token index:", index);
|
41 |
+
|
42 |
+
// Reset all tokens
|
43 |
+
gradioApp.querySelectorAll('.token').forEach(el => {
|
44 |
+
el.classList.remove('highlighted');
|
45 |
+
});
|
46 |
+
|
47 |
+
// Highlight this token
|
48 |
+
this.classList.add('highlighted');
|
49 |
+
|
50 |
+
// Update the hidden index to trigger the Python callback
|
51 |
+
setHiddenIndex(index);
|
52 |
+
});
|
53 |
+
});
|
54 |
+
}
|
55 |
+
console.log("Preamble complete");
|
56 |
+
|
57 |
+
document.addEventListener("DOMContentLoaded", function() {
|
58 |
+
// Setup initial listeners
|
59 |
+
console.log("DOM fully loaded and parsed");
|
60 |
+
setupTokenListeners();
|
61 |
+
|
62 |
+
// Setup a mutation observer to handle dynamically added tokens
|
63 |
+
const observer = new MutationObserver(function(mutations) {
|
64 |
+
mutations.forEach(function(mutation) {
|
65 |
+
if (mutation.addedNodes.length) {
|
66 |
+
setupTokenListeners();
|
67 |
+
}
|
68 |
+
});
|
69 |
+
});
|
70 |
+
|
71 |
+
// Start observing the token container for changes
|
72 |
+
const tokenContainer = gradioApp.querySelector('.token-container');
|
73 |
+
console.log("Token container:", tokenContainer);
|
74 |
+
if (tokenContainer) {
|
75 |
+
observer.observe(tokenContainer.parentNode, { childList: true, subtree: true });
|
76 |
+
}
|
77 |
+
console.log("Listener setup complete");
|
78 |
+
});
|
79 |
+
</script>
|
80 |
+
"""
|
81 |
+
|
82 |
+
|
83 |
+
def load_dataset(mode: str):
|
84 |
+
if mode == "tossup":
|
85 |
+
ds = datasets.load_dataset(PLAYGROUND_DATASET_NAMES["tossup"], split="eval")
|
86 |
+
ds = ds.filter(lambda x: x["qid"].split("-")[2] == "1" and int(x["qid"].split("-")[3]) <= 10)
|
87 |
+
elif mode == "bonus":
|
88 |
+
ds = datasets.load_dataset(PLAYGROUND_DATASET_NAMES["bonus"], split="eval")
|
89 |
+
ds = ds.filter(lambda x: x["qid"].split("-")[2] == "1" and int(x["qid"].split("-")[3]) <= 10)
|
90 |
+
else:
|
91 |
+
raise ValueError(f"Invalid mode: {mode}")
|
92 |
+
|
93 |
+
return ds
|
94 |
+
|
95 |
+
|
96 |
+
def main():
|
97 |
+
tossup_ds = load_dataset("tossup")
|
98 |
+
bonus_ds = load_dataset("bonus")
|
99 |
+
app = gr.Blocks(
|
100 |
+
css=css_pipeline + css_tossup,
|
101 |
+
head=js_preamble,
|
102 |
+
theme=THEME,
|
103 |
+
title="Quizbowl Bot",
|
104 |
+
)
|
105 |
+
with app:
|
106 |
+
with gr.Tabs():
|
107 |
+
with gr.Tab("Tossup Agents"):
|
108 |
+
defaults = DEFAULT_SELECTIONS["tossup"] | {
|
109 |
+
"init_workflow": factory.create_quizbowl_simple_workflow(),
|
110 |
+
"simple_workflow": False,
|
111 |
+
}
|
112 |
+
tossup_interface = TossupInterface(app, tossup_ds, AVAILABLE_MODELS, defaults)
|
113 |
+
# ModelStepComponent(value=factory.create_quizbowl_simple_step())
|
114 |
+
with gr.Tab("Bonus Round Agents"):
|
115 |
+
defaults = DEFAULT_SELECTIONS["bonus"] | {
|
116 |
+
"init_workflow": factory.create_quizbowl_bonus_simple_workflow(),
|
117 |
+
"simple_workflow": True,
|
118 |
+
}
|
119 |
+
bonus_interface = BonusInterface(app, bonus_ds, AVAILABLE_MODELS, defaults)
|
120 |
+
|
121 |
+
app.queue(api_open=True).launch()
|
122 |
+
|
123 |
+
|
124 |
+
if __name__ == "__main__":
|
125 |
+
main()
|
pyproject.toml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[tool.ruff]
|
2 |
+
# Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default.
|
3 |
+
select = ["E", "F"]
|
4 |
+
ignore = ["E501"] # line too long (black is taking care of this)
|
5 |
+
line-length = 119
|
6 |
+
fixable = ["A", "B", "C", "D", "E", "F", "G", "I", "N", "Q", "S", "T", "W", "ANN", "ARG", "BLE", "COM", "DJ", "DTZ", "EM", "ERA", "EXE", "FBT", "ICN", "INP", "ISC", "NPY", "PD", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP", "YTT"]
|
7 |
+
|
8 |
+
[tool.isort]
|
9 |
+
profile = "black"
|
10 |
+
line_length = 119
|
11 |
+
|
12 |
+
[tool.black]
|
13 |
+
line-length = 119
|
requirements.txt
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
APScheduler
|
2 |
+
black
|
3 |
+
datasets
|
4 |
+
gradio
|
5 |
+
modelscope_studio
|
6 |
+
gradio[oauth]
|
7 |
+
gradio_leaderboard
|
8 |
+
gradio_client
|
9 |
+
huggingface-hub>=0.18.0
|
10 |
+
matplotlib
|
11 |
+
numpy<2.0.0
|
12 |
+
pandas
|
13 |
+
python-dateutil
|
14 |
+
tqdm
|
15 |
+
transformers
|
16 |
+
tokenizers>=0.15.0
|
17 |
+
sentencepiece
|
18 |
+
litellm
|
19 |
+
openai
|
20 |
+
anthropic
|
21 |
+
cohere
|
22 |
+
langchain
|
23 |
+
langchain-core
|
24 |
+
langchain-community
|
25 |
+
langchain-anthropic
|
26 |
+
langchain-openai
|
27 |
+
langchain-cohere
|
src/components/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Components package
|
src/components/model_pipeline/__init__.py
ADDED
File without changes
|
src/components/model_pipeline/model_pipeline.py
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
import yaml
|
5 |
+
|
6 |
+
from components.model_pipeline.state_manager import (
|
7 |
+
ModelStepUIState,
|
8 |
+
PipelineState,
|
9 |
+
PipelineStateManager,
|
10 |
+
PipelineUIState,
|
11 |
+
)
|
12 |
+
from components.model_step.model_step import ModelStepComponent
|
13 |
+
from components.utils import make_state
|
14 |
+
from workflows.structs import ModelStep, Workflow
|
15 |
+
from workflows.validators import WorkflowValidator
|
16 |
+
|
17 |
+
|
18 |
+
def validate_simple_workflow(workflow: Workflow, required_output_variables: list[str]) -> Workflow:
|
19 |
+
"""Validate the workflow."""
|
20 |
+
step = next(iter(workflow.steps.values()))
|
21 |
+
if not step.output_fields:
|
22 |
+
raise ValueError("No output fields found in the workflow")
|
23 |
+
output_field_names = {output.name for output in step.output_fields}
|
24 |
+
if not set(required_output_variables) <= output_field_names:
|
25 |
+
missing_vars = required_output_variables - output_field_names
|
26 |
+
raise ValueError(f"Missing required output variables: {missing_vars}")
|
27 |
+
return workflow
|
28 |
+
|
29 |
+
|
30 |
+
def validate_complex_workflow(workflow: Workflow, required_output_variables: list[str]) -> Workflow:
|
31 |
+
"""Validate the workflow."""
|
32 |
+
print("Validating complex workflow.")
|
33 |
+
return workflow
|
34 |
+
step = next(iter(workflow.steps.values()))
|
35 |
+
if not step.output_fields:
|
36 |
+
raise ValueError("No output fields found in the workflow")
|
37 |
+
output_field_names = {output.name for output in step.output_fields}
|
38 |
+
if not output_field_names <= set(required_output_variables):
|
39 |
+
missing_vars = output_field_names - set(required_output_variables)
|
40 |
+
raise ValueError(f"Missing required output variables: {missing_vars}")
|
41 |
+
return workflow
|
42 |
+
|
43 |
+
|
44 |
+
def parse_yaml_workflow(yaml_str: str) -> Workflow:
|
45 |
+
"""Parse a YAML workflow."""
|
46 |
+
workflow = yaml.safe_load(yaml_str)
|
47 |
+
return Workflow(**workflow)
|
48 |
+
|
49 |
+
|
50 |
+
def update_workflow_from_code(yaml_str: str, ui_state: PipelineUIState) -> PipelineState:
|
51 |
+
"""Update a workflow from a YAML string."""
|
52 |
+
workflow = parse_yaml_workflow(yaml_str)
|
53 |
+
ui_state = PipelineUIState.from_workflow(workflow)
|
54 |
+
return PipelineState(workflow=workflow, ui_state=ui_state)
|
55 |
+
|
56 |
+
|
57 |
+
class PipelineInterface:
|
58 |
+
"""UI for the pipeline."""
|
59 |
+
|
60 |
+
def __init__(
|
61 |
+
self,
|
62 |
+
workflow: Workflow,
|
63 |
+
ui_state: PipelineUIState | None = None,
|
64 |
+
model_options: list[str] = None,
|
65 |
+
simple: bool = False,
|
66 |
+
):
|
67 |
+
self.model_options = model_options
|
68 |
+
self.simple = simple
|
69 |
+
if not ui_state:
|
70 |
+
ui_state = PipelineUIState.from_workflow(workflow)
|
71 |
+
self.ui_state = make_state(ui_state)
|
72 |
+
self.pipeline_state = make_state(PipelineState(workflow=workflow, ui_state=ui_state))
|
73 |
+
self.variables_state = make_state(workflow.get_available_variables())
|
74 |
+
|
75 |
+
self.sm = PipelineStateManager()
|
76 |
+
self.input_variables = workflow.inputs
|
77 |
+
self.required_output_variables = list(workflow.outputs.keys())
|
78 |
+
|
79 |
+
# UI elements
|
80 |
+
self.steps_container = None
|
81 |
+
self.components = []
|
82 |
+
|
83 |
+
# Render the pipeline UI
|
84 |
+
self.render()
|
85 |
+
|
86 |
+
def _render_step(
|
87 |
+
self,
|
88 |
+
model_step: ModelStep,
|
89 |
+
step_ui_state: ModelStepUIState,
|
90 |
+
available_variables: list[str],
|
91 |
+
position: int = 0,
|
92 |
+
):
|
93 |
+
with gr.Column(elem_classes="step-container"):
|
94 |
+
# Create the step component
|
95 |
+
step_interface = ModelStepComponent(
|
96 |
+
value=model_step,
|
97 |
+
ui_state=step_ui_state,
|
98 |
+
model_options=self.model_options,
|
99 |
+
input_variables=available_variables,
|
100 |
+
pipeline_state_manager=self.sm,
|
101 |
+
)
|
102 |
+
|
103 |
+
step_interface.on_model_step_change(
|
104 |
+
self.sm.update_model_step_state,
|
105 |
+
inputs=[self.pipeline_state, step_interface.model_step_state, step_interface.ui_state],
|
106 |
+
outputs=[self.pipeline_state, self.ui_state, self.variables_state],
|
107 |
+
)
|
108 |
+
|
109 |
+
step_interface.on_ui_change(
|
110 |
+
self.sm.update_model_step_ui,
|
111 |
+
inputs=[self.pipeline_state, step_interface.ui_state, gr.State(model_step.id)],
|
112 |
+
outputs=[self.pipeline_state, self.ui_state],
|
113 |
+
)
|
114 |
+
|
115 |
+
if self.simple:
|
116 |
+
return step_interface
|
117 |
+
|
118 |
+
# Add step controls below
|
119 |
+
with gr.Row(elem_classes="step-controls"):
|
120 |
+
up_button = gr.Button("⬆️ Move Up", elem_classes="step-control-btn")
|
121 |
+
down_button = gr.Button("⬇️ Move Down", elem_classes="step-control-btn")
|
122 |
+
remove_button = gr.Button("🗑️ Remove", elem_classes="step-control-btn")
|
123 |
+
|
124 |
+
buttons = (up_button, down_button, remove_button)
|
125 |
+
self._assign_step_controls(buttons, position)
|
126 |
+
|
127 |
+
return (step_interface, *buttons)
|
128 |
+
|
129 |
+
def _assign_step_controls(self, buttons: tuple[gr.Button, gr.Button, gr.Button], position: int):
|
130 |
+
up_button, down_button, remove_button = buttons
|
131 |
+
position = gr.State(position)
|
132 |
+
up_button.click(self.sm.move_up, inputs=[self.ui_state, position], outputs=self.ui_state)
|
133 |
+
down_button.click(self.sm.move_down, inputs=[self.ui_state, position], outputs=self.ui_state)
|
134 |
+
remove_button.click(
|
135 |
+
self.sm.remove_step,
|
136 |
+
inputs=[self.pipeline_state, position],
|
137 |
+
outputs=[self.pipeline_state, self.ui_state, self.variables_state],
|
138 |
+
)
|
139 |
+
|
140 |
+
def _render_add_step_button(self, position: int):
|
141 |
+
if position not in {0, -1}:
|
142 |
+
raise ValueError("Position must be 0 or -1")
|
143 |
+
row_class = "pipeline-header" if position == 0 else "pipeline-footer"
|
144 |
+
with gr.Row(elem_classes=row_class):
|
145 |
+
add_step_btn = gr.Button("➕ Add Step", elem_classes="add-step-button")
|
146 |
+
add_step_btn.click(
|
147 |
+
self.sm.add_step,
|
148 |
+
inputs=[self.pipeline_state, gr.State(position)],
|
149 |
+
outputs=[self.pipeline_state, self.ui_state, self.variables_state],
|
150 |
+
)
|
151 |
+
return add_step_btn
|
152 |
+
|
153 |
+
def _render_output_fields(self, available_variables: list[str], pipeline_state: PipelineState):
|
154 |
+
dropdowns = {}
|
155 |
+
UNSET_VALUE = "Choose variable..."
|
156 |
+
variable_options = [UNSET_VALUE] + [v for v in available_variables if v not in self.input_variables]
|
157 |
+
with gr.Column(elem_classes="step-accordion"):
|
158 |
+
with gr.Row(elem_classes="output-fields-header"):
|
159 |
+
gr.Markdown("#### Final output variables mapping:")
|
160 |
+
with gr.Row(elem_classes="output-fields-row"):
|
161 |
+
for output_field in self.required_output_variables:
|
162 |
+
value = pipeline_state.workflow.outputs[output_field]
|
163 |
+
if not value:
|
164 |
+
value = UNSET_VALUE
|
165 |
+
dropdown = gr.Dropdown(
|
166 |
+
label=output_field,
|
167 |
+
value=value,
|
168 |
+
choices=variable_options,
|
169 |
+
interactive=True,
|
170 |
+
elem_classes="output-field-variable",
|
171 |
+
# show_label=False,
|
172 |
+
)
|
173 |
+
dropdown.change(
|
174 |
+
self.sm.update_output_variables,
|
175 |
+
inputs=[self.pipeline_state, gr.State(output_field), dropdown],
|
176 |
+
outputs=[self.pipeline_state],
|
177 |
+
)
|
178 |
+
dropdowns[output_field] = dropdown
|
179 |
+
|
180 |
+
def update_choices(available_variables):
|
181 |
+
"""Update the choices for the dropdowns"""
|
182 |
+
return [
|
183 |
+
gr.update(choices=available_variables, value=None, selected=None) for dropdown in dropdowns.values()
|
184 |
+
]
|
185 |
+
|
186 |
+
self.variables_state.change(
|
187 |
+
update_choices,
|
188 |
+
inputs=[self.variables_state],
|
189 |
+
outputs=list(dropdowns.values()),
|
190 |
+
)
|
191 |
+
return dropdowns
|
192 |
+
|
193 |
+
def validate_workflow(self, state: PipelineState) -> PipelineState:
|
194 |
+
"""Validate the workflow."""
|
195 |
+
try:
|
196 |
+
if self.simple:
|
197 |
+
workflow = validate_simple_workflow(state.workflow, self.required_output_variables)
|
198 |
+
else:
|
199 |
+
workflow = validate_complex_workflow(state.workflow, self.required_output_variables)
|
200 |
+
state.workflow = workflow
|
201 |
+
return state
|
202 |
+
except ValueError as e:
|
203 |
+
raise gr.Error(e)
|
204 |
+
|
205 |
+
def _render_pipeline_header(self):
|
206 |
+
# Add Step button at top
|
207 |
+
input_variables_str = ", ".join([f"`{variable}`" for variable in self.input_variables])
|
208 |
+
output_variables_str = ", ".join([f"`{variable}`" for variable in self.required_output_variables])
|
209 |
+
if self.simple:
|
210 |
+
instruction = "Create a simple single LLM call pipeline that takes in the following input variables and outputs the following output variables:"
|
211 |
+
else:
|
212 |
+
instruction = "Create a pipeline that takes in the following input variables and outputs the following output variables:"
|
213 |
+
gr.Markdown(f"### {instruction}")
|
214 |
+
gr.Markdown(f"Input Variables: {input_variables_str}")
|
215 |
+
gr.Markdown(f"Output Variables: {output_variables_str}")
|
216 |
+
|
217 |
+
# if not self.simple:
|
218 |
+
# self._render_add_step_button(0)
|
219 |
+
|
220 |
+
def render(self):
|
221 |
+
"""Render the pipeline UI."""
|
222 |
+
# Create a placeholder for all the step components
|
223 |
+
self.all_components = []
|
224 |
+
|
225 |
+
# self.pipeline_state.change(
|
226 |
+
# lambda x, y: print(f"Pipeline state changed! UI:\n{x}\n\n Data:\n{y}"),
|
227 |
+
# inputs=[self.ui_state, self.pipeline_state],
|
228 |
+
# outputs=[],
|
229 |
+
# )
|
230 |
+
|
231 |
+
self._render_pipeline_header()
|
232 |
+
|
233 |
+
# Function to render all steps
|
234 |
+
@gr.render(inputs=[self.pipeline_state, self.ui_state])
|
235 |
+
def render_steps(state, ui_state):
|
236 |
+
"""Render all steps in the pipeline"""
|
237 |
+
workflow = state.workflow
|
238 |
+
print(f"\nRerender triggered! Current UI State:{ui_state}")
|
239 |
+
components = []
|
240 |
+
|
241 |
+
step_objects = [] # Reset step objects list
|
242 |
+
for i, step_id in enumerate(ui_state.step_ids):
|
243 |
+
step_data = workflow.steps[step_id]
|
244 |
+
step_ui_state = ui_state.steps[step_id]
|
245 |
+
available_variables = self.sm.get_all_variables(state, step_id)
|
246 |
+
sub_components = self._render_step(step_data, step_ui_state, available_variables, i)
|
247 |
+
step_objects.append(sub_components)
|
248 |
+
|
249 |
+
components.append(step_objects)
|
250 |
+
|
251 |
+
# Bottom buttons
|
252 |
+
if not self.simple:
|
253 |
+
self._render_add_step_button(-1)
|
254 |
+
|
255 |
+
@gr.render(inputs=[self.variables_state, self.pipeline_state])
|
256 |
+
def render_output_fields(available_variables, pipeline_state):
|
257 |
+
return self._render_output_fields(available_variables, pipeline_state)
|
258 |
+
|
259 |
+
export_btn = gr.Button("Export Pipeline", elem_classes="export-button")
|
260 |
+
# components.append(export_btn)
|
261 |
+
|
262 |
+
# Add a code box to display the workflow JSON
|
263 |
+
# with gr.Column(elem_classes="workflow-json-container"):
|
264 |
+
with gr.Accordion("Pipeline Preview", open=False, elem_classes="pipeline-preview") as config_accordion:
|
265 |
+
config_output = gr.Code(
|
266 |
+
label="Workflow Configuration",
|
267 |
+
language="yaml",
|
268 |
+
elem_classes="workflow-json",
|
269 |
+
interactive=True,
|
270 |
+
autocomplete=True,
|
271 |
+
)
|
272 |
+
# components.append(config_accordion)
|
273 |
+
|
274 |
+
config_output.blur(
|
275 |
+
fn=update_workflow_from_code,
|
276 |
+
inputs=[config_output, self.ui_state],
|
277 |
+
outputs=[self.pipeline_state],
|
278 |
+
)
|
279 |
+
|
280 |
+
# Connect the export button to show the workflow JSON
|
281 |
+
export_btn.click(self.validate_workflow, inputs=[self.pipeline_state], outputs=[self.pipeline_state]).success(
|
282 |
+
fn=lambda: gr.update(visible=True, open=True), outputs=[config_accordion]
|
283 |
+
)
|
284 |
+
export_btn.click(
|
285 |
+
fn=self.sm.get_formatted_config,
|
286 |
+
inputs=[self.pipeline_state, gr.State("yaml")],
|
287 |
+
outputs=[config_output],
|
288 |
+
js="() => {document.querySelector('.pipeline-preview').scrollIntoView({behavior: 'smooth'})}",
|
289 |
+
)
|
290 |
+
|
291 |
+
# self.all_components = components
|
src/components/model_pipeline/state_manager.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
from typing import Any, Literal
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
import yaml
|
7 |
+
from pydantic import BaseModel, Field
|
8 |
+
|
9 |
+
from components import utils
|
10 |
+
from workflows.factory import create_new_llm_step
|
11 |
+
from workflows.structs import ModelStep, Workflow
|
12 |
+
|
13 |
+
|
14 |
+
def make_step_id(step_id: int):
|
15 |
+
"""Make a step id from a step name."""
|
16 |
+
if step_id < 26:
|
17 |
+
return chr(ord("A") + step_id)
|
18 |
+
else:
|
19 |
+
# For more than 26 steps, use AA, AB, AC, etc.
|
20 |
+
first_char = chr(ord("A") + (step_id // 26) - 1)
|
21 |
+
second_char = chr(ord("A") + (step_id % 26))
|
22 |
+
return f"{first_char}{second_char}"
|
23 |
+
|
24 |
+
|
25 |
+
class ModelStepUIState(BaseModel):
|
26 |
+
"""Represents the UI state for a model step component."""
|
27 |
+
|
28 |
+
expanded: bool = True
|
29 |
+
active_tab: Literal["model-tab", "inputs-tab", "outputs-tab"] = "model-tab"
|
30 |
+
|
31 |
+
def update(self, key: str, value: Any) -> "ModelStepUIState":
|
32 |
+
"""Update the UI state."""
|
33 |
+
new_state = self.model_copy(update={key: value})
|
34 |
+
logging.warning("UI state updated: %s", self)
|
35 |
+
return new_state
|
36 |
+
|
37 |
+
|
38 |
+
class PipelineUIState(BaseModel):
|
39 |
+
"""Represents the UI state for a pipeline component."""
|
40 |
+
|
41 |
+
step_ids: list[str] = Field(default_factory=list)
|
42 |
+
steps: dict[str, ModelStepUIState] = Field(default_factory=dict)
|
43 |
+
|
44 |
+
def model_post_init(self, __context: utils.Any) -> None:
|
45 |
+
if not self.steps and self.step_ids:
|
46 |
+
self.steps = {step_id: ModelStepUIState() for step_id in self.step_ids}
|
47 |
+
return super().model_post_init(__context)
|
48 |
+
|
49 |
+
def get_step_position(self, step_id: str):
|
50 |
+
"""Get the position of a step in the pipeline."""
|
51 |
+
return next((i for i, step in enumerate(self.step_ids) if step == step_id), None)
|
52 |
+
|
53 |
+
@classmethod
|
54 |
+
def from_workflow(cls, workflow: Workflow):
|
55 |
+
"""Create a pipeline UI state from a workflow."""
|
56 |
+
return PipelineUIState(
|
57 |
+
step_ids=list(workflow.steps.keys()),
|
58 |
+
steps={step_id: ModelStepUIState() for step_id in workflow.steps.keys()},
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
class PipelineState(BaseModel):
|
63 |
+
"""Represents the state for a pipeline component."""
|
64 |
+
|
65 |
+
workflow: Workflow
|
66 |
+
ui_state: PipelineUIState
|
67 |
+
|
68 |
+
def insert_step(self, position: int, step: ModelStep):
|
69 |
+
if step.id in self.workflow.steps:
|
70 |
+
raise ValueError(f"Step {step.id} already exists in pipeline")
|
71 |
+
|
72 |
+
# Validate position
|
73 |
+
if position != -1 and (position < 0 or position > self.n_steps):
|
74 |
+
raise ValueError(f"Invalid position: {position}. Must be between 0 and {self.n_steps} or -1")
|
75 |
+
|
76 |
+
self.workflow.steps[step.id] = step
|
77 |
+
|
78 |
+
self.ui_state = self.ui_state.model_copy()
|
79 |
+
self.ui_state.steps[step.id] = ModelStepUIState()
|
80 |
+
if position == -1:
|
81 |
+
self.ui_state.step_ids.append(step.id)
|
82 |
+
else:
|
83 |
+
self.ui_state.step_ids.insert(position, step.id)
|
84 |
+
return self
|
85 |
+
|
86 |
+
def remove_step(self, position: int):
|
87 |
+
step_id = self.ui_state.step_ids.pop(position)
|
88 |
+
self.workflow.steps.pop(step_id)
|
89 |
+
self.ui_state = self.ui_state.model_copy()
|
90 |
+
self.ui_state.steps.pop(step_id)
|
91 |
+
self.update_output_variables_mapping()
|
92 |
+
|
93 |
+
def update_output_variables_mapping(self):
|
94 |
+
available_variables = set(self.available_variables)
|
95 |
+
for output_field in self.workflow.outputs:
|
96 |
+
if self.workflow.outputs[output_field] not in available_variables:
|
97 |
+
self.workflow.outputs[output_field] = None
|
98 |
+
return self
|
99 |
+
|
100 |
+
@property
|
101 |
+
def available_variables(self):
|
102 |
+
return self.workflow.get_available_variables()
|
103 |
+
|
104 |
+
@property
|
105 |
+
def n_steps(self):
|
106 |
+
return len(self.workflow.steps)
|
107 |
+
|
108 |
+
|
109 |
+
class PipelineStateManager:
|
110 |
+
"""Manages a pipeline of multiple steps."""
|
111 |
+
|
112 |
+
def get_formatted_config(self, state: PipelineState, format: Literal["json", "yaml"] = "yaml"):
|
113 |
+
"""Get the full pipeline configuration."""
|
114 |
+
config = state.workflow.model_dump(exclude_defaults=True)
|
115 |
+
if format == "yaml":
|
116 |
+
return yaml.dump(config, default_flow_style=False, sort_keys=False, indent=4)
|
117 |
+
else:
|
118 |
+
return json.dumps(config, indent=4, sort_keys=False)
|
119 |
+
|
120 |
+
def count_state(self):
|
121 |
+
return gr.State(len(self.steps))
|
122 |
+
|
123 |
+
def add_step(self, state: PipelineState, position: int = -1, name=""):
|
124 |
+
"""Create a new step and return its state."""
|
125 |
+
step_id = make_step_id(state.n_steps)
|
126 |
+
step_name = name or f"Step {state.n_steps + 1}"
|
127 |
+
new_step = create_new_llm_step(step_id=step_id, name=step_name)
|
128 |
+
state = state.insert_step(position, new_step)
|
129 |
+
return state, state.ui_state, state.available_variables
|
130 |
+
|
131 |
+
def remove_step(self, state: PipelineState, position: int):
|
132 |
+
"""Remove a step from the pipeline."""
|
133 |
+
if 0 <= position < state.n_steps:
|
134 |
+
state = state.remove_step(position)
|
135 |
+
else:
|
136 |
+
raise ValueError(f"Invalid step position: {position}")
|
137 |
+
return state, state.ui_state, state.available_variables
|
138 |
+
|
139 |
+
def move_up(self, ui_state: PipelineUIState, position: int):
|
140 |
+
"""Move a step up in the pipeline."""
|
141 |
+
utils.move_item(ui_state.step_ids, position, "up")
|
142 |
+
return ui_state.model_copy()
|
143 |
+
|
144 |
+
def move_down(self, ui_state: PipelineUIState, position: int):
|
145 |
+
"""Move a step down in the pipeline."""
|
146 |
+
utils.move_item(ui_state.step_ids, position, "down")
|
147 |
+
return ui_state.model_copy()
|
148 |
+
|
149 |
+
def update_model_step_state(self, state: PipelineState, model_step: ModelStep, ui_state: ModelStepUIState):
|
150 |
+
"""Update a step in the pipeline."""
|
151 |
+
state.workflow.steps[model_step.id] = model_step.model_copy()
|
152 |
+
state.ui_state.steps[model_step.id] = ui_state.model_copy()
|
153 |
+
state.ui_state = state.ui_state.model_copy()
|
154 |
+
state.update_output_variables_mapping()
|
155 |
+
return state, state.ui_state, state.available_variables
|
156 |
+
|
157 |
+
def update_output_variables(self, state: PipelineState, target: str, produced_variable: str):
|
158 |
+
if produced_variable == "Choose variable...":
|
159 |
+
produced_variable = None
|
160 |
+
"""Update the output variables for a step."""
|
161 |
+
state.workflow.outputs.update({target: produced_variable})
|
162 |
+
return state
|
163 |
+
|
164 |
+
def update_model_step_ui(self, state: PipelineState, step_ui: ModelStepUIState, step_id: str):
|
165 |
+
"""Update a step in the pipeline."""
|
166 |
+
state.ui_state.steps[step_id] = step_ui.model_copy()
|
167 |
+
return state, state.ui_state
|
168 |
+
|
169 |
+
def get_all_variables(self, state: PipelineState, model_step_id: str | None = None) -> list[str]:
|
170 |
+
"""Get all variables from all steps."""
|
171 |
+
available_variables = state.available_variables
|
172 |
+
if model_step_id is None:
|
173 |
+
return available_variables
|
174 |
+
else:
|
175 |
+
prefix = f"{model_step_id}."
|
176 |
+
return [var for var in available_variables if not var.startswith(prefix)]
|
177 |
+
|
178 |
+
def get_pipeline_config(self):
|
179 |
+
"""Get the full pipeline configuration."""
|
180 |
+
return self.workflow
|
src/components/model_step/__init__.py
ADDED
File without changes
|
src/components/model_step/model_step.py
ADDED
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from typing import Any
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
from gradio.components import FormComponent
|
6 |
+
|
7 |
+
from components.model_pipeline.state_manager import ModelStepUIState, PipelineState, PipelineStateManager
|
8 |
+
from utils import get_full_model_name
|
9 |
+
from workflows.structs import ModelStep
|
10 |
+
|
11 |
+
from .state_manager import ModelStepStateManager
|
12 |
+
from .ui_components import InputRowButtonGroup, OutputRowButtonGroup
|
13 |
+
|
14 |
+
|
15 |
+
def _make_accordion_label(model_step: ModelStep):
|
16 |
+
name = model_step.name if model_step.name else "Untitled"
|
17 |
+
input_field_names = [field.name for field in model_step.input_fields]
|
18 |
+
inputs_str = ", ".join(input_field_names)
|
19 |
+
output_field_names = [field.name for field in model_step.output_fields]
|
20 |
+
outputs_str = ", ".join(output_field_names)
|
21 |
+
return "{}: {} ({}) → ({})".format(model_step.id, name, inputs_str, outputs_str)
|
22 |
+
|
23 |
+
|
24 |
+
class ModelStepComponent(FormComponent):
|
25 |
+
"""
|
26 |
+
A custom Gradio component representing a single Step in a pipeline.
|
27 |
+
It contains:
|
28 |
+
1. Model Provider & System Prompt
|
29 |
+
2. Inputs – fields with name, description, and variable used
|
30 |
+
3. Outputs – fields with name, description, and variable used
|
31 |
+
|
32 |
+
Listens to events:
|
33 |
+
- on_model_step_change
|
34 |
+
- on_ui_change
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
value: ModelStep | gr.State,
|
40 |
+
ui_state: ModelStepUIState | gr.State | None = None,
|
41 |
+
model_options: list[str] | None = None,
|
42 |
+
input_variables: list[str] | None = None,
|
43 |
+
max_input_fields=5,
|
44 |
+
max_output_fields=5,
|
45 |
+
pipeline_state_manager: PipelineStateManager | None = None,
|
46 |
+
**kwargs,
|
47 |
+
):
|
48 |
+
self.max_fields = {
|
49 |
+
"input": max_input_fields,
|
50 |
+
"output": max_output_fields,
|
51 |
+
}
|
52 |
+
self.model_options = model_options
|
53 |
+
self.input_variables = input_variables
|
54 |
+
self.sm = ModelStepStateManager(max_input_fields, max_output_fields)
|
55 |
+
self.pipeline_sm: PipelineStateManager = pipeline_state_manager
|
56 |
+
|
57 |
+
self.model_step_state = gr.State(value)
|
58 |
+
ui_state = ui_state or ModelStepUIState()
|
59 |
+
if not isinstance(ui_state, gr.State):
|
60 |
+
ui_state = gr.State(ui_state)
|
61 |
+
self.ui_state: gr.State = ui_state
|
62 |
+
|
63 |
+
self.inputs_count_state = gr.State(len(value.input_fields))
|
64 |
+
self.outputs_count_state = gr.State(len(value.output_fields))
|
65 |
+
|
66 |
+
# UI components that will be created in render
|
67 |
+
self.accordion = None
|
68 |
+
self.ui = None
|
69 |
+
self.step_name_input = None
|
70 |
+
self.model_selection = None
|
71 |
+
self.system_prompt = None
|
72 |
+
self.input_rows = []
|
73 |
+
self.output_rows = []
|
74 |
+
|
75 |
+
super().__init__(**kwargs)
|
76 |
+
# self.render()
|
77 |
+
self.setup_event_listeners()
|
78 |
+
|
79 |
+
@property
|
80 |
+
def model_step(self) -> ModelStep:
|
81 |
+
return self.model_step_state.value
|
82 |
+
|
83 |
+
@property
|
84 |
+
def step_id(self) -> str:
|
85 |
+
return self.model_step.id
|
86 |
+
|
87 |
+
def get_step_config(self) -> dict:
|
88 |
+
return self.model_step.model_dump()
|
89 |
+
|
90 |
+
# UI state accessors
|
91 |
+
def is_open(self) -> bool:
|
92 |
+
return self.ui_state.value.expanded
|
93 |
+
|
94 |
+
def get_active_tab(self) -> str:
|
95 |
+
"""Get the current active tab."""
|
96 |
+
return self.ui_state.value.active_tab
|
97 |
+
|
98 |
+
def _render_input_row(self, i: int) -> tuple[gr.Row, tuple, tuple]:
|
99 |
+
"""Render a single input row at index i."""
|
100 |
+
inputs = self.model_step.input_fields
|
101 |
+
is_visible = i < len(inputs)
|
102 |
+
label_visible = i == 0
|
103 |
+
initial_name = inputs[i].name if is_visible else ""
|
104 |
+
initial_desc = inputs[i].description if is_visible else ""
|
105 |
+
initial_var = inputs[i].variable if is_visible else "question_text"
|
106 |
+
|
107 |
+
with gr.Row(visible=is_visible, elem_classes="field-row form") as row:
|
108 |
+
button_group = InputRowButtonGroup()
|
109 |
+
|
110 |
+
inp_name = gr.Textbox(
|
111 |
+
label="Input Name",
|
112 |
+
placeholder="Field name",
|
113 |
+
value=initial_name,
|
114 |
+
elem_classes="field-name",
|
115 |
+
scale=1,
|
116 |
+
show_label=label_visible,
|
117 |
+
)
|
118 |
+
|
119 |
+
# Get variable choices safely
|
120 |
+
# variable_choices = []
|
121 |
+
# if self.pipeline_sm is not None:
|
122 |
+
# variable_choices = self.pipeline_sm.get_all_variables(self.step_id)
|
123 |
+
|
124 |
+
inp_var = gr.Dropdown(
|
125 |
+
choices=self.input_variables,
|
126 |
+
label="Variable Used",
|
127 |
+
value=initial_var,
|
128 |
+
elem_classes="field-variable",
|
129 |
+
scale=1,
|
130 |
+
show_label=label_visible,
|
131 |
+
)
|
132 |
+
inp_desc = gr.Textbox(
|
133 |
+
label="Description",
|
134 |
+
placeholder="Field description",
|
135 |
+
value=initial_desc,
|
136 |
+
elem_classes="field-description",
|
137 |
+
scale=3,
|
138 |
+
show_label=label_visible,
|
139 |
+
)
|
140 |
+
fields = (inp_name, inp_var, inp_desc)
|
141 |
+
# buttons = (delete_button, add_button)
|
142 |
+
return row, fields, button_group
|
143 |
+
|
144 |
+
def _render_output_row(self, i: int) -> tuple[gr.Row, tuple, tuple]:
|
145 |
+
"""Render a single output row at index i."""
|
146 |
+
outputs = self.model_step.output_fields
|
147 |
+
is_visible = i < len(outputs)
|
148 |
+
label_visible = i == 0
|
149 |
+
initial_name = outputs[i].name if is_visible else ""
|
150 |
+
initial_desc = outputs[i].description if is_visible else ""
|
151 |
+
initial_type = outputs[i].type if is_visible else "str"
|
152 |
+
with gr.Row(visible=is_visible, elem_classes="field-row") as row:
|
153 |
+
button_group = OutputRowButtonGroup()
|
154 |
+
|
155 |
+
out_name = gr.Textbox(
|
156 |
+
label="Output Field",
|
157 |
+
placeholder="Variable identifier",
|
158 |
+
value=initial_name,
|
159 |
+
elem_classes="field-name",
|
160 |
+
scale=1,
|
161 |
+
show_label=label_visible,
|
162 |
+
)
|
163 |
+
out_type = gr.Dropdown(
|
164 |
+
choices=["str", "int", "float", "bool"],
|
165 |
+
allow_custom_value=True,
|
166 |
+
label="Type",
|
167 |
+
value=initial_type,
|
168 |
+
elem_classes="field-type",
|
169 |
+
scale=0,
|
170 |
+
show_label=label_visible,
|
171 |
+
interactive=True,
|
172 |
+
)
|
173 |
+
out_desc = gr.Textbox(
|
174 |
+
label="Description",
|
175 |
+
placeholder="Field description",
|
176 |
+
value=initial_desc,
|
177 |
+
elem_classes="field-description",
|
178 |
+
scale=3,
|
179 |
+
show_label=label_visible,
|
180 |
+
)
|
181 |
+
|
182 |
+
fields = (out_name, out_type, out_desc)
|
183 |
+
return row, fields, button_group
|
184 |
+
|
185 |
+
def _render_prompt_tab_content(self):
|
186 |
+
self.system_prompt = gr.Textbox(
|
187 |
+
label="System Prompt",
|
188 |
+
placeholder="Enter the system prompt for this step",
|
189 |
+
lines=5,
|
190 |
+
value=self.model_step.system_prompt,
|
191 |
+
elem_classes="system-prompt",
|
192 |
+
)
|
193 |
+
|
194 |
+
def _render_inputs_tab_content(self):
|
195 |
+
with gr.Column(variant="panel", elem_classes="fields-panel") as self.inputs_column:
|
196 |
+
# Render input rows using helper method
|
197 |
+
for i in range(self.max_fields["input"]):
|
198 |
+
row = self._render_input_row(i)
|
199 |
+
self.input_rows.append(row)
|
200 |
+
|
201 |
+
def _render_outputs_tab_content(self):
|
202 |
+
with gr.Column(variant="panel", elem_classes="fields-panel") as self.outputs_column:
|
203 |
+
# Render output rows using helper method
|
204 |
+
for i in range(self.max_fields["output"]):
|
205 |
+
row = self._render_output_row(i)
|
206 |
+
self.output_rows.append(row)
|
207 |
+
|
208 |
+
def _render_tab_content(self, tab_id: str):
|
209 |
+
if tab_id == "model-tab":
|
210 |
+
self._render_prompt_tab_content()
|
211 |
+
elif tab_id == "inputs-tab":
|
212 |
+
self._render_inputs_tab_content()
|
213 |
+
elif tab_id == "outputs-tab":
|
214 |
+
self._render_outputs_tab_content()
|
215 |
+
|
216 |
+
def _render_header(self, model_options: tuple[str]):
|
217 |
+
# Header with step name
|
218 |
+
with gr.Row(elem_classes="step-header-row"):
|
219 |
+
self.step_name_input = gr.Textbox(
|
220 |
+
label="",
|
221 |
+
value=self.model_step.name,
|
222 |
+
elem_classes="step-name",
|
223 |
+
show_label=False,
|
224 |
+
placeholder="Model name...",
|
225 |
+
)
|
226 |
+
unselected_choice = "Select Model..."
|
227 |
+
current_value = (
|
228 |
+
get_full_model_name(self.model_step.model, self.model_step.provider)
|
229 |
+
if self.model_step.model
|
230 |
+
else unselected_choice
|
231 |
+
)
|
232 |
+
self.model_selection = gr.Dropdown(
|
233 |
+
choices=[unselected_choice] + model_options,
|
234 |
+
label="Model Provider",
|
235 |
+
show_label=False,
|
236 |
+
value=current_value,
|
237 |
+
elem_classes="model-dropdown",
|
238 |
+
scale=1,
|
239 |
+
)
|
240 |
+
self.temperature_slider = gr.Slider(
|
241 |
+
value=self.model_step.temperature,
|
242 |
+
minimum=0.0,
|
243 |
+
maximum=5,
|
244 |
+
step=0.05,
|
245 |
+
info="Temperature",
|
246 |
+
show_label=False,
|
247 |
+
)
|
248 |
+
|
249 |
+
def render(self):
|
250 |
+
"""Render the component UI"""
|
251 |
+
# Reset UI component lists
|
252 |
+
self.input_rows = []
|
253 |
+
self.output_rows = []
|
254 |
+
self.tabs = {}
|
255 |
+
|
256 |
+
# Create the accordion for this step
|
257 |
+
accordion_label = _make_accordion_label(self.model_step)
|
258 |
+
self.accordion = gr.Accordion(label=accordion_label, open=self.is_open(), elem_classes="step-accordion")
|
259 |
+
|
260 |
+
# Create the UI content inside the accordion
|
261 |
+
with self.accordion:
|
262 |
+
self._render_header(self.model_options)
|
263 |
+
|
264 |
+
# Configuration tabs
|
265 |
+
selected_tab = self.get_active_tab()
|
266 |
+
with gr.Tabs(elem_classes="step-tabs", selected=selected_tab):
|
267 |
+
tab_ids = ("model-tab", "inputs-tab", "outputs-tab")
|
268 |
+
tab_labels = ("Model", "Inputs", "Outputs")
|
269 |
+
for tab_id, label in zip(tab_ids, tab_labels):
|
270 |
+
with gr.TabItem(label, elem_classes="tab-content", id=tab_id) as tab:
|
271 |
+
self._render_tab_content(tab_id)
|
272 |
+
self.tabs[tab_id] = tab
|
273 |
+
|
274 |
+
return self.accordion
|
275 |
+
|
276 |
+
def _setup_event_listeners_for_view_change(self):
|
277 |
+
for tab_id, tab in self.tabs.items():
|
278 |
+
tab.select(
|
279 |
+
fn=self.sm.update_ui_state,
|
280 |
+
inputs=[self.ui_state, gr.State("active_tab"), gr.State(tab_id)],
|
281 |
+
outputs=[self.ui_state],
|
282 |
+
)
|
283 |
+
self.accordion.collapse(
|
284 |
+
fn=self.sm.update_ui_state,
|
285 |
+
inputs=[self.ui_state, gr.State("expanded"), gr.State(False)],
|
286 |
+
outputs=[self.ui_state],
|
287 |
+
)
|
288 |
+
self.accordion.expand(
|
289 |
+
fn=self.sm.update_ui_state,
|
290 |
+
inputs=[self.ui_state, gr.State("expanded"), gr.State(True)],
|
291 |
+
outputs=[self.ui_state],
|
292 |
+
)
|
293 |
+
|
294 |
+
def _setup_event_listeners_model_tab(self):
|
295 |
+
# Step name change
|
296 |
+
self.step_name_input.blur(
|
297 |
+
fn=self._update_state_and_label,
|
298 |
+
inputs=[self.model_step_state, self.step_name_input],
|
299 |
+
outputs=[self.model_step_state, self.accordion],
|
300 |
+
)
|
301 |
+
|
302 |
+
self.temperature_slider.release(
|
303 |
+
fn=self.sm.update_temperature,
|
304 |
+
inputs=[self.model_step_state, self.temperature_slider],
|
305 |
+
outputs=[self.model_step_state],
|
306 |
+
)
|
307 |
+
|
308 |
+
# Model and system prompt
|
309 |
+
self.model_selection.change(
|
310 |
+
fn=self.sm.update_model_and_provider,
|
311 |
+
inputs=[self.model_step_state, self.model_selection],
|
312 |
+
outputs=[self.model_step_state],
|
313 |
+
)
|
314 |
+
|
315 |
+
self.system_prompt.blur(
|
316 |
+
fn=self.sm.update_system_prompt,
|
317 |
+
inputs=[self.model_step_state, self.system_prompt],
|
318 |
+
outputs=[self.model_step_state],
|
319 |
+
)
|
320 |
+
|
321 |
+
def _setup_event_listeners_inputs_tab(self):
|
322 |
+
# Setup input row events
|
323 |
+
for i, (row, fields, button_group) in enumerate(self.input_rows):
|
324 |
+
inp_name, inp_var, inp_desc = fields
|
325 |
+
row_index = gr.State(i)
|
326 |
+
|
327 |
+
# Field change handlers
|
328 |
+
inp_name.blur(
|
329 |
+
fn=self.sm.update_input_field_name,
|
330 |
+
inputs=[self.model_step_state, inp_name, row_index],
|
331 |
+
outputs=[self.model_step_state],
|
332 |
+
)
|
333 |
+
|
334 |
+
inp_var.change(
|
335 |
+
fn=self.sm.update_input_field_variable,
|
336 |
+
inputs=[self.model_step_state, inp_var, row_index],
|
337 |
+
outputs=[self.model_step_state],
|
338 |
+
)
|
339 |
+
|
340 |
+
inp_desc.blur(
|
341 |
+
fn=self.sm.update_input_field_description,
|
342 |
+
inputs=[self.model_step_state, inp_desc, row_index],
|
343 |
+
outputs=[self.model_step_state],
|
344 |
+
)
|
345 |
+
|
346 |
+
rows = [row for (row, _, _) in self.input_rows]
|
347 |
+
input_fields = [field for (_, fields, _) in self.input_rows for field in fields]
|
348 |
+
|
349 |
+
# Button handlers
|
350 |
+
button_group.delete(
|
351 |
+
fn=self.sm.delete_input_field,
|
352 |
+
inputs=[self.model_step_state, row_index],
|
353 |
+
outputs=[self.model_step_state, self.inputs_count_state] + rows + input_fields,
|
354 |
+
)
|
355 |
+
|
356 |
+
button_group.add(
|
357 |
+
fn=self.sm.add_input_field,
|
358 |
+
inputs=[self.model_step_state, row_index],
|
359 |
+
outputs=[self.model_step_state, self.inputs_count_state] + rows + input_fields,
|
360 |
+
)
|
361 |
+
|
362 |
+
def _setup_event_listeners_outputs_tab(self):
|
363 |
+
# Setup output row events
|
364 |
+
for i, (row, fields, button_group) in enumerate(self.output_rows):
|
365 |
+
out_name, out_type, out_desc = fields
|
366 |
+
|
367 |
+
row_index = gr.State(i)
|
368 |
+
|
369 |
+
# Field change handlers
|
370 |
+
out_name.blur(
|
371 |
+
fn=self.sm.update_output_field_name,
|
372 |
+
inputs=[self.model_step_state, out_name, row_index],
|
373 |
+
outputs=[self.model_step_state],
|
374 |
+
)
|
375 |
+
|
376 |
+
out_type.change(
|
377 |
+
fn=self.sm.update_output_field_type,
|
378 |
+
inputs=[self.model_step_state, out_type, row_index],
|
379 |
+
outputs=[self.model_step_state],
|
380 |
+
)
|
381 |
+
|
382 |
+
out_desc.blur(
|
383 |
+
fn=self.sm.update_output_field_description,
|
384 |
+
inputs=[self.model_step_state, out_desc, row_index],
|
385 |
+
outputs=[self.model_step_state],
|
386 |
+
)
|
387 |
+
|
388 |
+
rows = [row for (row, _, _) in self.output_rows]
|
389 |
+
output_fields = [field for (_, fields, _) in self.output_rows for field in fields]
|
390 |
+
|
391 |
+
# Button handlers
|
392 |
+
button_group.delete(
|
393 |
+
fn=self.sm.delete_output_field,
|
394 |
+
inputs=[self.model_step_state, row_index],
|
395 |
+
outputs=[self.model_step_state, self.outputs_count_state] + rows + output_fields,
|
396 |
+
)
|
397 |
+
|
398 |
+
button_group.add(
|
399 |
+
fn=self.sm.add_output_field,
|
400 |
+
inputs=[self.model_step_state, row_index],
|
401 |
+
outputs=[self.model_step_state, self.outputs_count_state] + rows + output_fields,
|
402 |
+
)
|
403 |
+
|
404 |
+
button_group.up(
|
405 |
+
fn=self.sm.move_output_field,
|
406 |
+
inputs=[self.model_step_state, row_index, gr.State("up")],
|
407 |
+
outputs=[self.model_step_state] + output_fields,
|
408 |
+
)
|
409 |
+
|
410 |
+
button_group.down(
|
411 |
+
fn=self.sm.move_output_field,
|
412 |
+
inputs=[self.model_step_state, row_index, gr.State("down")],
|
413 |
+
outputs=[self.model_step_state] + output_fields,
|
414 |
+
)
|
415 |
+
|
416 |
+
# Function to set up event listeners - call this separately after all components are rendered
|
417 |
+
def setup_event_listeners(self):
|
418 |
+
"""Set up all event listeners for this component"""
|
419 |
+
self._setup_event_listeners_for_view_change()
|
420 |
+
self._setup_event_listeners_model_tab()
|
421 |
+
self._setup_event_listeners_inputs_tab()
|
422 |
+
self._setup_event_listeners_outputs_tab()
|
423 |
+
|
424 |
+
def state_str(x, limited: bool = False):
|
425 |
+
d = x.model_dump()
|
426 |
+
if limited:
|
427 |
+
d = {k: d[k] for k in {"name", "temperature"}}
|
428 |
+
return json.dumps(d, indent=2)
|
429 |
+
|
430 |
+
def log_step_states(x, y, src: str):
|
431 |
+
print(f"{src} triggered! UI:\n{state_str(x)}\n\nData:\n{state_str(y, True)}")
|
432 |
+
print("--------------------------------")
|
433 |
+
print(f"self.model_step_state: \n{self.get_step_config()}")
|
434 |
+
print("--------------------------------")
|
435 |
+
|
436 |
+
# self.model_step_state.change(
|
437 |
+
# log_step_states,
|
438 |
+
# inputs=[self.ui_state, self.model_step_state, gr.State("Model Change")],
|
439 |
+
# )
|
440 |
+
# self.ui_state.change(
|
441 |
+
# log_step_states,
|
442 |
+
# inputs=[self.ui_state, self.model_step_state, gr.State("UI Change")],
|
443 |
+
# )
|
444 |
+
|
445 |
+
def on_model_step_change(self, fn, inputs, outputs):
|
446 |
+
"""Set up an event listener for the model change event."""
|
447 |
+
self.model_step_state.change(fn, inputs, outputs)
|
448 |
+
|
449 |
+
def on_ui_change(self, fn, inputs, outputs):
|
450 |
+
"""Set up an event listener for the UI change event."""
|
451 |
+
self.ui_state.change(fn, inputs, outputs)
|
452 |
+
|
453 |
+
def _update_state_and_label(self, model_step: ModelStep, name: str):
|
454 |
+
"""Update both the state and the accordion label."""
|
455 |
+
new_model_step = self.sm.update_step_name(model_step, name)
|
456 |
+
new_label = _make_accordion_label(new_model_step)
|
457 |
+
return new_model_step, gr.update(label=new_label)
|
458 |
+
|
459 |
+
def refresh_variable_dropdowns(self, pipeline_state: PipelineState):
|
460 |
+
# TODO: Fix this. Not sure why this is needed.
|
461 |
+
"""Refresh the variable dropdown options in all input rows."""
|
462 |
+
variable_choices = []
|
463 |
+
if self.pipeline_sm is not None:
|
464 |
+
variable_choices = self.pipeline_sm.get_all_variables(pipeline_state)
|
465 |
+
|
466 |
+
for _, fields, _ in self.input_rows:
|
467 |
+
_, inp_var, _ = fields
|
468 |
+
inp_var.update(choices=variable_choices)
|
469 |
+
|
470 |
+
def _update_model_and_refresh_ui(self, updated_model_step):
|
471 |
+
"""Update the model step state and refresh UI elements that depend on it."""
|
472 |
+
self.model_step_state.value = updated_model_step
|
473 |
+
# Update accordion label
|
474 |
+
new_label = _make_accordion_label(updated_model_step)
|
475 |
+
if self.accordion:
|
476 |
+
self.accordion.update(label=new_label)
|
477 |
+
return updated_model_step
|
src/components/model_step/state_manager.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Literal, Union
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
|
5 |
+
from components.model_pipeline.state_manager import ModelStepUIState
|
6 |
+
from components.utils import DIRECTIONS, move_item
|
7 |
+
from utils import get_model_and_provider
|
8 |
+
from workflows.structs import FieldType, ModelStep
|
9 |
+
|
10 |
+
|
11 |
+
class ModelStepStateManager:
|
12 |
+
def __init__(self, max_input_fields: int, max_output_fields: int):
|
13 |
+
self.max_fields = {
|
14 |
+
"input": max_input_fields,
|
15 |
+
"output": max_output_fields,
|
16 |
+
}
|
17 |
+
|
18 |
+
# UI state update functions
|
19 |
+
def update_ui_state(self, ui_state: ModelStepUIState, key: str, value: Any) -> ModelStepUIState:
|
20 |
+
return ui_state.update(key, value)
|
21 |
+
|
22 |
+
# Property update functions
|
23 |
+
def update_step_name(self, model_step: ModelStep, value: str) -> ModelStep:
|
24 |
+
"""Update the step name in state and accordion label."""
|
25 |
+
return model_step.update_property("name", value)
|
26 |
+
|
27 |
+
def update_temperature(self, model_step: ModelStep, value: float) -> ModelStep:
|
28 |
+
return model_step.update_property("temperature", value)
|
29 |
+
|
30 |
+
def update_model_and_provider(self, model_step: ModelStep, value: str) -> ModelStep:
|
31 |
+
"""Update the model provider in the state."""
|
32 |
+
model, provider = get_model_and_provider(value)
|
33 |
+
return model_step.update({"model": model, "provider": provider})
|
34 |
+
|
35 |
+
def update_system_prompt(self, model_step: ModelStep, value: str) -> ModelStep:
|
36 |
+
"""Update the system prompt in the state."""
|
37 |
+
return model_step.update_property("system_prompt", value)
|
38 |
+
|
39 |
+
# Field update functions
|
40 |
+
def update_input_field_name(self, model_step: ModelStep, value: str, index: int) -> ModelStep:
|
41 |
+
"""Update a specific field of an input field at the given index."""
|
42 |
+
return model_step.update_field("input", index, "name", value)
|
43 |
+
|
44 |
+
def update_input_field_variable(self, model_step: ModelStep, value: str, index: int) -> ModelStep:
|
45 |
+
"""Update a specific field of an input field at the given index."""
|
46 |
+
return model_step.update_field("input", index, "variable", value)
|
47 |
+
|
48 |
+
def update_input_field_description(self, model_step: ModelStep, value: str, index: int) -> ModelStep:
|
49 |
+
"""Update a specific field of an input field at the given index."""
|
50 |
+
return model_step.update_field("input", index, "description", value)
|
51 |
+
|
52 |
+
def update_output_field_name(self, model_step: ModelStep, value: str, index: int) -> ModelStep:
|
53 |
+
"""Update a specific field of an output field at the given index."""
|
54 |
+
return model_step.update_field("output", index, "name", value)
|
55 |
+
|
56 |
+
def update_output_field_type(self, model_step: ModelStep, value: str, index: int) -> ModelStep:
|
57 |
+
"""Update a specific field of an output field at the given index."""
|
58 |
+
return model_step.update_field("output", index, "type", value)
|
59 |
+
|
60 |
+
def update_output_field_variable(self, model_step: ModelStep, value: str, index: int) -> ModelStep:
|
61 |
+
"""Update a specific field of an output field at the given index."""
|
62 |
+
return model_step.update_field("output", index, "variable", value)
|
63 |
+
|
64 |
+
def update_output_field_description(self, model_step: ModelStep, value: str, index: int) -> ModelStep:
|
65 |
+
"""Update a specific field of an output field at the given index."""
|
66 |
+
return model_step.update_field("output", index, "description", value)
|
67 |
+
|
68 |
+
def make_input_field_updates(self, model_step: ModelStep) -> list[gr.State | dict[str, Any]]:
|
69 |
+
fields = model_step.input_fields
|
70 |
+
updates = []
|
71 |
+
for i in range(self.max_fields["input"]):
|
72 |
+
if i < len(fields):
|
73 |
+
updates.extend(
|
74 |
+
[
|
75 |
+
gr.update(value=fields[i].name),
|
76 |
+
gr.update(value=fields[i].variable),
|
77 |
+
gr.update(value=fields[i].description),
|
78 |
+
]
|
79 |
+
)
|
80 |
+
else:
|
81 |
+
updates.extend([gr.skip(), gr.skip(), gr.skip()])
|
82 |
+
return updates
|
83 |
+
|
84 |
+
def make_output_field_updates(self, model_step: ModelStep) -> list[gr.State | dict[str, Any]]:
|
85 |
+
fields = model_step.output_fields
|
86 |
+
updates = []
|
87 |
+
for i in range(self.max_fields["output"]):
|
88 |
+
if i < len(fields):
|
89 |
+
updates.extend(
|
90 |
+
[
|
91 |
+
gr.update(value=fields[i].name),
|
92 |
+
gr.update(value=fields[i].type),
|
93 |
+
gr.update(value=fields[i].description),
|
94 |
+
]
|
95 |
+
)
|
96 |
+
else:
|
97 |
+
updates.extend([gr.skip(), gr.skip(), gr.skip()])
|
98 |
+
return updates
|
99 |
+
|
100 |
+
def _add_field(
|
101 |
+
self, model_step: ModelStep, field_type: FieldType, index: int = -1, input_var: str | None = None
|
102 |
+
) -> tuple[Union[ModelStep, int, dict[str, Any]], ...]:
|
103 |
+
new_step = model_step.add_field(field_type, index, input_var)
|
104 |
+
fields = new_step.fields(field_type)
|
105 |
+
row_updates = [gr.update(visible=i < len(fields)) for i in range(self.max_fields[field_type])]
|
106 |
+
return new_step, len(fields), *row_updates
|
107 |
+
|
108 |
+
def _delete_field(
|
109 |
+
self, model_step: ModelStep, field_type: FieldType, index: int
|
110 |
+
) -> tuple[Union[ModelStep, int, dict[str, Any]], ...]:
|
111 |
+
new_step = model_step.delete_field(field_type, index)
|
112 |
+
fields = new_step.fields(field_type)
|
113 |
+
row_updates = [gr.update(visible=i < len(fields)) for i in range(self.max_fields[field_type])]
|
114 |
+
return new_step, len(fields), *row_updates
|
115 |
+
|
116 |
+
# Field add/delete functions
|
117 |
+
def add_input_field(self, model_step: ModelStep, index: int = -1):
|
118 |
+
updates = self._add_field(model_step, "input", index, input_var="question_text")
|
119 |
+
return *updates, *self.make_input_field_updates(model_step)
|
120 |
+
|
121 |
+
def add_output_field(self, model_step: ModelStep, index: int = -1):
|
122 |
+
updates = self._add_field(model_step, "output", index)
|
123 |
+
return *updates, *self.make_output_field_updates(model_step)
|
124 |
+
|
125 |
+
def delete_input_field(self, model_step: ModelStep, index: int):
|
126 |
+
updates = self._delete_field(model_step, "input", index)
|
127 |
+
return *updates, *self.make_input_field_updates(model_step)
|
128 |
+
|
129 |
+
def delete_output_field(self, model_step: ModelStep, index: int):
|
130 |
+
updates = self._delete_field(model_step, "output", index)
|
131 |
+
return *updates, *self.make_output_field_updates(model_step)
|
132 |
+
|
133 |
+
def move_output_field(
|
134 |
+
self, model_step: ModelStep, index: int, direction: DIRECTIONS
|
135 |
+
) -> list[gr.State | dict[str, Any]]:
|
136 |
+
"""
|
137 |
+
Move an output field in the list either up or down.
|
138 |
+
|
139 |
+
Args:
|
140 |
+
index: Index of the output field to move
|
141 |
+
direction: Direction to move the field ('up' or 'down')
|
142 |
+
|
143 |
+
Returns:
|
144 |
+
list: A list containing [updated_state, field_value_updates...]
|
145 |
+
"""
|
146 |
+
new_step = model_step.model_copy()
|
147 |
+
move_item(new_step.output_fields, index, direction)
|
148 |
+
|
149 |
+
# Update all output fields to reflect the new order
|
150 |
+
updates = self.make_output_field_updates(new_step)
|
151 |
+
|
152 |
+
return new_step, *updates
|
src/components/model_step/ui_components.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from gradio.components import FormComponent
|
3 |
+
|
4 |
+
|
5 |
+
class ButtonGroup:
|
6 |
+
"""Base class for button groups with common functionality."""
|
7 |
+
|
8 |
+
def __init__(self, events: list[str], *args, **kwargs):
|
9 |
+
self.buttons = {event: None for event in events}
|
10 |
+
self.click_args = {event: None for event in events}
|
11 |
+
self.render()
|
12 |
+
|
13 |
+
def render(self):
|
14 |
+
"""Render the buttons and set up their event handlers."""
|
15 |
+
for event, button in self.buttons.items():
|
16 |
+
if self.click_args[event]:
|
17 |
+
button.click(*self.click_args[event])
|
18 |
+
|
19 |
+
def _setup_button(self, event, fn, inputs, outputs):
|
20 |
+
"""Set up a button's click event handler."""
|
21 |
+
self.click_args[event] = fn, inputs, outputs
|
22 |
+
if self.buttons[event]:
|
23 |
+
self.buttons[event].click(fn, inputs, outputs)
|
24 |
+
|
25 |
+
def api_info(self):
|
26 |
+
return {
|
27 |
+
"name": self.__class__.__name__,
|
28 |
+
"events": self.EVENTS,
|
29 |
+
"inputs": [],
|
30 |
+
"outputs": [],
|
31 |
+
}
|
32 |
+
|
33 |
+
def example_payload(self):
|
34 |
+
"""Return None since this component doesn't have direct input values."""
|
35 |
+
return None
|
36 |
+
|
37 |
+
def example_value(self):
|
38 |
+
"""Return None since this component doesn't have direct output values."""
|
39 |
+
return None
|
40 |
+
|
41 |
+
|
42 |
+
class InputRowButtonGroup(ButtonGroup):
|
43 |
+
"""Button group for input rows with delete and add buttons."""
|
44 |
+
|
45 |
+
EVENTS = ["delete", "add"]
|
46 |
+
|
47 |
+
def __init__(self, *args, **kwargs):
|
48 |
+
super().__init__(self.EVENTS, *args, **kwargs)
|
49 |
+
|
50 |
+
def render(self):
|
51 |
+
with gr.Column(scale=0, min_width=40, elem_classes="button-column"):
|
52 |
+
self.buttons["delete"] = gr.Button("❌", elem_classes="icon-button delete-button", scale=0)
|
53 |
+
self.buttons["add"] = gr.Button("➕", elem_classes="icon-button add-field-button", scale=0)
|
54 |
+
super().render()
|
55 |
+
|
56 |
+
def delete(self, fn, inputs, outputs):
|
57 |
+
self._setup_button("delete", fn, inputs, outputs)
|
58 |
+
|
59 |
+
def add(self, fn, inputs, outputs):
|
60 |
+
self._setup_button("add", fn, inputs, outputs)
|
61 |
+
|
62 |
+
|
63 |
+
class OutputRowButtonGroup(ButtonGroup):
|
64 |
+
"""Button group for output rows with delete, add, up, and down buttons."""
|
65 |
+
|
66 |
+
EVENTS = ["delete", "add", "up", "down"]
|
67 |
+
|
68 |
+
def __init__(self, *args, **kwargs):
|
69 |
+
super().__init__(self.EVENTS, *args, **kwargs)
|
70 |
+
|
71 |
+
def render(self):
|
72 |
+
with gr.Column(scale=0, elem_classes="button-column", min_width=40):
|
73 |
+
self.buttons["delete"] = gr.Button("❌", elem_classes="icon-button delete-button", scale=0)
|
74 |
+
self.buttons["add"] = gr.Button("➕", elem_classes="icon-button add-field-button", scale=0)
|
75 |
+
|
76 |
+
with gr.Column(scale=0, elem_classes="button-column", min_width=40):
|
77 |
+
self.buttons["up"] = gr.Button("⬆️", elem_classes="icon-button up-button", scale=0)
|
78 |
+
self.buttons["down"] = gr.Button("⬇️", elem_classes="icon-button down-button", scale=0)
|
79 |
+
return super().render()
|
80 |
+
|
81 |
+
def delete(self, fn, inputs, outputs):
|
82 |
+
self._setup_button("delete", fn, inputs, outputs)
|
83 |
+
|
84 |
+
def add(self, fn, inputs, outputs):
|
85 |
+
self._setup_button("add", fn, inputs, outputs)
|
86 |
+
|
87 |
+
def up(self, fn, inputs, outputs):
|
88 |
+
self._setup_button("up", fn, inputs, outputs)
|
89 |
+
|
90 |
+
def down(self, fn, inputs, outputs):
|
91 |
+
self._setup_button("down", fn, inputs, outputs)
|
src/components/quizbowl/__init__.py
ADDED
File without changes
|
src/components/quizbowl/bonus.py
ADDED
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
from typing import Any
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
from datasets import Dataset
|
10 |
+
|
11 |
+
from components.model_pipeline.model_pipeline import PipelineInterface, PipelineState
|
12 |
+
from submission import submit
|
13 |
+
from workflows import factory
|
14 |
+
from workflows.qb.simple_agent import SimpleBonusAgent
|
15 |
+
from workflows.structs import ModelStep, Workflow
|
16 |
+
|
17 |
+
from .plotting import (
|
18 |
+
create_pyplot,
|
19 |
+
create_scatter_pyplot,
|
20 |
+
evaluate_buzz,
|
21 |
+
update_plot,
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
def evaluate_bonus_part(prediction: str, clean_answers: list[str]) -> float:
|
26 |
+
"""Evaluate a single bonus part."""
|
27 |
+
return evaluate_buzz(prediction, clean_answers)
|
28 |
+
|
29 |
+
|
30 |
+
def process_bonus_results(results: list[dict]) -> pd.DataFrame:
|
31 |
+
"""Process results from bonus mode and prepare visualization data."""
|
32 |
+
return pd.DataFrame(
|
33 |
+
[
|
34 |
+
{
|
35 |
+
"Part": f"Part {r['part_number']}",
|
36 |
+
"Correct?": "✅" if r["score"] == 1 else "❌",
|
37 |
+
"Confidence": r["confidence"],
|
38 |
+
"Prediction": r["answer"],
|
39 |
+
"Explanation": r["explanation"],
|
40 |
+
}
|
41 |
+
for r in results
|
42 |
+
]
|
43 |
+
)
|
44 |
+
|
45 |
+
|
46 |
+
def initialize_eval_interface(example: dict, model_outputs: list[dict]):
|
47 |
+
"""Initialize the interface with example text."""
|
48 |
+
try:
|
49 |
+
# Create HTML for leadin and parts
|
50 |
+
leadin_html = f"<div class='leadin'>{example['leadin']}</div>"
|
51 |
+
parts_html = []
|
52 |
+
for i, part in enumerate(example["parts"]):
|
53 |
+
parts_html.append(f"<div class='part'><b>Part {i + 1}:</b> {part['part']}</div>")
|
54 |
+
|
55 |
+
html_content = f"{leadin_html}<div class='parts-container'>{''.join(parts_html)}</div>"
|
56 |
+
|
57 |
+
# Create confidence plot data
|
58 |
+
plot_data = create_bonus_confidence_plot(example["parts"], model_outputs)
|
59 |
+
|
60 |
+
# Store state
|
61 |
+
state = json.dumps({"parts": example["parts"], "outputs": model_outputs})
|
62 |
+
|
63 |
+
return html_content, plot_data, state
|
64 |
+
except Exception as e:
|
65 |
+
logging.error(f"Error initializing interface: {e}", exc_info=True)
|
66 |
+
return f"<div>Error initializing interface: {str(e)}</div>", pd.DataFrame(), "{}"
|
67 |
+
|
68 |
+
|
69 |
+
def create_bonus_confidence_plot(parts: list[dict], model_outputs: list[dict]):
|
70 |
+
"""Create confidence plot for bonus parts."""
|
71 |
+
plt.style.use("ggplot")
|
72 |
+
fig = plt.figure(figsize=(10, 6))
|
73 |
+
ax = fig.add_subplot(111)
|
74 |
+
|
75 |
+
# Plot confidence for each part
|
76 |
+
x = range(1, len(parts) + 1)
|
77 |
+
confidences = [output["confidence"] for output in model_outputs]
|
78 |
+
scores = [output["score"] for output in model_outputs]
|
79 |
+
|
80 |
+
# Plot confidence bars
|
81 |
+
bars = ax.bar(x, confidences, color="#4698cf")
|
82 |
+
|
83 |
+
# Color bars based on correctness
|
84 |
+
for i, score in enumerate(scores):
|
85 |
+
bars[i].set_color("green" if score == 1 else "red")
|
86 |
+
|
87 |
+
ax.set_title("Part Confidence")
|
88 |
+
ax.set_xlabel("Part Number")
|
89 |
+
ax.set_ylabel("Confidence")
|
90 |
+
ax.set_xticks(x)
|
91 |
+
ax.set_xticklabels([f"Part {i}" for i in x])
|
92 |
+
|
93 |
+
return fig
|
94 |
+
|
95 |
+
|
96 |
+
def validate_workflow(workflow: Workflow):
|
97 |
+
"""Validate that a workflow is properly configured for the bonus task."""
|
98 |
+
if not workflow.steps:
|
99 |
+
raise ValueError("Workflow must have at least one step")
|
100 |
+
|
101 |
+
# Ensure all steps are properly configured
|
102 |
+
for step_id, step in workflow.steps.items():
|
103 |
+
validate_model_step(step)
|
104 |
+
|
105 |
+
# Check that the workflow has the correct structure
|
106 |
+
input_vars = set(workflow.inputs)
|
107 |
+
if "leadin" not in input_vars or "part" not in input_vars:
|
108 |
+
raise ValueError("Workflow must have 'leadin' and 'part' as inputs")
|
109 |
+
|
110 |
+
output_vars = set(workflow.outputs)
|
111 |
+
if not all(var in output_vars for var in ["answer", "confidence", "explanation"]):
|
112 |
+
raise ValueError("Workflow must produce 'answer', 'confidence', and 'explanation' as outputs")
|
113 |
+
|
114 |
+
|
115 |
+
def validate_model_step(model_step: ModelStep):
|
116 |
+
"""Validate that a model step is properly configured for the bonus task."""
|
117 |
+
# Check required fields
|
118 |
+
if not model_step.model or not model_step.provider:
|
119 |
+
raise ValueError("Model step must have both model and provider specified")
|
120 |
+
|
121 |
+
if model_step.call_type != "llm":
|
122 |
+
raise ValueError("Model step must have call_type 'llm'")
|
123 |
+
|
124 |
+
# Validate temperature for LLM steps
|
125 |
+
if model_step.temperature is None:
|
126 |
+
raise ValueError("Temperature must be specified for LLM model steps")
|
127 |
+
|
128 |
+
if not (0.0 <= model_step.temperature <= 1.0):
|
129 |
+
raise ValueError(f"Temperature must be between 0.0 and 1.0, got {model_step.temperature}")
|
130 |
+
|
131 |
+
# Validate input fields
|
132 |
+
input_field_names = {field.name for field in model_step.input_fields}
|
133 |
+
if "leadin" not in input_field_names or "part" not in input_field_names:
|
134 |
+
raise ValueError("Model step must have 'leadin' and 'part' input fields")
|
135 |
+
|
136 |
+
# Validate output fields
|
137 |
+
output_field_names = {field.name for field in model_step.output_fields}
|
138 |
+
required_outputs = {"answer", "confidence", "explanation"}
|
139 |
+
if not all(out in output_field_names for out in required_outputs):
|
140 |
+
raise ValueError("Model step must have all required output fields: answer, confidence, explanation")
|
141 |
+
|
142 |
+
# Validate confidence output field is of type float
|
143 |
+
for field in model_step.output_fields:
|
144 |
+
if field.name == "confidence" and field.type != "float":
|
145 |
+
raise ValueError("The 'confidence' output field must be of type 'float'")
|
146 |
+
|
147 |
+
|
148 |
+
class BonusInterface:
|
149 |
+
"""Gradio interface for the Bonus mode."""
|
150 |
+
|
151 |
+
def __init__(self, app: gr.Blocks, dataset: Dataset, model_options: dict, defaults: dict):
|
152 |
+
"""Initialize the Bonus interface."""
|
153 |
+
logging.info(f"Initializing Bonus interface with dataset size: {len(dataset)}")
|
154 |
+
self.ds = dataset
|
155 |
+
self.model_options = model_options
|
156 |
+
self.app = app
|
157 |
+
self.defaults = defaults
|
158 |
+
self.output_state = gr.State(value="{}")
|
159 |
+
self.render()
|
160 |
+
|
161 |
+
def _render_model_interface(self, workflow: Workflow, simple: bool = True):
|
162 |
+
"""Render the model interface."""
|
163 |
+
self.pipeline_interface = PipelineInterface(
|
164 |
+
workflow,
|
165 |
+
simple=simple,
|
166 |
+
model_options=list(self.model_options.keys()),
|
167 |
+
)
|
168 |
+
with gr.Row():
|
169 |
+
self.run_btn = gr.Button("Run Bonus", variant="primary")
|
170 |
+
|
171 |
+
def _render_qb_interface(self):
|
172 |
+
"""Render the quizbowl interface."""
|
173 |
+
with gr.Row():
|
174 |
+
self.qid_selector = gr.Number(
|
175 |
+
label="Question ID", value=1, precision=0, minimum=1, maximum=len(self.ds), show_label=True, scale=0
|
176 |
+
)
|
177 |
+
self.answer_display = gr.Textbox(
|
178 |
+
label="Answers", elem_id="answer-display", elem_classes="answer-box", interactive=False, scale=1
|
179 |
+
)
|
180 |
+
self.clean_answer_display = gr.Textbox(
|
181 |
+
label="Acceptable Answers",
|
182 |
+
elem_id="answer-display-2",
|
183 |
+
elem_classes="answer-box",
|
184 |
+
interactive=False,
|
185 |
+
scale=2,
|
186 |
+
)
|
187 |
+
|
188 |
+
self.question_display = gr.HTML(label="Question", elem_id="question-display")
|
189 |
+
with gr.Row():
|
190 |
+
self.confidence_plot = gr.Plot(
|
191 |
+
label="Part Confidence",
|
192 |
+
format="webp",
|
193 |
+
)
|
194 |
+
|
195 |
+
self.results_table = gr.DataFrame(
|
196 |
+
label="Model Outputs",
|
197 |
+
value=pd.DataFrame(columns=["Part", "Correct?", "Confidence", "Prediction", "Explanation"]),
|
198 |
+
)
|
199 |
+
|
200 |
+
with gr.Row():
|
201 |
+
self.eval_btn = gr.Button("Evaluate")
|
202 |
+
|
203 |
+
with gr.Accordion("Model Submission", elem_classes="model-submission-accordion", open=True):
|
204 |
+
with gr.Row():
|
205 |
+
self.model_name_input = gr.Textbox(label="Model Name")
|
206 |
+
self.description_input = gr.Textbox(label="Description")
|
207 |
+
with gr.Row():
|
208 |
+
gr.LoginButton()
|
209 |
+
self.submit_btn = gr.Button("Submit")
|
210 |
+
self.submit_status = gr.HTML(label="Submission Status")
|
211 |
+
|
212 |
+
def render(self):
|
213 |
+
"""Create the Gradio interface."""
|
214 |
+
self.hidden_input = gr.Textbox(value="", visible=False, elem_id="hidden-index")
|
215 |
+
workflow = self.defaults["init_workflow"]
|
216 |
+
|
217 |
+
with gr.Row():
|
218 |
+
# Model Panel
|
219 |
+
with gr.Column(scale=1):
|
220 |
+
self._render_model_interface(workflow, simple=self.defaults["simple_workflow"])
|
221 |
+
|
222 |
+
with gr.Column(scale=1):
|
223 |
+
self._render_qb_interface()
|
224 |
+
|
225 |
+
self._setup_event_listeners()
|
226 |
+
|
227 |
+
def get_new_question_html(self, question_id: int):
|
228 |
+
"""Get the HTML for a new question."""
|
229 |
+
example = self.ds[question_id - 1]
|
230 |
+
leadin = example["leadin"]
|
231 |
+
parts = example["parts"]
|
232 |
+
|
233 |
+
# Create HTML for leadin and parts
|
234 |
+
leadin_html = f"<div class='leadin'>{leadin}</div>"
|
235 |
+
parts_html = []
|
236 |
+
for i, part in enumerate(parts):
|
237 |
+
parts_html.append(f"<div class='part'>{part['part']}</div>")
|
238 |
+
|
239 |
+
parts_html_str = "<br>".join(parts_html)
|
240 |
+
|
241 |
+
html_content = (
|
242 |
+
f"<div class='token-container'>{leadin_html}<div class='parts-container'><br>{parts_html_str}</div></div>"
|
243 |
+
)
|
244 |
+
|
245 |
+
# Format answers
|
246 |
+
primary_answers = [f"{i + 1}. {part['answer_primary']}" for i, part in enumerate(parts)]
|
247 |
+
clean_answers = []
|
248 |
+
for i, part in enumerate(parts):
|
249 |
+
part_answers = [a for a in part["clean_answers"] if len(a.split()) <= 6]
|
250 |
+
clean_answers.append(f"{i + 1}. {', '.join(part_answers)}")
|
251 |
+
|
252 |
+
return html_content, "\n".join(primary_answers), "\n".join(clean_answers)
|
253 |
+
|
254 |
+
def get_model_outputs(self, example: dict, pipeline_state: PipelineState):
|
255 |
+
"""Get the model outputs for a given question ID."""
|
256 |
+
outputs = []
|
257 |
+
leadin = example["leadin"]
|
258 |
+
|
259 |
+
for i, part in enumerate(example["parts"]):
|
260 |
+
agent = SimpleBonusAgent(workflow=pipeline_state.workflow)
|
261 |
+
# Run model for each part
|
262 |
+
part_output = agent.run(leadin, part["part"])
|
263 |
+
|
264 |
+
# Add part number and evaluate score
|
265 |
+
part_output["part_number"] = i + 1
|
266 |
+
part_output["score"] = evaluate_bonus_part(part_output["answer"], part["clean_answers"])
|
267 |
+
|
268 |
+
outputs.append(part_output)
|
269 |
+
|
270 |
+
return outputs
|
271 |
+
|
272 |
+
def run_bonus(
|
273 |
+
self,
|
274 |
+
question_id: int,
|
275 |
+
pipeline_state: PipelineState,
|
276 |
+
) -> tuple[str, Any, Any]:
|
277 |
+
"""Run the agent in bonus mode."""
|
278 |
+
try:
|
279 |
+
# Validate inputs
|
280 |
+
question_id = int(question_id - 1)
|
281 |
+
if not self.ds or question_id < 0 or question_id >= len(self.ds):
|
282 |
+
return "Invalid question ID or dataset not loaded", None, None
|
283 |
+
|
284 |
+
example = self.ds[question_id]
|
285 |
+
outputs = self.get_model_outputs(example, pipeline_state)
|
286 |
+
|
287 |
+
# Process results and prepare visualization data
|
288 |
+
html_content, plot_data, output_state = initialize_eval_interface(example, outputs)
|
289 |
+
df = process_bonus_results(outputs)
|
290 |
+
|
291 |
+
return (
|
292 |
+
html_content,
|
293 |
+
gr.update(value=plot_data, label=f"Part Confidence on Question {question_id + 1}"),
|
294 |
+
gr.update(value=output_state),
|
295 |
+
gr.update(value=df, label=f"Model Outputs for Question {question_id + 1}"),
|
296 |
+
)
|
297 |
+
except Exception as e:
|
298 |
+
import traceback
|
299 |
+
|
300 |
+
error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
|
301 |
+
return error_msg, None, None
|
302 |
+
|
303 |
+
def evaluate_bonus(self, pipeline_state: PipelineState, progress: gr.Progress = gr.Progress()):
|
304 |
+
"""Evaluate the bonus questions."""
|
305 |
+
try:
|
306 |
+
# Validate inputs
|
307 |
+
if not self.ds or not self.ds.num_rows:
|
308 |
+
return "No dataset loaded", None, None
|
309 |
+
|
310 |
+
total_correct = 0
|
311 |
+
total_parts = 0
|
312 |
+
part_scores = []
|
313 |
+
part_numbers = []
|
314 |
+
|
315 |
+
for example in progress.tqdm(self.ds, desc="Evaluating bonus questions"):
|
316 |
+
model_outputs = self.get_model_outputs(example, pipeline_state)
|
317 |
+
|
318 |
+
for output in model_outputs:
|
319 |
+
total_parts += 1
|
320 |
+
if output["score"] == 1:
|
321 |
+
total_correct += 1
|
322 |
+
part_scores.append(output["score"])
|
323 |
+
part_numbers.append(output["part_number"])
|
324 |
+
|
325 |
+
accuracy = total_correct / total_parts
|
326 |
+
df = pd.DataFrame(
|
327 |
+
[
|
328 |
+
{
|
329 |
+
"Part Accuracy": f"{accuracy:.2%}",
|
330 |
+
"Total Score": f"{total_correct}/{total_parts}",
|
331 |
+
"Questions Evaluated": len(self.ds),
|
332 |
+
}
|
333 |
+
]
|
334 |
+
)
|
335 |
+
|
336 |
+
plot_data = create_scatter_pyplot(part_numbers, part_scores)
|
337 |
+
return (
|
338 |
+
gr.update(value=df, label="Scores on Sample Set"),
|
339 |
+
gr.update(value=plot_data, label="Part Scores on Sample Set"),
|
340 |
+
)
|
341 |
+
except Exception:
|
342 |
+
import traceback
|
343 |
+
|
344 |
+
logging.error(f"Error evaluating bonus: {traceback.format_exc()}")
|
345 |
+
return "Error evaluating bonus", None, None
|
346 |
+
|
347 |
+
def submit_model(
|
348 |
+
self, model_name: str, description: str, pipeline_state: PipelineState, profile: gr.OAuthProfile = None
|
349 |
+
):
|
350 |
+
"""Submit the model output."""
|
351 |
+
return submit.submit_model(model_name, description, pipeline_state.workflow, "bonus", profile)
|
352 |
+
|
353 |
+
def _setup_event_listeners(self):
|
354 |
+
# Initialize with the default question (ID 0)
|
355 |
+
|
356 |
+
gr.on(
|
357 |
+
triggers=[self.app.load, self.qid_selector.change],
|
358 |
+
fn=self.get_new_question_html,
|
359 |
+
inputs=[self.qid_selector],
|
360 |
+
outputs=[self.question_display, self.answer_display, self.clean_answer_display],
|
361 |
+
)
|
362 |
+
self.run_btn.click(
|
363 |
+
self.pipeline_interface.validate_workflow,
|
364 |
+
inputs=[self.pipeline_interface.pipeline_state],
|
365 |
+
outputs=[self.pipeline_interface.pipeline_state],
|
366 |
+
).success(
|
367 |
+
self.run_bonus,
|
368 |
+
inputs=[
|
369 |
+
self.qid_selector,
|
370 |
+
self.pipeline_interface.pipeline_state,
|
371 |
+
],
|
372 |
+
outputs=[
|
373 |
+
self.question_display,
|
374 |
+
self.confidence_plot,
|
375 |
+
self.output_state,
|
376 |
+
self.results_table,
|
377 |
+
],
|
378 |
+
)
|
379 |
+
|
380 |
+
self.eval_btn.click(
|
381 |
+
fn=self.evaluate_bonus,
|
382 |
+
inputs=[self.pipeline_interface.pipeline_state],
|
383 |
+
outputs=[self.results_table, self.confidence_plot],
|
384 |
+
)
|
385 |
+
|
386 |
+
self.submit_btn.click(
|
387 |
+
fn=self.submit_model_output,
|
388 |
+
inputs=[
|
389 |
+
self.model_name_input,
|
390 |
+
self.description_input,
|
391 |
+
self.pipeline_interface.pipeline_state,
|
392 |
+
],
|
393 |
+
outputs=[self.submit_status],
|
394 |
+
)
|
395 |
+
self.hidden_input.change(
|
396 |
+
fn=update_plot,
|
397 |
+
inputs=[self.hidden_input, self.output_state],
|
398 |
+
outputs=[self.confidence_plot],
|
399 |
+
)
|
src/components/quizbowl/plotting.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import re
|
4 |
+
from collections import Counter
|
5 |
+
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import pandas as pd
|
8 |
+
|
9 |
+
|
10 |
+
def evaluate_buzz(prediction: str, clean_answers: list[str] | str) -> int:
|
11 |
+
"""Evaluate the buzz of a prediction against the clean answers."""
|
12 |
+
if isinstance(clean_answers, str):
|
13 |
+
print("clean_answers is a string")
|
14 |
+
clean_answers = [clean_answers]
|
15 |
+
pred = prediction.lower().strip()
|
16 |
+
if not pred:
|
17 |
+
return 0
|
18 |
+
for answer in clean_answers:
|
19 |
+
answer = answer.strip().lower()
|
20 |
+
if answer and answer in pred:
|
21 |
+
print(f"Found {answer} in {pred}")
|
22 |
+
return 1
|
23 |
+
return 0
|
24 |
+
|
25 |
+
|
26 |
+
def create_answer_html(answer: str):
|
27 |
+
"""Create HTML for the answer."""
|
28 |
+
return f"<div class='answer-header'>Answer:<br>{answer}</div>"
|
29 |
+
|
30 |
+
|
31 |
+
def create_tokens_html(tokens: list[str], eval_points: list[tuple], answer: str, marker_indices: list[int] = None):
|
32 |
+
"""Create HTML for tokens with hover capability and a colored header for the answer."""
|
33 |
+
try:
|
34 |
+
html_parts = []
|
35 |
+
ep = dict(eval_points)
|
36 |
+
marker_indices = set(marker_indices) if isinstance(marker_indices, list) else set()
|
37 |
+
|
38 |
+
# Add a colored header for the answer
|
39 |
+
# html_parts.append(create_answer_html(answer))
|
40 |
+
|
41 |
+
for i, token in enumerate(tokens):
|
42 |
+
# Check if this token is a buzz point
|
43 |
+
values = ep.get(i, (None, 0, 0))
|
44 |
+
confidence, buzz_point, score = values
|
45 |
+
|
46 |
+
# Replace non-word characters for proper display in HTML
|
47 |
+
display_token = token
|
48 |
+
if not re.match(r"\w+", token):
|
49 |
+
display_token = token.replace(" ", " ")
|
50 |
+
|
51 |
+
# Add buzz marker class if it's a buzz point
|
52 |
+
if confidence is None:
|
53 |
+
css_class = ""
|
54 |
+
elif not buzz_point:
|
55 |
+
css_class = " guess-point no-buzz"
|
56 |
+
else:
|
57 |
+
css_class = f" guess-point buzz-{score}"
|
58 |
+
|
59 |
+
token_html = f'<span id="token-{i}" class="token{css_class}" data-index="{i}">{display_token}</span>'
|
60 |
+
if i in marker_indices:
|
61 |
+
token_html += "<span style='color: rgba(0,0,255,0.3);'>|</span>"
|
62 |
+
html_parts.append(token_html)
|
63 |
+
|
64 |
+
return f"<div class='token-container'>{''.join(html_parts)}</div>"
|
65 |
+
except Exception as e:
|
66 |
+
logging.error(f"Error creating token HTML: {e}", exc_info=True)
|
67 |
+
return f"<div class='token-container'>Error creating tokens: {str(e)}</div>"
|
68 |
+
|
69 |
+
|
70 |
+
def create_line_plot(eval_points, highlighted_index=-1):
|
71 |
+
"""Create a Gradio LinePlot of token values with optional highlighting using DataFrame."""
|
72 |
+
try:
|
73 |
+
# Create base confidence data
|
74 |
+
data = []
|
75 |
+
|
76 |
+
# Add buzz points to the plot
|
77 |
+
for i, (v, b) in eval_points:
|
78 |
+
color = "#ff4444" if b == 0 else "#228b22"
|
79 |
+
data.append(
|
80 |
+
{
|
81 |
+
"position": i,
|
82 |
+
"value": v,
|
83 |
+
"type": "buzz",
|
84 |
+
"highlight": True,
|
85 |
+
"color": color,
|
86 |
+
}
|
87 |
+
)
|
88 |
+
|
89 |
+
if highlighted_index >= 0:
|
90 |
+
# Add vertical line for the highlighted token
|
91 |
+
data.extend(
|
92 |
+
[
|
93 |
+
{
|
94 |
+
"position": highlighted_index,
|
95 |
+
"value": 0,
|
96 |
+
"type": "hover-line",
|
97 |
+
"color": "#000000",
|
98 |
+
"highlight": True,
|
99 |
+
},
|
100 |
+
{
|
101 |
+
"position": highlighted_index,
|
102 |
+
"value": 1,
|
103 |
+
"type": "hover-line",
|
104 |
+
"color": "#000000",
|
105 |
+
"highlight": True,
|
106 |
+
},
|
107 |
+
]
|
108 |
+
)
|
109 |
+
|
110 |
+
return pd.DataFrame(data)
|
111 |
+
except Exception as e:
|
112 |
+
logging.error(f"Error creating line plot: {e}", exc_info=True)
|
113 |
+
# Return an empty DataFrame with the expected columns
|
114 |
+
return pd.DataFrame(columns=["position", "value", "type", "highlight", "color"])
|
115 |
+
|
116 |
+
|
117 |
+
def create_pyplot(tokens, eval_points, highlighted_index=-1):
|
118 |
+
"""Create a pyplot of token values with optional highlighting."""
|
119 |
+
plt.style.use("ggplot") # Set theme to grid paper
|
120 |
+
fig = plt.figure(figsize=(10, 6)) # Set figure size
|
121 |
+
ax = fig.add_subplot(111)
|
122 |
+
x = [0]
|
123 |
+
y = [0]
|
124 |
+
for i, (v, b, s) in eval_points:
|
125 |
+
x.append(i + 1)
|
126 |
+
y.append(v)
|
127 |
+
|
128 |
+
ax.plot(x, y, "o--", color="#4698cf")
|
129 |
+
for i, (v, b, s) in eval_points:
|
130 |
+
if not b:
|
131 |
+
continue
|
132 |
+
color = "green" if s else "red"
|
133 |
+
ax.plot(i + 1, v, "o", color=color)
|
134 |
+
if i >= len(tokens):
|
135 |
+
print(f"Token index {i} is out of bounds for n_tokens: {len(tokens)}")
|
136 |
+
ax.annotate(f"{tokens[i]}", (i + 1, v), textcoords="offset points", xytext=(0, 10), ha="center")
|
137 |
+
|
138 |
+
if highlighted_index >= 0:
|
139 |
+
# Add light vertical line for the highlighted token from 0 to 1
|
140 |
+
ax.axvline(x=highlighted_index + 1, color="#ff9900", linestyle="--", ymin=0, ymax=1)
|
141 |
+
|
142 |
+
ax.set_title("Buzz Confidence")
|
143 |
+
ax.set_xlabel("Token Index")
|
144 |
+
ax.set_ylabel("Confidence")
|
145 |
+
ax.set_xticks(x)
|
146 |
+
ax.set_xticklabels(x)
|
147 |
+
return fig
|
148 |
+
|
149 |
+
|
150 |
+
def create_scatter_pyplot(token_positions, scores):
|
151 |
+
"""Create a scatter plot of token positions and scores."""
|
152 |
+
plt.style.use("ggplot")
|
153 |
+
fig = plt.figure(figsize=(10, 6))
|
154 |
+
ax = fig.add_subplot(111)
|
155 |
+
|
156 |
+
counts = Counter(zip(token_positions, scores))
|
157 |
+
X = []
|
158 |
+
Y = []
|
159 |
+
S = []
|
160 |
+
for (pos, score), size in counts.items():
|
161 |
+
X.append(pos)
|
162 |
+
Y.append(score)
|
163 |
+
S.append(size * 20)
|
164 |
+
|
165 |
+
ax.scatter(X, Y, color="#4698cf", s=S)
|
166 |
+
|
167 |
+
return fig
|
168 |
+
|
169 |
+
|
170 |
+
def update_plot(highlighted_index, state):
|
171 |
+
"""Update the plot when a token is hovered; add a vertical line on the plot."""
|
172 |
+
try:
|
173 |
+
if not state or state == "{}":
|
174 |
+
logging.warning("Empty state provided to update_plot")
|
175 |
+
return pd.DataFrame()
|
176 |
+
|
177 |
+
highlighted_index = int(highlighted_index) if highlighted_index else None
|
178 |
+
logging.info(f"Update plot triggered with token index: {highlighted_index}")
|
179 |
+
|
180 |
+
data = json.loads(state)
|
181 |
+
tokens = data.get("tokens", [])
|
182 |
+
values = data.get("values", [])
|
183 |
+
|
184 |
+
if not tokens or not values:
|
185 |
+
logging.warning("No tokens or values found in state")
|
186 |
+
return pd.DataFrame()
|
187 |
+
|
188 |
+
# Create updated plot with highlighting of the token point
|
189 |
+
# plot_data = create_line_plot(values, highlighted_index)
|
190 |
+
plot_data = create_pyplot(tokens, values, highlighted_index)
|
191 |
+
return plot_data
|
192 |
+
except Exception as e:
|
193 |
+
logging.error(f"Error updating plot: {e}")
|
194 |
+
return pd.DataFrame()
|
src/components/quizbowl/tossup.py
ADDED
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
from typing import Any
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
import numpy as np
|
7 |
+
import pandas as pd
|
8 |
+
from datasets import Dataset
|
9 |
+
|
10 |
+
from components.model_pipeline.model_pipeline import PipelineInterface, PipelineState
|
11 |
+
from submission import submit
|
12 |
+
from workflows.qb.simple_agent import SimpleTossupAgent
|
13 |
+
from workflows.structs import ModelStep, Workflow
|
14 |
+
|
15 |
+
from .plotting import (
|
16 |
+
create_answer_html,
|
17 |
+
create_pyplot,
|
18 |
+
create_scatter_pyplot,
|
19 |
+
create_tokens_html,
|
20 |
+
evaluate_buzz,
|
21 |
+
update_plot,
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
def add_model_scores(model_outputs: list[dict], clean_answers: list[str], run_indices: list[int]) -> list[dict]:
|
26 |
+
"""Add model scores to the model outputs."""
|
27 |
+
for output, run_idx in zip(model_outputs, run_indices):
|
28 |
+
output["score"] = evaluate_buzz(output["answer"], clean_answers)
|
29 |
+
output["token_position"] = run_idx + 1
|
30 |
+
return model_outputs
|
31 |
+
|
32 |
+
|
33 |
+
def prepare_buzz_evals(
|
34 |
+
run_indices: list[int], model_outputs: list[dict]
|
35 |
+
) -> tuple[list[str], list[tuple[int, float, bool]]]:
|
36 |
+
"""Process text into tokens and assign random values for demonstration."""
|
37 |
+
if not run_indices:
|
38 |
+
logging.warning("No run indices provided, returning empty results")
|
39 |
+
return [], []
|
40 |
+
eval_points = []
|
41 |
+
for i, v in zip(run_indices, model_outputs):
|
42 |
+
eval_point = v["confidence"], v["buzz"], v["score"]
|
43 |
+
eval_points.append((int(i), eval_point))
|
44 |
+
|
45 |
+
return eval_points
|
46 |
+
|
47 |
+
|
48 |
+
def initialize_eval_interface(example, model_outputs: list[dict]):
|
49 |
+
"""Initialize the interface with example text."""
|
50 |
+
tokens = example["question"].split()
|
51 |
+
run_indices = example["run_indices"]
|
52 |
+
answer = example["answer_primary"]
|
53 |
+
|
54 |
+
try:
|
55 |
+
eval_points = prepare_buzz_evals(run_indices, model_outputs)
|
56 |
+
|
57 |
+
if not tokens:
|
58 |
+
return "<div>No tokens found in the provided text.</div>", pd.DataFrame(), "{}"
|
59 |
+
highlighted_index = next((int(i) for i, (_, b, _) in eval_points if b == 1), -1)
|
60 |
+
html_content = create_tokens_html(tokens, eval_points, answer)
|
61 |
+
plot_data = create_pyplot(tokens, eval_points, highlighted_index)
|
62 |
+
|
63 |
+
# Store tokens, values, and buzzes as JSON for later use
|
64 |
+
state = json.dumps({"tokens": tokens, "values": eval_points})
|
65 |
+
|
66 |
+
return html_content, plot_data, state
|
67 |
+
except Exception as e:
|
68 |
+
logging.error(f"Error initializing interface: {e}", exc_info=True)
|
69 |
+
return f"<div>Error initializing interface: {str(e)}</div>", pd.DataFrame(), "{}"
|
70 |
+
|
71 |
+
|
72 |
+
def process_tossup_results(results: list[dict], top_k_mode: bool = False) -> pd.DataFrame:
|
73 |
+
"""Process results from tossup mode and prepare visualization data."""
|
74 |
+
# Create DataFrame for detailed results
|
75 |
+
if top_k_mode:
|
76 |
+
raise ValueError("Top-k mode not supported for tossup mode")
|
77 |
+
return pd.DataFrame(
|
78 |
+
[
|
79 |
+
{
|
80 |
+
"Token Position": r["token_position"],
|
81 |
+
"Correct?": "✅" if r["score"] == 1 else "❌",
|
82 |
+
"Confidence": r["confidence"],
|
83 |
+
"Prediction": r["answer"],
|
84 |
+
}
|
85 |
+
for r in results
|
86 |
+
]
|
87 |
+
)
|
88 |
+
|
89 |
+
|
90 |
+
def validate_workflow(workflow: Workflow):
|
91 |
+
"""
|
92 |
+
Validate that a workflow is properly configured for the tossup task.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
workflow (Workflow): The workflow to validate
|
96 |
+
|
97 |
+
Raises:
|
98 |
+
ValueError: If the workflow is not properly configured
|
99 |
+
"""
|
100 |
+
if not workflow.steps:
|
101 |
+
raise ValueError("Workflow must have at least one step")
|
102 |
+
|
103 |
+
# Ensure all steps are properly configured
|
104 |
+
for step_id, step in workflow.steps.items():
|
105 |
+
validate_model_step(step)
|
106 |
+
|
107 |
+
# Check that the workflow has the correct structure
|
108 |
+
input_vars = set(workflow.inputs)
|
109 |
+
if "question" not in input_vars:
|
110 |
+
raise ValueError("Workflow must have 'question' as an input")
|
111 |
+
|
112 |
+
output_vars = set(workflow.outputs)
|
113 |
+
if not any("answer" in out_var for out_var in output_vars):
|
114 |
+
raise ValueError("Workflow must produce an 'answer' as output")
|
115 |
+
if not any("confidence" in out_var for out_var in output_vars):
|
116 |
+
raise ValueError("Workflow must produce a 'confidence' score as output")
|
117 |
+
|
118 |
+
|
119 |
+
def validate_model_step(model_step: ModelStep):
|
120 |
+
"""
|
121 |
+
Validate that a model step is properly configured for the tossup task.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
model_step (ModelStep): The model step to validate
|
125 |
+
|
126 |
+
Raises:
|
127 |
+
ValueError: If the model step is not properly configured
|
128 |
+
"""
|
129 |
+
# Check required fields
|
130 |
+
if not model_step.model or not model_step.provider:
|
131 |
+
raise ValueError("Model step must have both model and provider specified")
|
132 |
+
|
133 |
+
if model_step.call_type != "llm":
|
134 |
+
raise ValueError("Model step must have call_type 'llm'")
|
135 |
+
|
136 |
+
# Validate temperature for LLM steps
|
137 |
+
if model_step.temperature is None:
|
138 |
+
raise ValueError("Temperature must be specified for LLM model steps")
|
139 |
+
|
140 |
+
if not (0.0 <= model_step.temperature <= 1.0):
|
141 |
+
raise ValueError(f"Temperature must be between 0.0 and 1.0, got {model_step.temperature}")
|
142 |
+
|
143 |
+
# Validate input fields
|
144 |
+
input_field_names = {field.name for field in model_step.input_fields}
|
145 |
+
if "question" not in input_field_names:
|
146 |
+
raise ValueError("Model step must have a 'question' input field")
|
147 |
+
|
148 |
+
# Validate output fields
|
149 |
+
output_field_names = {field.name for field in model_step.output_fields}
|
150 |
+
if "answer" not in output_field_names:
|
151 |
+
raise ValueError("Model step must have an 'answer' output field")
|
152 |
+
if "confidence" not in output_field_names:
|
153 |
+
raise ValueError("Model step must have a 'confidence' output field")
|
154 |
+
|
155 |
+
# Validate confidence output field is of type float
|
156 |
+
for field in model_step.output_fields:
|
157 |
+
if field.name == "confidence" and field.type != "float":
|
158 |
+
raise ValueError("The 'confidence' output field must be of type 'float'")
|
159 |
+
|
160 |
+
|
161 |
+
class TossupInterface:
|
162 |
+
"""Gradio interface for the Tossup mode."""
|
163 |
+
|
164 |
+
def __init__(self, app: gr.Blocks, dataset: Dataset, model_options: dict, defaults: dict):
|
165 |
+
"""Initialize the Tossup interface."""
|
166 |
+
logging.info(f"Initializing Tossup interface with dataset size: {len(dataset)}")
|
167 |
+
self.ds = dataset
|
168 |
+
self.model_options = model_options
|
169 |
+
self.app = app
|
170 |
+
self.defaults = defaults
|
171 |
+
self.output_state = gr.State(value="{}")
|
172 |
+
self.render()
|
173 |
+
|
174 |
+
def _render_model_interface(self, workflow: Workflow, simple: bool = True):
|
175 |
+
"""Render the model interface."""
|
176 |
+
self.pipeline_interface = PipelineInterface(
|
177 |
+
workflow,
|
178 |
+
simple=simple,
|
179 |
+
model_options=list(self.model_options.keys()),
|
180 |
+
)
|
181 |
+
with gr.Row():
|
182 |
+
self.buzz_t_slider = gr.Slider(
|
183 |
+
minimum=0.5,
|
184 |
+
maximum=1.0,
|
185 |
+
value=self.defaults["buzz_threshold"],
|
186 |
+
step=0.01,
|
187 |
+
label="Buzz Threshold",
|
188 |
+
)
|
189 |
+
self.early_stop_checkbox = gr.Checkbox(
|
190 |
+
value=self.defaults["early_stop"],
|
191 |
+
label="Early Stop",
|
192 |
+
info="Stop early if already buzzed",
|
193 |
+
)
|
194 |
+
self.run_btn = gr.Button("Run Tossup", variant="primary")
|
195 |
+
|
196 |
+
def _render_qb_interface(self):
|
197 |
+
"""Render the quizbowl interface."""
|
198 |
+
with gr.Row():
|
199 |
+
self.qid_selector = gr.Number(
|
200 |
+
label="Question ID", value=1, precision=0, minimum=1, maximum=len(self.ds), show_label=True, scale=0
|
201 |
+
)
|
202 |
+
self.answer_display = gr.Textbox(
|
203 |
+
label="PrimaryAnswer", elem_id="answer-display", elem_classes="answer-box", interactive=False, scale=1
|
204 |
+
)
|
205 |
+
self.clean_answer_display = gr.Textbox(
|
206 |
+
label="Acceptable Answers",
|
207 |
+
elem_id="answer-display-2",
|
208 |
+
elem_classes="answer-box",
|
209 |
+
interactive=False,
|
210 |
+
scale=2,
|
211 |
+
)
|
212 |
+
# self.answer_display = gr.HTML(label="Answer", elem_id="answer-display")
|
213 |
+
self.question_display = gr.HTML(label="Question", elem_id="question-display")
|
214 |
+
with gr.Row():
|
215 |
+
self.confidence_plot = gr.Plot(
|
216 |
+
label="Buzz Confidence",
|
217 |
+
format="webp",
|
218 |
+
)
|
219 |
+
self.results_table = gr.DataFrame(
|
220 |
+
label="Model Outputs",
|
221 |
+
value=pd.DataFrame(columns=["Token Position", "Correct?", "Confidence", "Prediction"]),
|
222 |
+
)
|
223 |
+
with gr.Row():
|
224 |
+
self.eval_btn = gr.Button("Evaluate")
|
225 |
+
|
226 |
+
with gr.Accordion("Model Submission", elem_classes="model-submission-accordion", open=True):
|
227 |
+
with gr.Row():
|
228 |
+
self.model_name_input = gr.Textbox(label="Model Name")
|
229 |
+
self.description_input = gr.Textbox(label="Description")
|
230 |
+
with gr.Row():
|
231 |
+
gr.LoginButton()
|
232 |
+
self.submit_btn = gr.Button("Submit")
|
233 |
+
self.submit_status = gr.HTML(label="Submission Status")
|
234 |
+
|
235 |
+
def render(self):
|
236 |
+
"""Create the Gradio interface."""
|
237 |
+
|
238 |
+
self.hidden_input = gr.Textbox(value="", visible=False, elem_id="hidden-index")
|
239 |
+
|
240 |
+
workflow = self.defaults["init_workflow"]
|
241 |
+
|
242 |
+
with gr.Row():
|
243 |
+
# Model Panel
|
244 |
+
with gr.Column(scale=1):
|
245 |
+
self._render_model_interface(workflow, simple=self.defaults["simple_workflow"])
|
246 |
+
|
247 |
+
with gr.Column(scale=1):
|
248 |
+
self._render_qb_interface()
|
249 |
+
|
250 |
+
self._setup_event_listeners()
|
251 |
+
|
252 |
+
def get_full_question(self, question_id: int) -> str:
|
253 |
+
"""Get the full question text for a given question ID."""
|
254 |
+
try:
|
255 |
+
question_id = int(question_id - 1)
|
256 |
+
if not self.ds or question_id < 0 or question_id >= len(self.ds):
|
257 |
+
return "Invalid question ID or dataset not loaded"
|
258 |
+
|
259 |
+
question_data = self.ds[question_id]
|
260 |
+
# Get the full question text (the last element in question_runs)
|
261 |
+
full_question = question_data["question"]
|
262 |
+
gold_label = question_data["answer_primary"]
|
263 |
+
|
264 |
+
return f"Question: {full_question}\n\nCorrect Answer: {gold_label}"
|
265 |
+
except Exception as e:
|
266 |
+
return f"Error loading question: {str(e)}"
|
267 |
+
|
268 |
+
def validate_workflow(self, pipeline_state: PipelineState):
|
269 |
+
"""Validate the workflow."""
|
270 |
+
try:
|
271 |
+
validate_workflow(pipeline_state.workflow)
|
272 |
+
except Exception as e:
|
273 |
+
raise gr.Error(f"Error validating workflow: {str(e)}")
|
274 |
+
|
275 |
+
def get_new_question_html(self, question_id: int):
|
276 |
+
"""Get the HTML for a new question."""
|
277 |
+
example = self.ds[question_id - 1]
|
278 |
+
question = example["question"]
|
279 |
+
gold_label = example["answer_primary"]
|
280 |
+
marker_indices = example["run_indices"]
|
281 |
+
tokens = question.split()
|
282 |
+
question_html = create_tokens_html(tokens, [], gold_label, marker_indices)
|
283 |
+
clean_answers = [a for a in example["clean_answers"] if len(a.split()) <= 6]
|
284 |
+
clean_answers = ", ".join(clean_answers)
|
285 |
+
return question_html, gold_label, clean_answers
|
286 |
+
|
287 |
+
def get_model_outputs(self, example: dict, pipeline_state: PipelineState, buzz_threshold: float, early_stop: bool):
|
288 |
+
"""Get the model outputs for a given question ID."""
|
289 |
+
question_runs = []
|
290 |
+
tokens = example["question"].split()
|
291 |
+
for run_idx in example["run_indices"]:
|
292 |
+
question_runs.append(" ".join(tokens[: run_idx + 1]))
|
293 |
+
|
294 |
+
agent = SimpleTossupAgent(workflow=pipeline_state.workflow, buzz_threshold=buzz_threshold)
|
295 |
+
outputs = list(agent.run(question_runs, early_stop=early_stop))
|
296 |
+
outputs = add_model_scores(outputs, example["clean_answers"], example["run_indices"])
|
297 |
+
return outputs
|
298 |
+
|
299 |
+
def run_tossup(
|
300 |
+
self,
|
301 |
+
question_id: int,
|
302 |
+
pipeline_state: PipelineState,
|
303 |
+
buzz_threshold: float,
|
304 |
+
early_stop: bool = True,
|
305 |
+
) -> tuple[str, Any, Any]:
|
306 |
+
"""Run the agent in tossup mode with a system prompt."""
|
307 |
+
try:
|
308 |
+
# Validate inputs
|
309 |
+
question_id = int(question_id - 1)
|
310 |
+
if not self.ds or question_id < 0 or question_id >= len(self.ds):
|
311 |
+
return "Invalid question ID or dataset not loaded", None, None
|
312 |
+
example = self.ds[question_id]
|
313 |
+
outputs = self.get_model_outputs(example, pipeline_state, buzz_threshold, early_stop)
|
314 |
+
|
315 |
+
# Process results and prepare visualization data
|
316 |
+
tokens_html, plot_data, output_state = initialize_eval_interface(example, outputs)
|
317 |
+
df = process_tossup_results(outputs)
|
318 |
+
return (
|
319 |
+
tokens_html,
|
320 |
+
gr.update(value=plot_data, label=f"Buzz Confidence on Question {question_id + 1}"),
|
321 |
+
gr.update(value=output_state),
|
322 |
+
gr.update(value=df, label=f"Model Outputs for Question {question_id + 1}"),
|
323 |
+
)
|
324 |
+
except Exception as e:
|
325 |
+
import traceback
|
326 |
+
|
327 |
+
error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
|
328 |
+
return error_msg, None, None
|
329 |
+
|
330 |
+
def evaluate_tossups(
|
331 |
+
self, pipeline_state: PipelineState, buzz_threshold: float, progress: gr.Progress = gr.Progress()
|
332 |
+
):
|
333 |
+
"""Evaluate the tossup."""
|
334 |
+
try:
|
335 |
+
# Validate inputs
|
336 |
+
if not self.ds or not self.ds.num_rows:
|
337 |
+
return "No dataset loaded", None, None
|
338 |
+
|
339 |
+
buzz_counts = 0
|
340 |
+
correct_buzzes = 0
|
341 |
+
token_positions = []
|
342 |
+
correctness = []
|
343 |
+
for example in progress.tqdm(self.ds, desc="Evaluating tossup questions"):
|
344 |
+
model_outputs = self.get_model_outputs(example, pipeline_state, buzz_threshold, early_stop=True)
|
345 |
+
if model_outputs[-1]["buzz"]:
|
346 |
+
buzz_counts += 1
|
347 |
+
if model_outputs[-1]["score"] == 1:
|
348 |
+
correct_buzzes += 1
|
349 |
+
token_positions.append(model_outputs[-1]["token_position"])
|
350 |
+
correctness.append(model_outputs[-1]["score"])
|
351 |
+
buzz_accuracy = correct_buzzes / buzz_counts
|
352 |
+
df = pd.DataFrame(
|
353 |
+
[
|
354 |
+
{
|
355 |
+
"Avg Buzz Position": f"{np.mean(token_positions):.2f}",
|
356 |
+
"Buzz Accuracy": f"{buzz_accuracy:.2%}",
|
357 |
+
"Total Score": f"{correct_buzzes}/{len(self.ds)}",
|
358 |
+
}
|
359 |
+
]
|
360 |
+
)
|
361 |
+
plot_data = create_scatter_pyplot(token_positions, correctness)
|
362 |
+
return (
|
363 |
+
gr.update(value=df, label="Scores on Sample Set"),
|
364 |
+
gr.update(value=plot_data, label="Buzz Positions on Sample Set"),
|
365 |
+
)
|
366 |
+
except Exception:
|
367 |
+
import traceback
|
368 |
+
|
369 |
+
logging.error(f"Error evaluating tossups: {traceback.format_exc()}")
|
370 |
+
return "Error evaluating tossups", None, None
|
371 |
+
|
372 |
+
def submit_model(
|
373 |
+
self, model_name: str, description: str, pipeline_state: PipelineState, profile: gr.OAuthProfile = None
|
374 |
+
):
|
375 |
+
"""Submit the model output."""
|
376 |
+
return submit.submit_model(model_name, description, pipeline_state.workflow, "tossup", profile)
|
377 |
+
|
378 |
+
def _setup_event_listeners(self):
|
379 |
+
gr.on(
|
380 |
+
triggers=[self.app.load, self.qid_selector.change],
|
381 |
+
fn=self.get_new_question_html,
|
382 |
+
inputs=[self.qid_selector],
|
383 |
+
outputs=[self.question_display, self.answer_display, self.clean_answer_display],
|
384 |
+
)
|
385 |
+
|
386 |
+
self.run_btn.click(
|
387 |
+
self.pipeline_interface.validate_workflow,
|
388 |
+
inputs=[self.pipeline_interface.pipeline_state],
|
389 |
+
outputs=[self.pipeline_interface.pipeline_state],
|
390 |
+
).success(
|
391 |
+
self.run_tossup,
|
392 |
+
inputs=[
|
393 |
+
self.qid_selector,
|
394 |
+
self.pipeline_interface.pipeline_state,
|
395 |
+
self.buzz_t_slider,
|
396 |
+
self.early_stop_checkbox,
|
397 |
+
],
|
398 |
+
outputs=[
|
399 |
+
self.question_display,
|
400 |
+
self.confidence_plot,
|
401 |
+
self.output_state,
|
402 |
+
self.results_table,
|
403 |
+
],
|
404 |
+
)
|
405 |
+
|
406 |
+
self.eval_btn.click(
|
407 |
+
fn=self.evaluate_tossups,
|
408 |
+
inputs=[self.pipeline_interface.pipeline_state, self.buzz_t_slider],
|
409 |
+
outputs=[self.results_table, self.confidence_plot],
|
410 |
+
)
|
411 |
+
|
412 |
+
self.submit_btn.click(
|
413 |
+
fn=self.submit_model,
|
414 |
+
inputs=[
|
415 |
+
self.model_name_input,
|
416 |
+
self.description_input,
|
417 |
+
self.pipeline_interface.pipeline_state,
|
418 |
+
],
|
419 |
+
outputs=[self.submit_status],
|
420 |
+
)
|
421 |
+
|
422 |
+
self.hidden_input.change(
|
423 |
+
fn=update_plot,
|
424 |
+
inputs=[self.hidden_input, self.output_state],
|
425 |
+
outputs=[self.confidence_plot],
|
426 |
+
)
|
src/components/quizbowl/utils.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
|
5 |
+
|
6 |
+
def _create_confidence_plot_data(results: List[Dict], top_k_mode: bool = False) -> pd.DataFrame:
|
7 |
+
"""Create a DataFrame for the confidence plot."""
|
8 |
+
if not top_k_mode:
|
9 |
+
return pd.DataFrame(
|
10 |
+
{
|
11 |
+
"position": [r["position"] for r in results],
|
12 |
+
"confidence": [r["confidence"] for r in results],
|
13 |
+
"answer": [r["answer"] for r in results],
|
14 |
+
}
|
15 |
+
)
|
16 |
+
|
17 |
+
# For top-k mode, extract and plot top answers
|
18 |
+
return _create_top_k_plot_data(results)
|
19 |
+
|
20 |
+
|
21 |
+
def _create_top_k_plot_data(results: List[Dict]) -> pd.DataFrame:
|
22 |
+
"""Create plot data for top-k mode."""
|
23 |
+
# Find top answers across all positions (limited to top 5)
|
24 |
+
top_answers = set()
|
25 |
+
for r in results:
|
26 |
+
for g in r.get("guesses", [])[:3]: # Get top 3 from each position
|
27 |
+
if g.get("answer"):
|
28 |
+
top_answers.add(g.get("answer"))
|
29 |
+
|
30 |
+
top_answers = list(top_answers)[:5] # Limit to 5 total answers
|
31 |
+
|
32 |
+
# Create plot data for each answer
|
33 |
+
all_data = []
|
34 |
+
for position_idx, result in enumerate(results):
|
35 |
+
position = result["position"]
|
36 |
+
for answer in top_answers:
|
37 |
+
confidence = 0
|
38 |
+
for guess in result.get("guesses", []):
|
39 |
+
if guess.get("answer") == answer:
|
40 |
+
confidence = guess.get("confidence", 0)
|
41 |
+
break
|
42 |
+
all_data.append({"position": position, "confidence": confidence, "answer": answer})
|
43 |
+
|
44 |
+
return pd.DataFrame(all_data)
|
45 |
+
|
46 |
+
|
47 |
+
def _create_top_k_dataframe(results: List[Dict]) -> pd.DataFrame:
|
48 |
+
"""Create a DataFrame for top-k results."""
|
49 |
+
df_rows = []
|
50 |
+
for result in results:
|
51 |
+
position = result["position"]
|
52 |
+
for i, guess in enumerate(result.get("guesses", [])):
|
53 |
+
df_rows.append(
|
54 |
+
{
|
55 |
+
"position": position,
|
56 |
+
"answer": guess.get("answer", ""),
|
57 |
+
"confidence": guess.get("confidence", 0),
|
58 |
+
"rank": i + 1,
|
59 |
+
}
|
60 |
+
)
|
61 |
+
return pd.DataFrame(df_rows)
|
62 |
+
|
63 |
+
|
64 |
+
def _format_buzz_result(buzzed: bool, results: List[Dict], gold_label: str, top_k_mode: bool) -> tuple[str, str, bool]:
|
65 |
+
"""Format the result text based on whether the agent buzzed."""
|
66 |
+
if not buzzed:
|
67 |
+
return f"Did not buzz. Correct answer was: {gold_label}", "No buzz", False
|
68 |
+
|
69 |
+
buzz_position = next(i for i, r in enumerate(results) if r.get("buzz", False))
|
70 |
+
buzz_result = results[buzz_position]
|
71 |
+
|
72 |
+
if top_k_mode:
|
73 |
+
# For top-k, check if any of the top guesses match
|
74 |
+
top_answers = [g.get("answer", "").lower() for g in buzz_result.get("guesses", [])]
|
75 |
+
correct = gold_label.lower() in [a.lower() for a in top_answers]
|
76 |
+
final_answer = top_answers[0] if top_answers else "No answer"
|
77 |
+
else:
|
78 |
+
# For regular mode
|
79 |
+
final_answer = buzz_result["answer"]
|
80 |
+
correct = final_answer.lower() == gold_label.lower()
|
81 |
+
|
82 |
+
result_text = f"BUZZED at position {buzz_position + 1} with answer: {final_answer}\n"
|
83 |
+
result_text += f"Correct answer: {gold_label}\n"
|
84 |
+
result_text += f"Result: {'CORRECT' if correct else 'INCORRECT'}"
|
85 |
+
|
86 |
+
return result_text, final_answer, correct
|
src/components/utils.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Literal
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
|
5 |
+
DIRECTIONS = Literal["up", "down"]
|
6 |
+
|
7 |
+
|
8 |
+
def make_state(value: Any) -> gr.State:
|
9 |
+
"""Make a state from a value."""
|
10 |
+
if isinstance(value, gr.State):
|
11 |
+
return value
|
12 |
+
else:
|
13 |
+
return gr.State(value)
|
14 |
+
|
15 |
+
|
16 |
+
# List utilities
|
17 |
+
def move_item(items: list, position: int, direction: DIRECTIONS):
|
18 |
+
"""Move an item up or down in a list."""
|
19 |
+
if not isinstance(items, list):
|
20 |
+
raise ValueError("items must be a list")
|
21 |
+
if not isinstance(position, int) or not (0 <= position < len(items)):
|
22 |
+
raise ValueError("position must be a valid index in the list")
|
23 |
+
if direction not in ["up", "down"]:
|
24 |
+
raise ValueError("direction must be 'up' or 'down'")
|
25 |
+
|
26 |
+
if direction == "up" and position > 0:
|
27 |
+
items[position], items[position - 1] = items[position - 1], items[position]
|
28 |
+
elif direction == "down" and position < len(items) - 1:
|
29 |
+
items[position], items[position + 1] = items[position + 1], items[position]
|
src/display/__init__.py
ADDED
File without changes
|
src/display/css_html_js.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
custom_css = """
|
2 |
+
|
3 |
+
.markdown-text {
|
4 |
+
font-size: 16px !important;
|
5 |
+
}
|
6 |
+
|
7 |
+
#models-to-add-text {
|
8 |
+
font-size: 18px !important;
|
9 |
+
}
|
10 |
+
|
11 |
+
#citation-button span {
|
12 |
+
font-size: 16px !important;
|
13 |
+
}
|
14 |
+
|
15 |
+
#citation-button textarea {
|
16 |
+
font-size: 16px !important;
|
17 |
+
}
|
18 |
+
|
19 |
+
#citation-button > label > button {
|
20 |
+
margin: 6px;
|
21 |
+
transform: scale(1.3);
|
22 |
+
}
|
23 |
+
|
24 |
+
#leaderboard-table {
|
25 |
+
margin-top: 15px
|
26 |
+
}
|
27 |
+
|
28 |
+
#leaderboard-table-lite {
|
29 |
+
margin-top: 15px
|
30 |
+
}
|
31 |
+
|
32 |
+
#search-bar-table-box > div:first-child {
|
33 |
+
background: none;
|
34 |
+
border: none;
|
35 |
+
}
|
36 |
+
|
37 |
+
#search-bar {
|
38 |
+
padding: 0px;
|
39 |
+
}
|
40 |
+
|
41 |
+
/* Limit the width of the first AutoEvalColumn so that names don't expand too much */
|
42 |
+
#leaderboard-table td:nth-child(2),
|
43 |
+
#leaderboard-table th:nth-child(2) {
|
44 |
+
max-width: 400px;
|
45 |
+
overflow: auto;
|
46 |
+
white-space: nowrap;
|
47 |
+
}
|
48 |
+
|
49 |
+
/* Workflow JSON styling */
|
50 |
+
.workflow-json-container {
|
51 |
+
margin-top: 20px;
|
52 |
+
margin-bottom: 30px;
|
53 |
+
}
|
54 |
+
|
55 |
+
.workflow-json {
|
56 |
+
border: 1px solid #ddd;
|
57 |
+
border-radius: 8px;
|
58 |
+
box-shadow: 0 2px 5px rgba(0,0,0,0.1);
|
59 |
+
}
|
60 |
+
|
61 |
+
.workflow-json pre {
|
62 |
+
max-height: 500px;
|
63 |
+
overflow-y: auto;
|
64 |
+
}
|
65 |
+
|
66 |
+
.tab-buttons button {
|
67 |
+
font-size: 20px;
|
68 |
+
}
|
69 |
+
|
70 |
+
#scale-logo {
|
71 |
+
border-style: none !important;
|
72 |
+
box-shadow: none;
|
73 |
+
display: block;
|
74 |
+
margin-left: auto;
|
75 |
+
margin-right: auto;
|
76 |
+
max-width: 600px;
|
77 |
+
}
|
78 |
+
|
79 |
+
#scale-logo .download {
|
80 |
+
display: none;
|
81 |
+
}
|
82 |
+
#filter_type{
|
83 |
+
border: 0;
|
84 |
+
padding-left: 0;
|
85 |
+
padding-top: 0;
|
86 |
+
}
|
87 |
+
#filter_type label {
|
88 |
+
display: flex;
|
89 |
+
}
|
90 |
+
#filter_type label > span{
|
91 |
+
margin-top: var(--spacing-lg);
|
92 |
+
margin-right: 0.5em;
|
93 |
+
}
|
94 |
+
#filter_type label > .wrap{
|
95 |
+
width: 103px;
|
96 |
+
}
|
97 |
+
#filter_type label > .wrap .wrap-inner{
|
98 |
+
padding: 2px;
|
99 |
+
}
|
100 |
+
#filter_type label > .wrap .wrap-inner input{
|
101 |
+
width: 1px
|
102 |
+
}
|
103 |
+
#filter-columns-type{
|
104 |
+
border:0;
|
105 |
+
padding:0.5;
|
106 |
+
}
|
107 |
+
#filter-columns-size{
|
108 |
+
border:0;
|
109 |
+
padding:0.5;
|
110 |
+
}
|
111 |
+
#box-filter > .form{
|
112 |
+
border: 0
|
113 |
+
}
|
114 |
+
"""
|
115 |
+
|
116 |
+
get_window_url_params = """
|
117 |
+
function(url_params) {
|
118 |
+
const params = new URLSearchParams(window.location.search);
|
119 |
+
url_params = Object.fromEntries(params);
|
120 |
+
return url_params;
|
121 |
+
}
|
122 |
+
"""
|
src/display/custom_css.py
ADDED
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
css_pipeline = """
|
2 |
+
:root {
|
3 |
+
color-scheme: light !important;
|
4 |
+
--block-border-width: 0;
|
5 |
+
--section-header-text-weight: 600;
|
6 |
+
--section-header-text-size: 14px;
|
7 |
+
--mono-font-family: "Roboto Mono", monospace;
|
8 |
+
--body-text-size: 14px !important;
|
9 |
+
|
10 |
+
--card-bg-color: #fcecd4;
|
11 |
+
--card-btn-color: #D4E4FC;
|
12 |
+
--card-btn-color-hover: #7DAEF6;
|
13 |
+
--answer-bg-color: #f0f8ff;
|
14 |
+
--hover-border-color: #121212;
|
15 |
+
}
|
16 |
+
|
17 |
+
.dark {
|
18 |
+
--block-border-width: 0;
|
19 |
+
--card-bg-color: #383127;
|
20 |
+
--answer-bg-color: #1a2b3c;
|
21 |
+
--hover-border-color: #ffffff;
|
22 |
+
}
|
23 |
+
|
24 |
+
.gradio-app {
|
25 |
+
// font-family: Arial, sans-serif;
|
26 |
+
}
|
27 |
+
|
28 |
+
.form {
|
29 |
+
box-shadow: 0 0 0 0 !important;
|
30 |
+
}
|
31 |
+
|
32 |
+
.head {
|
33 |
+
margin-bottom: 0px;
|
34 |
+
}
|
35 |
+
|
36 |
+
.gradio-container {
|
37 |
+
max-width: 1500px;
|
38 |
+
margin: 0 auto;
|
39 |
+
padding: 0 8px;
|
40 |
+
}
|
41 |
+
|
42 |
+
.html-container {
|
43 |
+
padding: 0px 0px;
|
44 |
+
margin: 4px 0px;
|
45 |
+
border-radius: 12px;
|
46 |
+
gap: 0px
|
47 |
+
}
|
48 |
+
|
49 |
+
.step-container {
|
50 |
+
background-color: var(--card-bg-color);
|
51 |
+
padding: 0px 0px;
|
52 |
+
margin: 4px 0px;
|
53 |
+
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2);
|
54 |
+
border-radius: 12px;
|
55 |
+
gap: 0px
|
56 |
+
}
|
57 |
+
|
58 |
+
.step-container:hover {
|
59 |
+
border-color: var(--hover-border-color);
|
60 |
+
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.2);
|
61 |
+
}
|
62 |
+
|
63 |
+
.step-accordion {
|
64 |
+
background-color: var(--card-bg-color);
|
65 |
+
border: 0px solid #e0e0e0 !important;
|
66 |
+
border-radius: 12px;
|
67 |
+
overflow: hidden;
|
68 |
+
// transition: box-shadow 0.3s ease, border-color 0.3s ease;
|
69 |
+
padding: 8px 8px;
|
70 |
+
font-size: 12px;
|
71 |
+
}
|
72 |
+
|
73 |
+
.output-fields-panel {
|
74 |
+
background-color: var(--card-bg-color);
|
75 |
+
border: 0px solid #e0e0e0 !important;
|
76 |
+
border-radius: 12px;
|
77 |
+
overflow: hidden;
|
78 |
+
transition: box-shadow 0.3s ease, border-color 0.3s ease;
|
79 |
+
padding: 8px 8px;
|
80 |
+
font-size: 12px;
|
81 |
+
}
|
82 |
+
|
83 |
+
.model-submission-accordion {
|
84 |
+
background-color: var(--card-bg-color);
|
85 |
+
border: 0px solid #e0e0e0 !important;
|
86 |
+
border-radius: 12px;
|
87 |
+
overflow: hidden;
|
88 |
+
transition: box-shadow 0.3s ease, border-color 0.3s ease;
|
89 |
+
font-size: 14px;
|
90 |
+
}
|
91 |
+
|
92 |
+
.model-submission-accordion > label-wrap {
|
93 |
+
font-size: 16px;
|
94 |
+
font-weight: bold !important;
|
95 |
+
}
|
96 |
+
|
97 |
+
.step-accordion:hover .step-name-input input {
|
98 |
+
font-weight: bold;
|
99 |
+
}
|
100 |
+
|
101 |
+
.step-accordion > label-wrap {
|
102 |
+
font-size: 14px;
|
103 |
+
font-weight: bold !important;
|
104 |
+
}
|
105 |
+
|
106 |
+
.step-header-row {
|
107 |
+
margin: 0px 0px;
|
108 |
+
padding: 0px 0px;
|
109 |
+
border: 0px !important;
|
110 |
+
}
|
111 |
+
|
112 |
+
.step-header-row form {
|
113 |
+
margin: 0px 0px;
|
114 |
+
padding: 0px 0px;
|
115 |
+
border: 0px !important;
|
116 |
+
}
|
117 |
+
|
118 |
+
.step-name {
|
119 |
+
margin: 0px
|
120 |
+
padding: 0px 0px;
|
121 |
+
// border-radius: 8px;
|
122 |
+
border: 0px !important
|
123 |
+
}
|
124 |
+
|
125 |
+
.model-dropdown {
|
126 |
+
margin: 0px
|
127 |
+
padding: 0px 8px;
|
128 |
+
}
|
129 |
+
|
130 |
+
.model-dropdown input {
|
131 |
+
font-size: 14px;
|
132 |
+
padding-bottom: 2px;
|
133 |
+
padding-top: 2px;
|
134 |
+
}
|
135 |
+
|
136 |
+
.step-name input {
|
137 |
+
font-size: 14px;
|
138 |
+
font-weight: bold;
|
139 |
+
padding-bottom: 8px;
|
140 |
+
margin-bottom: 4px;
|
141 |
+
border-radius: 12px !important;
|
142 |
+
}
|
143 |
+
|
144 |
+
.step-controls {
|
145 |
+
display: flex;
|
146 |
+
justify-content: flex-end;
|
147 |
+
gap: 12px;
|
148 |
+
background-color: var(--card-bg-color);
|
149 |
+
border-radius: 12px;
|
150 |
+
padding: 0px
|
151 |
+
border: 1px solid black;
|
152 |
+
}
|
153 |
+
|
154 |
+
.step-control-btn {
|
155 |
+
background-color: var(--card-btn-color);
|
156 |
+
font-size: 12px !important;
|
157 |
+
color: var(--body-text-color);
|
158 |
+
min-width: 36px !important;
|
159 |
+
min-height: 24px !important;
|
160 |
+
padding: 4px !important;
|
161 |
+
margin: 8px !important;
|
162 |
+
border-radius: 12px;
|
163 |
+
}
|
164 |
+
|
165 |
+
.step-control-btn:hover {
|
166 |
+
background-color: var(--card-btn-color-hover);
|
167 |
+
}
|
168 |
+
|
169 |
+
.step-tabs {
|
170 |
+
margin-top: 0px;
|
171 |
+
padding: 0px 0px;
|
172 |
+
border-radius: 0px;
|
173 |
+
border: 0px
|
174 |
+
background-color: transparent;
|
175 |
+
}
|
176 |
+
|
177 |
+
.tab-content {
|
178 |
+
padding: 0px 0px;
|
179 |
+
margin-bottom: 0px;
|
180 |
+
border-radius: 4px;
|
181 |
+
border: 0px solid #eee;
|
182 |
+
background-color: transparent !important;
|
183 |
+
}
|
184 |
+
|
185 |
+
.fields-panel {
|
186 |
+
background-color: transparent !important;
|
187 |
+
gap: 5px !important;
|
188 |
+
border-radius: 4px;
|
189 |
+
padding: 2px;
|
190 |
+
}
|
191 |
+
|
192 |
+
.field-row {
|
193 |
+
margin-bottom: 1px;
|
194 |
+
margin-top: 1px;
|
195 |
+
padding: 2px;
|
196 |
+
border-radius: 8px;
|
197 |
+
background-color: var(--block-background-fill) !important;
|
198 |
+
border: 0px solid #eee;
|
199 |
+
gap: 0px !important;
|
200 |
+
}
|
201 |
+
|
202 |
+
.output-field-row {
|
203 |
+
margin-bottom: 1px;
|
204 |
+
margin-top: 1px;
|
205 |
+
padding: 2px;
|
206 |
+
border-radius: 4px;
|
207 |
+
background-color: var(--block-background-fill) !important;
|
208 |
+
border: 0px solid #eee;
|
209 |
+
gap: 0px !important;
|
210 |
+
}
|
211 |
+
|
212 |
+
.output-fields-header {
|
213 |
+
padding: 0px 8px;
|
214 |
+
}
|
215 |
+
|
216 |
+
.output-fields-panel {
|
217 |
+
background-color: var(--block-background-fill) !important;
|
218 |
+
padding: 8px 8px;
|
219 |
+
}
|
220 |
+
|
221 |
+
.output-field-variable {
|
222 |
+
font-family: var(--mono-font-family) !important;
|
223 |
+
font-weight: 300 !important;
|
224 |
+
font-size: 12px !important;
|
225 |
+
padding: 8px 8px;
|
226 |
+
border-radius: 4px;
|
227 |
+
border: 0px solid #eee !important;
|
228 |
+
}
|
229 |
+
|
230 |
+
.output-field-variable span {
|
231 |
+
font-size: 12px !important;
|
232 |
+
}
|
233 |
+
|
234 |
+
.field-type {
|
235 |
+
min-width: 100px !important;
|
236 |
+
}
|
237 |
+
|
238 |
+
.field-name > label, .field-variable > label, .field-description > label, .field-type > label {
|
239 |
+
font-size: 12px !important;
|
240 |
+
}
|
241 |
+
|
242 |
+
.field-name input, .field-description input, .field-type input {
|
243 |
+
font-family: var(--mono-font-family) !important;
|
244 |
+
font-size: 12px !important;
|
245 |
+
}
|
246 |
+
|
247 |
+
.field-variable input, .field-type input, .output-field-variable input {
|
248 |
+
font-family: var(--mono-font-family) !important;
|
249 |
+
font-size: 12px !important;
|
250 |
+
padding-top: 3px;
|
251 |
+
padding-bottom: 3px;
|
252 |
+
}
|
253 |
+
|
254 |
+
.field-name listbox, .field-variable listbox, .field-type listbox {
|
255 |
+
font-family: var(--mono-font-family) !important;
|
256 |
+
padding-top: 2px;
|
257 |
+
padding-bottom: 2px;
|
258 |
+
font-size: 12px !important;
|
259 |
+
}
|
260 |
+
|
261 |
+
.field-description {
|
262 |
+
font-size: 12px !important;
|
263 |
+
}
|
264 |
+
|
265 |
+
/* Accordion button labels */
|
266 |
+
.step-accordion button.label-wrap {
|
267 |
+
font-size: 14px;
|
268 |
+
font-weight: bold !important;
|
269 |
+
font-family: var(--mono-font-family) !important;
|
270 |
+
}
|
271 |
+
|
272 |
+
.step-accordion button.label-wrap.open {
|
273 |
+
font-size: 14px;
|
274 |
+
font-weight: bold !important;
|
275 |
+
font-family: var(--mono-font-family) !important;
|
276 |
+
}
|
277 |
+
|
278 |
+
.button-column {
|
279 |
+
margin-top: 2px;
|
280 |
+
margin-bottom: 2px;
|
281 |
+
padding-top: 2px;
|
282 |
+
padding-bottom: 2px;
|
283 |
+
display: flex;
|
284 |
+
flex-direction: column;
|
285 |
+
justify-content: center;
|
286 |
+
align-items: center;
|
287 |
+
gap: 4x !important;
|
288 |
+
height: 100%;
|
289 |
+
}
|
290 |
+
|
291 |
+
.icon-button {
|
292 |
+
min-width: 28px !important;
|
293 |
+
max-width: 42px !important;
|
294 |
+
height: 28px !important;
|
295 |
+
max-height: 42px !important;
|
296 |
+
padding: 0 !important;
|
297 |
+
border-radius: 4px !important;
|
298 |
+
transition: background-color 0.2s ease, color 0.2s ease;
|
299 |
+
}
|
300 |
+
|
301 |
+
.delete-button {
|
302 |
+
background-color: #ffebee !important;
|
303 |
+
color: #d32f2f !important;
|
304 |
+
}
|
305 |
+
|
306 |
+
.delete-button:hover {
|
307 |
+
background-color: #ffcdd2 !important;
|
308 |
+
}
|
309 |
+
|
310 |
+
.up-button, .down-button {
|
311 |
+
background-color: #e3f2fd !important;
|
312 |
+
color: #1976d2 !important;
|
313 |
+
}
|
314 |
+
|
315 |
+
.up-button:hover, .down-button:hover {
|
316 |
+
background-color: #bbdefb !important;
|
317 |
+
}
|
318 |
+
|
319 |
+
.add-field-button {
|
320 |
+
background-color: #e8f5e9 !important;
|
321 |
+
color: #2e7d32 !important;
|
322 |
+
}
|
323 |
+
|
324 |
+
.add-field-button:hover, .add-step-button:hover {
|
325 |
+
background-color: #c8e6c9 !important;
|
326 |
+
}
|
327 |
+
|
328 |
+
.pipeline-controls {
|
329 |
+
border-top: 1px solid #eee;
|
330 |
+
padding-top: 8px;
|
331 |
+
}
|
332 |
+
|
333 |
+
.pipeline-header {
|
334 |
+
border-bottom: 1px solid #eee;
|
335 |
+
padding: 8px 0px;
|
336 |
+
}
|
337 |
+
|
338 |
+
.pipeline-footer {
|
339 |
+
border-top: 1px solid #eee;
|
340 |
+
padding: 8px 0px;
|
341 |
+
}
|
342 |
+
|
343 |
+
.add-step-button {
|
344 |
+
background-color: #e8f5e9 !important;
|
345 |
+
color: #2e7d32 !important;
|
346 |
+
border-radius: 12px;
|
347 |
+
}
|
348 |
+
|
349 |
+
.export-button {
|
350 |
+
background-color: #e0f7f5 !important;
|
351 |
+
color: #00796b !important;
|
352 |
+
border-radius: 12px;
|
353 |
+
}
|
354 |
+
|
355 |
+
.export-button:hover {
|
356 |
+
background-color: #b2dfdb !important;
|
357 |
+
}
|
358 |
+
|
359 |
+
.pipeline-preview {
|
360 |
+
background-color: var(--card-bg-color);
|
361 |
+
border-radius: 12px;
|
362 |
+
box-shadow: 0 0 0 0 !important;
|
363 |
+
}
|
364 |
+
"""
|
365 |
+
|
366 |
+
|
367 |
+
css_tossup = """
|
368 |
+
.token {
|
369 |
+
display: inline-block;
|
370 |
+
padding: 1px 3px;
|
371 |
+
margin: 1px;
|
372 |
+
border-radius: 4px;
|
373 |
+
cursor: pointer;
|
374 |
+
transition: background-color 0.2s;
|
375 |
+
}
|
376 |
+
.answer-header {
|
377 |
+
font-weight: 900;
|
378 |
+
font-size: 16px;
|
379 |
+
padding: 8px;
|
380 |
+
border-radius: 8px;
|
381 |
+
background-color: var(--answer-bg-color) !important;
|
382 |
+
}
|
383 |
+
.answer-box textarea {
|
384 |
+
font-size: 16px;
|
385 |
+
padding: 8px;
|
386 |
+
border-radius: 8px;
|
387 |
+
background-color: var(--answer-bg-color) !important;
|
388 |
+
}
|
389 |
+
.token:hover, .token.highlighted {
|
390 |
+
background-color: #ffcc00;
|
391 |
+
}
|
392 |
+
.token.guess-point {
|
393 |
+
border-bottom: 3px solid;
|
394 |
+
}
|
395 |
+
.token.no-buzz {
|
396 |
+
border-color: #6b96b3;
|
397 |
+
}
|
398 |
+
.token.buzz-0 {
|
399 |
+
border-color: #ff4444;
|
400 |
+
}
|
401 |
+
.token.buzz-1 {
|
402 |
+
border-color: #228b22; /* Darker and slightly muted green */
|
403 |
+
}
|
404 |
+
.token-container {
|
405 |
+
line-height: 1.7;
|
406 |
+
padding: 5px;
|
407 |
+
margin-left: 4px;
|
408 |
+
margin-right: 4px;
|
409 |
+
background-color: var(--answer-bg-color) !important;
|
410 |
+
border-radius: 8px;
|
411 |
+
margin-bottom: 10px;
|
412 |
+
}
|
413 |
+
"""
|
src/display/formatting.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def model_hyperlink(link, model_name):
|
2 |
+
return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
|
3 |
+
|
4 |
+
|
5 |
+
def make_clickable_model(model_name):
|
6 |
+
link = f"https://huggingface.co/{model_name}"
|
7 |
+
return model_hyperlink(link, model_name)
|
8 |
+
|
9 |
+
|
10 |
+
def styled_error(error):
|
11 |
+
return f"<p style='color: red; font-size: 20px; text-align: center;'>{error}</p>"
|
12 |
+
|
13 |
+
|
14 |
+
def styled_warning(warn):
|
15 |
+
return f"<p style='color: orange; font-size: 20px; text-align: center;'>{warn}</p>"
|
16 |
+
|
17 |
+
|
18 |
+
def styled_message(message):
|
19 |
+
return f"<p style='color: green; font-size: 20px; text-align: center;'>{message}</p>"
|
20 |
+
|
21 |
+
|
22 |
+
def has_no_nan_values(df, columns):
|
23 |
+
return df[columns].notna().all(axis=1)
|
24 |
+
|
25 |
+
|
26 |
+
def has_nan_values(df, columns):
|
27 |
+
return df[columns].isna().any(axis=1)
|
src/display/utils.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, make_dataclass
|
2 |
+
from enum import Enum
|
3 |
+
|
4 |
+
import pandas as pd
|
5 |
+
|
6 |
+
from src.about import Tasks
|
7 |
+
|
8 |
+
def fields(raw_class):
|
9 |
+
return [v for k, v in raw_class.__dict__.items() if k[:2] != "__" and k[-2:] != "__"]
|
10 |
+
|
11 |
+
|
12 |
+
# These classes are for user facing column names,
|
13 |
+
# to avoid having to change them all around the code
|
14 |
+
# when a modif is needed
|
15 |
+
@dataclass
|
16 |
+
class ColumnContent:
|
17 |
+
name: str
|
18 |
+
type: str
|
19 |
+
displayed_by_default: bool
|
20 |
+
hidden: bool = False
|
21 |
+
never_hidden: bool = False
|
22 |
+
|
23 |
+
## Leaderboard columns
|
24 |
+
auto_eval_column_dict = []
|
25 |
+
# Init
|
26 |
+
auto_eval_column_dict.append(["model_type_symbol", ColumnContent, ColumnContent("T", "str", True, never_hidden=True)])
|
27 |
+
auto_eval_column_dict.append(["model", ColumnContent, ColumnContent("Model", "markdown", True, never_hidden=True)])
|
28 |
+
#Scores
|
29 |
+
auto_eval_column_dict.append(["average", ColumnContent, ColumnContent("Average ⬆️", "number", True)])
|
30 |
+
for task in Tasks:
|
31 |
+
auto_eval_column_dict.append([task.name, ColumnContent, ColumnContent(task.value.col_name, "number", True)])
|
32 |
+
# Model information
|
33 |
+
auto_eval_column_dict.append(["model_type", ColumnContent, ColumnContent("Type", "str", False)])
|
34 |
+
auto_eval_column_dict.append(["architecture", ColumnContent, ColumnContent("Architecture", "str", False)])
|
35 |
+
auto_eval_column_dict.append(["weight_type", ColumnContent, ColumnContent("Weight type", "str", False, True)])
|
36 |
+
auto_eval_column_dict.append(["precision", ColumnContent, ColumnContent("Precision", "str", False)])
|
37 |
+
auto_eval_column_dict.append(["license", ColumnContent, ColumnContent("Hub License", "str", False)])
|
38 |
+
auto_eval_column_dict.append(["params", ColumnContent, ColumnContent("#Params (B)", "number", False)])
|
39 |
+
auto_eval_column_dict.append(["likes", ColumnContent, ColumnContent("Hub ❤️", "number", False)])
|
40 |
+
auto_eval_column_dict.append(["still_on_hub", ColumnContent, ColumnContent("Available on the hub", "bool", False)])
|
41 |
+
auto_eval_column_dict.append(["revision", ColumnContent, ColumnContent("Model sha", "str", False, False)])
|
42 |
+
|
43 |
+
# We use make dataclass to dynamically fill the scores from Tasks
|
44 |
+
AutoEvalColumn = make_dataclass("AutoEvalColumn", auto_eval_column_dict, frozen=True)
|
45 |
+
|
46 |
+
## For the queue columns in the submission tab
|
47 |
+
@dataclass(frozen=True)
|
48 |
+
class EvalQueueColumn: # Queue column
|
49 |
+
model = ColumnContent("model", "markdown", True)
|
50 |
+
revision = ColumnContent("revision", "str", True)
|
51 |
+
private = ColumnContent("private", "bool", True)
|
52 |
+
precision = ColumnContent("precision", "str", True)
|
53 |
+
weight_type = ColumnContent("weight_type", "str", "Original")
|
54 |
+
status = ColumnContent("status", "str", True)
|
55 |
+
|
56 |
+
## All the model information that we might need
|
57 |
+
@dataclass
|
58 |
+
class ModelDetails:
|
59 |
+
name: str
|
60 |
+
display_name: str = ""
|
61 |
+
symbol: str = "" # emoji
|
62 |
+
|
63 |
+
|
64 |
+
class ModelType(Enum):
|
65 |
+
PT = ModelDetails(name="pretrained", symbol="🟢")
|
66 |
+
FT = ModelDetails(name="fine-tuned", symbol="🔶")
|
67 |
+
IFT = ModelDetails(name="instruction-tuned", symbol="⭕")
|
68 |
+
RL = ModelDetails(name="RL-tuned", symbol="🟦")
|
69 |
+
Unknown = ModelDetails(name="", symbol="?")
|
70 |
+
|
71 |
+
def to_str(self, separator=" "):
|
72 |
+
return f"{self.value.symbol}{separator}{self.value.name}"
|
73 |
+
|
74 |
+
@staticmethod
|
75 |
+
def from_str(type):
|
76 |
+
if "fine-tuned" in type or "🔶" in type:
|
77 |
+
return ModelType.FT
|
78 |
+
if "pretrained" in type or "🟢" in type:
|
79 |
+
return ModelType.PT
|
80 |
+
if "RL-tuned" in type or "🟦" in type:
|
81 |
+
return ModelType.RL
|
82 |
+
if "instruction-tuned" in type or "⭕" in type:
|
83 |
+
return ModelType.IFT
|
84 |
+
return ModelType.Unknown
|
85 |
+
|
86 |
+
class WeightType(Enum):
|
87 |
+
Adapter = ModelDetails("Adapter")
|
88 |
+
Original = ModelDetails("Original")
|
89 |
+
Delta = ModelDetails("Delta")
|
90 |
+
|
91 |
+
class Precision(Enum):
|
92 |
+
float16 = ModelDetails("float16")
|
93 |
+
bfloat16 = ModelDetails("bfloat16")
|
94 |
+
Unknown = ModelDetails("?")
|
95 |
+
|
96 |
+
def from_str(precision):
|
97 |
+
if precision in ["torch.float16", "float16"]:
|
98 |
+
return Precision.float16
|
99 |
+
if precision in ["torch.bfloat16", "bfloat16"]:
|
100 |
+
return Precision.bfloat16
|
101 |
+
return Precision.Unknown
|
102 |
+
|
103 |
+
# Column selection
|
104 |
+
COLS = [c.name for c in fields(AutoEvalColumn) if not c.hidden]
|
105 |
+
|
106 |
+
EVAL_COLS = [c.name for c in fields(EvalQueueColumn)]
|
107 |
+
EVAL_TYPES = [c.type for c in fields(EvalQueueColumn)]
|
108 |
+
|
109 |
+
BENCHMARK_COLS = [t.value.col_name for t in Tasks]
|
110 |
+
|
src/envs.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from huggingface_hub import HfApi
|
4 |
+
|
5 |
+
# Info to change for your repository
|
6 |
+
# ----------------------------------
|
7 |
+
TOKEN = os.environ.get("HF_TOKEN") # A read/write token for your org
|
8 |
+
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
|
9 |
+
ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY")
|
10 |
+
COHERE_API_KEY = os.environ.get("COHERE_API_KEY")
|
11 |
+
|
12 |
+
OWNER = (
|
13 |
+
"umdclip" # Change to your org - don't forget to create a results and request dataset, with the correct format!
|
14 |
+
)
|
15 |
+
# ----------------------------------
|
16 |
+
|
17 |
+
REPO_ID = f"{OWNER}/advcal-leaderboard"
|
18 |
+
QUEUE_REPO = f"{OWNER}/advcal-requests"
|
19 |
+
RESULTS_REPO = f"{OWNER}/advcal-results"
|
20 |
+
|
21 |
+
PLAYGROUND_DATASET_NAMES = {
|
22 |
+
"tossup": "umdclip/acf-co24-tossups",
|
23 |
+
"bonus": "umdclip/acf-co24-bonuses",
|
24 |
+
}
|
25 |
+
|
26 |
+
# If you setup a cache later, just change HF_HOME
|
27 |
+
CACHE_PATH = os.getenv("HF_HOME", ".")
|
28 |
+
|
29 |
+
# Local caches
|
30 |
+
EVAL_REQUESTS_PATH = os.path.join(CACHE_PATH, "eval-queue")
|
31 |
+
EVAL_RESULTS_PATH = os.path.join(CACHE_PATH, "eval-results")
|
32 |
+
EVAL_REQUESTS_PATH_BACKEND = os.path.join(CACHE_PATH, "eval-queue-bk")
|
33 |
+
EVAL_RESULTS_PATH_BACKEND = os.path.join(CACHE_PATH, "eval-results-bk")
|
34 |
+
|
35 |
+
THEME = "gstaff/xkcd"
|
36 |
+
UNSELECTED_VAR_NAME = "Select Variable..."
|
37 |
+
UNSELECTED_MODEL_NAME = "Select Model..."
|
38 |
+
AVAILABLE_MODELS = {
|
39 |
+
"OpenAI/gpt-4o": {
|
40 |
+
"model": "gpt-4o-2024-11-20",
|
41 |
+
},
|
42 |
+
"OpenAI/gpt-4o-mini": {
|
43 |
+
"model": "gpt-4o-mini-2024-07-18",
|
44 |
+
},
|
45 |
+
"OpenAI/gpt-3.5-turbo": {
|
46 |
+
"model": "gpt-3.5-turbo-0125",
|
47 |
+
},
|
48 |
+
"Anthropic/claude-3-7-sonnet": {
|
49 |
+
"model": "claude-3-7-sonnet-20250219",
|
50 |
+
},
|
51 |
+
"Anthropic/claude-3-5-sonnet": {
|
52 |
+
"model": "claude-3-5-sonnet-20241022",
|
53 |
+
},
|
54 |
+
"Anthropic/claude-3-5-haiku": {
|
55 |
+
"model": "claude-3-5-haiku-20241022",
|
56 |
+
},
|
57 |
+
"Cohere/command-r": {
|
58 |
+
"model": "command-r-08-2024",
|
59 |
+
},
|
60 |
+
"Cohere/command-r-plus": {
|
61 |
+
"model": "command-r-plus-08-2024",
|
62 |
+
},
|
63 |
+
"Cohere/command-r7b": {
|
64 |
+
"model": "command-r7b-12-2024",
|
65 |
+
},
|
66 |
+
}
|
67 |
+
|
68 |
+
DEFAULT_SELECTIONS = {
|
69 |
+
"tossup": {
|
70 |
+
"simple_workflow": False,
|
71 |
+
"model": "OpenAI/gpt-4o-mini",
|
72 |
+
"temperature": 0.2,
|
73 |
+
"buzz_threshold": 0.85,
|
74 |
+
"early_stop": True,
|
75 |
+
},
|
76 |
+
"bonus": {
|
77 |
+
"simple_workflow": False,
|
78 |
+
"model": "OpenAI/gpt-4o-mini",
|
79 |
+
"temperature": 0.2,
|
80 |
+
"buzz_threshold": 0.85,
|
81 |
+
"early_stop": True,
|
82 |
+
},
|
83 |
+
}
|
84 |
+
|
85 |
+
DAILY_SUBMISSION_LIMIT_PER_USER = 5
|
86 |
+
API = HfApi(token=TOKEN)
|
src/submission/structs.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
from typing import Dict, List, Literal, Optional
|
3 |
+
|
4 |
+
from pydantic import BaseModel, Field
|
5 |
+
|
6 |
+
from workflows.structs import Workflow
|
7 |
+
|
8 |
+
CompetitionType = Literal["tossup", "bonus"]
|
9 |
+
SubmissionType = Literal["python_file", "simple_workflow", "complex_workflow"]
|
10 |
+
SubmissionStatus = Literal["submitted", "in_progress", "completed", "failed"]
|
11 |
+
|
12 |
+
|
13 |
+
class Submission(BaseModel):
|
14 |
+
"""
|
15 |
+
Represents a submission in the competition system, formatted for HuggingFace datasets.
|
16 |
+
|
17 |
+
This model is designed to be easily serializable to/from HuggingFace dataset format
|
18 |
+
while maintaining type safety and validation through Pydantic.
|
19 |
+
|
20 |
+
Attributes:
|
21 |
+
id: Unique identifier for the submission
|
22 |
+
name: Display name of the submission
|
23 |
+
description: Detailed description of what the submission does
|
24 |
+
user_email: Email of the user who created the submission
|
25 |
+
competition_type: Type of competition (Tossup or Bonus)
|
26 |
+
submission_type: Format of the submission (python file or workflow)
|
27 |
+
workflow: Optional workflow definition for workflow submissions, stored as JSON
|
28 |
+
code: Optional code content for python file submissions
|
29 |
+
status: Current status of the submission
|
30 |
+
created_at: ISO format timestamp of creation
|
31 |
+
updated_at: ISO format timestamp of last update
|
32 |
+
"""
|
33 |
+
|
34 |
+
id: str = Field(description="Unique identifier for the submission")
|
35 |
+
model_name: str = Field(description="Display name of the submission")
|
36 |
+
username: str = Field(description="HuggingFace username of the user who created the submission")
|
37 |
+
description: str = Field(description="Detailed description of what the submission does")
|
38 |
+
competition_type: CompetitionType = Field(description="Type of competition (tossup or bonus)")
|
39 |
+
submission_type: SubmissionType = Field(description="Format of the submission (python file or workflow)")
|
40 |
+
workflow: Optional[Workflow] = Field(default=None, description="Optional workflow definition stored as JSON dict")
|
41 |
+
code: Optional[str] = Field(default=None, description="Optional code content for python file submissions")
|
42 |
+
status: SubmissionStatus = Field(description="Current status of the submission")
|
43 |
+
created_at: str = Field(description="ISO format timestamp of creation")
|
44 |
+
updated_at: str = Field(description="ISO format timestamp of last update")
|
45 |
+
|
46 |
+
def to_dict(self) -> Dict:
|
47 |
+
"""Convert to dictionary format suitable for HF datasets"""
|
48 |
+
data = self.model_dump()
|
49 |
+
if self.workflow:
|
50 |
+
data["workflow"] = self.workflow.model_dump(exclude_defaults=True)
|
51 |
+
return data
|
52 |
+
|
53 |
+
@classmethod
|
54 |
+
def from_dict(cls, data: Dict) -> "Submission":
|
55 |
+
"""Create instance from dictionary format used in HF datasets"""
|
56 |
+
if data.get("workflow"):
|
57 |
+
data["workflow"] = Workflow.model_validate(data["workflow"])
|
58 |
+
return cls.model_validate(data)
|
src/submission/submit.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import traceback
|
4 |
+
from datetime import datetime, timedelta, timezone
|
5 |
+
from typing import Optional
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import yaml
|
9 |
+
|
10 |
+
from src.display.formatting import styled_error, styled_message
|
11 |
+
from src.envs import API, DAILY_SUBMISSION_LIMIT_PER_USER, EVAL_REQUESTS_PATH, QUEUE_REPO
|
12 |
+
from src.submission.structs import CompetitionType, Submission, SubmissionStatus
|
13 |
+
from workflows.structs import Workflow
|
14 |
+
|
15 |
+
|
16 |
+
def get_user_submissions_today(username: str, competition_type: str) -> list[Submission]:
|
17 |
+
today = datetime.now(timezone.utc).strftime("%Y%m%d")
|
18 |
+
if username is None:
|
19 |
+
raise gr.Error("Authentication required. Please log in to view your submissions.")
|
20 |
+
out_dir = f"{EVAL_REQUESTS_PATH}/{username}"
|
21 |
+
submissions = []
|
22 |
+
if not os.path.exists(out_dir):
|
23 |
+
return submissions
|
24 |
+
for file in os.listdir(out_dir):
|
25 |
+
if not file.startswith(f"{competition_type}_"):
|
26 |
+
continue
|
27 |
+
with open(os.path.join(out_dir, file), "r") as f:
|
28 |
+
submission = Submission.from_dict(json.load(f))
|
29 |
+
if submission.created_at.startswith(today):
|
30 |
+
submissions.append(submission)
|
31 |
+
return submissions
|
32 |
+
|
33 |
+
|
34 |
+
def get_time_until_next_submission(tz: timezone = timezone.utc) -> str:
|
35 |
+
next_day_00 = datetime.now(tz) + timedelta(days=1)
|
36 |
+
next_day_00 = next_day_00.replace(hour=0, minute=0, second=0, microsecond=0)
|
37 |
+
remaining_time = next_day_00 - datetime.now(tz)
|
38 |
+
hours = remaining_time.seconds // 3600
|
39 |
+
minutes = (remaining_time.seconds % 3600) // 60
|
40 |
+
remaining_time_str = f"{hours} hours {minutes} mins"
|
41 |
+
return remaining_time_str
|
42 |
+
|
43 |
+
|
44 |
+
def create_submission(
|
45 |
+
username: str,
|
46 |
+
model_name: str,
|
47 |
+
description: str,
|
48 |
+
workflow: Workflow,
|
49 |
+
competition_type: CompetitionType,
|
50 |
+
) -> Submission:
|
51 |
+
"""
|
52 |
+
Create a submission for a tossup model.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
name: Display name of the submission
|
56 |
+
description: Detailed description of what the submission does
|
57 |
+
user_email: Email of the user who created the submission
|
58 |
+
workflow: The workflow configuration for the tossup model
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
Submission object if successful, None if validation fails
|
62 |
+
"""
|
63 |
+
# Create the submission
|
64 |
+
dt = datetime.now(timezone.utc)
|
65 |
+
submission = Submission(
|
66 |
+
id=f"{competition_type}_{dt.strftime('%Y%m%d_%H%M%S')}_{model_name.lower().replace(' ', '_')}",
|
67 |
+
model_name=model_name,
|
68 |
+
username=username,
|
69 |
+
description=description,
|
70 |
+
competition_type=competition_type,
|
71 |
+
submission_type="simple_workflow",
|
72 |
+
workflow=workflow,
|
73 |
+
status="submitted",
|
74 |
+
created_at=dt.isoformat(),
|
75 |
+
updated_at=dt.isoformat(),
|
76 |
+
)
|
77 |
+
|
78 |
+
return submission
|
79 |
+
|
80 |
+
|
81 |
+
def submit_model(
|
82 |
+
model_name: str,
|
83 |
+
description: str,
|
84 |
+
workflow: Workflow,
|
85 |
+
competition_type: CompetitionType,
|
86 |
+
profile: gr.OAuthProfile | None,
|
87 |
+
) -> str:
|
88 |
+
"""
|
89 |
+
Submit a tossup model for evaluation.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
name: Display name of the submission
|
93 |
+
description: Detailed description of what the submission does
|
94 |
+
user_email: Email of the user who created the submission
|
95 |
+
workflow: The workflow configuration for the tossup model
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
Status message
|
99 |
+
"""
|
100 |
+
|
101 |
+
if profile is None:
|
102 |
+
return styled_error("Authentication required. Please log in first to submit your model.")
|
103 |
+
|
104 |
+
username = profile.username
|
105 |
+
|
106 |
+
if len(get_user_submissions_today(username)) >= DAILY_SUBMISSION_LIMIT_PER_USER:
|
107 |
+
time_str = get_time_until_next_submission()
|
108 |
+
return styled_error(
|
109 |
+
f"Daily submission limit of {DAILY_SUBMISSION_LIMIT_PER_USER} reached. Please try again in \n {time_str}."
|
110 |
+
)
|
111 |
+
try:
|
112 |
+
submission = create_submission(
|
113 |
+
username=username,
|
114 |
+
model_name=model_name,
|
115 |
+
description=description,
|
116 |
+
workflow=workflow,
|
117 |
+
competition_type=competition_type,
|
118 |
+
)
|
119 |
+
# Convert to dictionary format
|
120 |
+
submission_dict = submission.to_dict()
|
121 |
+
|
122 |
+
# Create output directory path
|
123 |
+
out_dir = f"{EVAL_REQUESTS_PATH}/{username}"
|
124 |
+
out_path = f"{out_dir}/{submission.id}.json"
|
125 |
+
|
126 |
+
# Upload to HuggingFace dataset
|
127 |
+
API.upload_file(
|
128 |
+
path_or_fileobj=json.dumps(submission_dict, indent=2).encode(),
|
129 |
+
path_in_repo=out_path.split("eval-queue/")[1],
|
130 |
+
repo_id=QUEUE_REPO,
|
131 |
+
repo_type="dataset",
|
132 |
+
commit_message=f"Add tossup submission {submission.id}",
|
133 |
+
)
|
134 |
+
|
135 |
+
return styled_message(
|
136 |
+
f"Successfully submitted tossup model!\n"
|
137 |
+
f"Submission ID: {submission.id}\n"
|
138 |
+
f"Name: {username}/{model_name}\n"
|
139 |
+
f"Please wait for up to an hour for the model to show in the PENDING list."
|
140 |
+
)
|
141 |
+
|
142 |
+
except Exception as e:
|
143 |
+
traceback.print_exc()
|
144 |
+
return styled_error(f"Error submitting model: {str(e)}")
|
145 |
+
|
146 |
+
|
147 |
+
if __name__ == "__main__":
|
148 |
+
# Example usage
|
149 |
+
from workflows.factory import create_quizbowl_simple_step_initial_setup
|
150 |
+
|
151 |
+
# Create workflow
|
152 |
+
model_step = create_quizbowl_simple_step_initial_setup()
|
153 |
+
model_step.model = "gpt-4"
|
154 |
+
model_step.provider = "openai"
|
155 |
+
model_step.temperature = 0.7
|
156 |
+
|
157 |
+
workflow = Workflow(
|
158 |
+
inputs=["question_text"],
|
159 |
+
outputs={"answer": "A.answer", "confidence": "A.confidence"},
|
160 |
+
steps={"A": model_step},
|
161 |
+
)
|
162 |
+
|
163 |
+
# Submit model
|
164 |
+
result = submit_model(
|
165 |
+
model_name="GPT-4 Tossup",
|
166 |
+
description="A simple GPT-4 model for tossup questions",
|
167 |
+
workflow=workflow,
|
168 |
+
competition_type="tossup",
|
169 |
+
)
|
170 |
+
print(result)
|
src/utils.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Description: Utility functions for the model_step component.
|
2 |
+
|
3 |
+
from envs import AVAILABLE_MODELS, UNSELECTED_MODEL_NAME
|
4 |
+
|
5 |
+
|
6 |
+
def guess_model_provider(model_name: str):
|
7 |
+
"""Guess the provider of a model name."""
|
8 |
+
model_name = model_name.lower()
|
9 |
+
if model_name.startswith("gpt-"):
|
10 |
+
return "OpenAI"
|
11 |
+
if "sonnet" in model_name or "claude" in model_name:
|
12 |
+
return "Anthropic"
|
13 |
+
raise ValueError(f"Model `{model_name}` not yet supported")
|
14 |
+
|
15 |
+
|
16 |
+
def get_model_and_provider(model_name: str):
|
17 |
+
"""Get the model and provider from a model name."""
|
18 |
+
if model_name == UNSELECTED_MODEL_NAME:
|
19 |
+
return "", ""
|
20 |
+
splits = model_name.split("/", maxsplit=1)
|
21 |
+
if len(splits) == 1:
|
22 |
+
full_model_name = AVAILABLE_MODELS.get(model_name, model_name)
|
23 |
+
provider = guess_model_provider(full_model_name)
|
24 |
+
return full_model_name, provider
|
25 |
+
if len(splits) == 2:
|
26 |
+
provider, model_name = splits
|
27 |
+
full_model_name = AVAILABLE_MODELS.get(model_name, model_name)
|
28 |
+
return full_model_name, provider
|
29 |
+
raise ValueError(f"Model `{model_name}` not yet supported")
|
30 |
+
|
31 |
+
|
32 |
+
def get_full_model_name(model_name: str, provider: str = ""):
|
33 |
+
"""Get the full model name from a model name."""
|
34 |
+
if model_name == "":
|
35 |
+
return UNSELECTED_MODEL_NAME
|
36 |
+
if not provider:
|
37 |
+
provider = guess_model_provider(model_name)
|
38 |
+
return f"{provider}/{model_name}"
|
src/workflows/README.md
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Workflows Subpackage
|
2 |
+
|
3 |
+
This subpackage provides a framework for defining, validating, and executing workflows composed of interconnected model steps with dependency management.
|
4 |
+
|
5 |
+
## Overview
|
6 |
+
|
7 |
+
The workflows subpackage enables the creation and execution of workflows where multiple model steps can be combined, with outputs from earlier steps feeding into inputs of later steps. The package handles dependency resolution, execution order, and error handling.
|
8 |
+
|
9 |
+
## Components
|
10 |
+
|
11 |
+
### `structs.py`
|
12 |
+
|
13 |
+
Contains the core data structures used throughout the workflow system:
|
14 |
+
|
15 |
+
- `Field`: Represents an input or output field with name and type information
|
16 |
+
- `ModelStep`: Represents a single step in a workflow with input fields, output fields, and model details
|
17 |
+
- `Workflow`: A collection of ModelSteps with their identifiers
|
18 |
+
|
19 |
+
### `utils.py`
|
20 |
+
|
21 |
+
Provides utility functions for workflow operations:
|
22 |
+
|
23 |
+
- `_create_variable_step_mapping`: Maps variables to the steps that produce them
|
24 |
+
- `create_dependency_graph`: Builds a dependency graph representing the execution order constraints
|
25 |
+
- `topological_sort`: Sorts steps in execution order based on their dependencies
|
26 |
+
|
27 |
+
### `workflow_executor.py`
|
28 |
+
|
29 |
+
Handles the execution of workflows:
|
30 |
+
|
31 |
+
- Processes inputs and outputs between steps
|
32 |
+
- Coordinates the execution of model steps in the correct order
|
33 |
+
- Integrates with external model providers (e.g., via litellm)
|
34 |
+
|
35 |
+
### `errors.py`
|
36 |
+
|
37 |
+
Defines custom exceptions for workflow-related errors:
|
38 |
+
|
39 |
+
- `WorkflowError`: Base class for workflow errors
|
40 |
+
- `CyclicDependencyError`: Raised when detecting cycles in the workflow graph
|
41 |
+
- `UnknownVariableError`: Raised when a step requires a variable that's not provided or produced
|
42 |
+
|
43 |
+
## Usage Example
|
44 |
+
|
45 |
+
```python
|
46 |
+
from workflows.structs import Field, ModelStep, Workflow
|
47 |
+
|
48 |
+
# Define a workflow with two steps
|
49 |
+
step1 = ModelStep(
|
50 |
+
input_fields=[Field(name="query", type="string")],
|
51 |
+
output_fields=[Field(name="summary", type="string")],
|
52 |
+
model="gpt-3.5-turbo",
|
53 |
+
system_prompt="Summarize the following text"
|
54 |
+
)
|
55 |
+
|
56 |
+
step2 = ModelStep(
|
57 |
+
input_fields=[Field(name="summary", type="string", variable="step1.summary")],
|
58 |
+
output_fields=[Field(name="key_points", type="array")],
|
59 |
+
model="gpt-4",
|
60 |
+
system_prompt="Extract key points from the summary"
|
61 |
+
)
|
62 |
+
|
63 |
+
workflow = Workflow(steps={"step1": step1, "step2": step2})
|
64 |
+
|
65 |
+
# Execute the workflow
|
66 |
+
from workflows.workflow_executor import execute_workflow
|
67 |
+
|
68 |
+
result = execute_workflow(
|
69 |
+
workflow=workflow,
|
70 |
+
input_values={"query": "Long text to summarize..."}
|
71 |
+
)
|
72 |
+
|
73 |
+
# Access results
|
74 |
+
summary = result["step1.summary"]
|
75 |
+
key_points = result["step2.key_points"]
|
76 |
+
```
|
77 |
+
|
78 |
+
## Error Handling
|
79 |
+
|
80 |
+
The workflows system provides robust error handling:
|
81 |
+
|
82 |
+
- Detects cyclic dependencies in workflow definitions
|
83 |
+
- Validates input/output variable references
|
84 |
+
- Ensures all required inputs are provided
|
85 |
+
|
86 |
+
## Extending the Workflows System
|
87 |
+
|
88 |
+
To extend the workflows system:
|
89 |
+
|
90 |
+
1. Add new model step types by extending the `ModelStep` class
|
91 |
+
2. Create custom field types by extending validation in the execution logic
|
92 |
+
3. Implement additional error types in `errors.py` for specialized error handling
|
src/workflows/errors.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Custom exceptions for workflow validation and execution errors.
|
3 |
+
|
4 |
+
This module defines the exception hierarchy for the workflows package, enabling
|
5 |
+
specific error types to be raised and caught during workflow validation and execution.
|
6 |
+
Each exception provides detailed error messages to help diagnose and fix issues in
|
7 |
+
workflow definitions or execution.
|
8 |
+
|
9 |
+
Exception hierarchy:
|
10 |
+
- WorkflowError (base class)
|
11 |
+
- UnknownVariableError (missing variable reference)
|
12 |
+
- CyclicDependencyError (circular dependencies)
|
13 |
+
- FunctionNotFoundError (missing function reference)
|
14 |
+
"""
|
15 |
+
|
16 |
+
|
17 |
+
# Define custom exceptions for workflow errors
|
18 |
+
class WorkflowError(Exception):
|
19 |
+
"""
|
20 |
+
Base exception class for all workflow-related errors.
|
21 |
+
|
22 |
+
This is the parent class for all workflow-specific exceptions and can be used
|
23 |
+
to catch any error from the workflows package.
|
24 |
+
"""
|
25 |
+
|
26 |
+
pass
|
27 |
+
|
28 |
+
|
29 |
+
class UnknownVariableError(WorkflowError):
|
30 |
+
"""
|
31 |
+
Raised when a workflow step references a variable that doesn't exist.
|
32 |
+
|
33 |
+
This typically occurs when a step's input field references a variable that is neither
|
34 |
+
provided as an external input nor produced as an output by any previous step.
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(self, var: str):
|
38 |
+
super().__init__(f"Unknown variable referenced: {var}")
|
39 |
+
|
40 |
+
|
41 |
+
class CyclicDependencyError(WorkflowError):
|
42 |
+
"""
|
43 |
+
Raised when a cyclic dependency is detected in a workflow.
|
44 |
+
|
45 |
+
A cyclic dependency occurs when there is a circular reference in the workflow graph,
|
46 |
+
such as step A depending on step B, which depends on step A. Such workflows cannot
|
47 |
+
be executed because there's no valid order to process the steps.
|
48 |
+
"""
|
49 |
+
|
50 |
+
def __init__(self):
|
51 |
+
super().__init__("Cyclic dependency detected in workflow")
|
52 |
+
|
53 |
+
|
54 |
+
class FunctionNotFoundError(WorkflowError):
|
55 |
+
"""
|
56 |
+
Raised when a referenced function cannot be found during workflow execution.
|
57 |
+
|
58 |
+
This typically occurs when a step references a function that doesn't exist in
|
59 |
+
the available function registry or namespace.
|
60 |
+
"""
|
61 |
+
|
62 |
+
def __init__(self, func_name: str):
|
63 |
+
super().__init__(f"Function not found: {func_name}")
|
src/workflows/executors.py
ADDED
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# %%
|
2 |
+
import json
|
3 |
+
from typing import Any
|
4 |
+
|
5 |
+
import pydantic
|
6 |
+
|
7 |
+
from llms import completion
|
8 |
+
from workflows.errors import WorkflowError
|
9 |
+
from workflows.structs import InputField, ModelStep, OutputField, Workflow
|
10 |
+
from workflows.utils import create_dependency_graph, topological_sort
|
11 |
+
|
12 |
+
"""
|
13 |
+
Core workflow execution functionality.
|
14 |
+
|
15 |
+
This module handles the execution of defined workflows, including input processing,
|
16 |
+
dependency-based execution order, model calling, and output collection. It integrates
|
17 |
+
with the litellm library to handle model interactions.
|
18 |
+
|
19 |
+
Key components:
|
20 |
+
- Utility functions for input/output transformation
|
21 |
+
- Input processing and validation
|
22 |
+
- Model step execution
|
23 |
+
- Complete workflow execution with dependency resolution
|
24 |
+
|
25 |
+
The module orchestrates the execution of steps in the correct order based on their
|
26 |
+
dependencies and manages the flow of data between steps.
|
27 |
+
"""
|
28 |
+
|
29 |
+
|
30 |
+
def upper(x):
|
31 |
+
if isinstance(x, str):
|
32 |
+
return x.upper()
|
33 |
+
return x
|
34 |
+
|
35 |
+
|
36 |
+
def lower(x):
|
37 |
+
if isinstance(x, str):
|
38 |
+
return x.lower()
|
39 |
+
return x
|
40 |
+
|
41 |
+
|
42 |
+
TYPE_MAP = {
|
43 |
+
"str": str,
|
44 |
+
"int": int,
|
45 |
+
"float": float,
|
46 |
+
"bool": bool,
|
47 |
+
}
|
48 |
+
|
49 |
+
FUNCTION_MAP = {
|
50 |
+
"upper": upper,
|
51 |
+
"lower": lower,
|
52 |
+
"len": len,
|
53 |
+
"split": str.split,
|
54 |
+
}
|
55 |
+
|
56 |
+
|
57 |
+
def get_type(type_str: str) -> type:
|
58 |
+
return TYPE_MAP.get(type_str, eval(type_str))
|
59 |
+
|
60 |
+
|
61 |
+
def create_processed_inputs(model_step: ModelStep, available_vars: dict[str, Any]) -> dict[str, Any]:
|
62 |
+
"""
|
63 |
+
Creates processed inputs for a model step.
|
64 |
+
|
65 |
+
This function extracts and processes the required inputs for a model step based on
|
66 |
+
its input field definitions. It retrieves values from the available variables dictionary
|
67 |
+
and applies any specified transformations.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
model_step (ModelStep): The model step for which to create processed inputs.
|
71 |
+
available_vars (dict[str, Any]): Dictionary of variables available for use as inputs.
|
72 |
+
Keys are variable names, values are the variable values.
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
dict[str, Any]: A dictionary of processed inputs ready for use by the model step.
|
76 |
+
Keys are input field names, values are the processed input values.
|
77 |
+
|
78 |
+
Raises:
|
79 |
+
WorkflowError: If a required variable is not found in available_vars,
|
80 |
+
or if a specified transformation function is not available.
|
81 |
+
|
82 |
+
Example:
|
83 |
+
>>> available_vars = {"step1.output": "Hello World"}
|
84 |
+
>>> create_processed_inputs(model_step, available_vars)
|
85 |
+
{"input_field_name": "HELLO WORLD"} # If upper transformation was specified
|
86 |
+
"""
|
87 |
+
processed_inputs: dict[str, Any] = {}
|
88 |
+
for input_field in model_step.input_fields:
|
89 |
+
var = input_field.variable
|
90 |
+
value = available_vars[var]
|
91 |
+
if input_field.func is not None:
|
92 |
+
func = FUNCTION_MAP.get(input_field.func)
|
93 |
+
func = func or eval(input_field.func)
|
94 |
+
value = func(value)
|
95 |
+
processed_inputs[input_field.name] = value
|
96 |
+
return processed_inputs
|
97 |
+
|
98 |
+
|
99 |
+
# %%
|
100 |
+
def execute_model_step(
|
101 |
+
model_step: ModelStep, available_vars: dict[str, Any], return_full_content: bool = False
|
102 |
+
) -> dict[str, Any] | tuple[dict[str, Any], str]:
|
103 |
+
"""
|
104 |
+
Executes a model step using the provided available variables.
|
105 |
+
|
106 |
+
This function handles the complete execution of a model step, including:
|
107 |
+
1. Processing inputs using variable references and transformations
|
108 |
+
2. Constructing the appropriate prompt for the model
|
109 |
+
3. Calling the model via litellm with structured output
|
110 |
+
4. Processing and validating the model's response
|
111 |
+
5. Applying any output transformations
|
112 |
+
|
113 |
+
The function supports different providers and model types through the litellm
|
114 |
+
integration, allowing for a consistent interface regardless of the underlying model.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
model_step (ModelStep): The model step to execute, containing model details,
|
118 |
+
input/output specifications, and system prompt.
|
119 |
+
available_vars (dict[str, Any]): A dictionary of all variables available to this step,
|
120 |
+
including outputs from previous steps and external inputs.
|
121 |
+
|
122 |
+
Returns:
|
123 |
+
dict[str, Any]: A dictionary of processed outputs from the model step,
|
124 |
+
with keys matching the output field names.
|
125 |
+
|
126 |
+
Raises:
|
127 |
+
WorkflowError: If there's an error in input processing, model execution,
|
128 |
+
or output validation.
|
129 |
+
|
130 |
+
Example:
|
131 |
+
>>> step = ModelStep(
|
132 |
+
... id="summarize",
|
133 |
+
... model="gpt-3.5-turbo",
|
134 |
+
... provider="openai",
|
135 |
+
... call_type="llm",
|
136 |
+
... system_prompt="Summarize the text",
|
137 |
+
... input_fields=[InputField(name="text", variable="input_text", description="Text to summarize")],
|
138 |
+
... output_fields=[OutputField(name="summary", type="str", description="Summary of the text")]
|
139 |
+
... )
|
140 |
+
>>> execute_model_step(step, {"input_text": "Long text to be summarized..."})
|
141 |
+
{"summary": "A concise summary of the text."}
|
142 |
+
"""
|
143 |
+
# Ensure inputs are processed using the specified functions in input_fields.
|
144 |
+
processed_inputs = create_processed_inputs(model_step, available_vars)
|
145 |
+
|
146 |
+
# Construct the input prompt for the model
|
147 |
+
input_str = ", ".join(f"{k}={v}" for k, v in processed_inputs.items())
|
148 |
+
step_result = f"{model_step.system_prompt} | Inputs: {input_str}"
|
149 |
+
|
150 |
+
# Define the expected output fields and their types
|
151 |
+
fields = {
|
152 |
+
field.name: (get_type(field.type), pydantic.Field(..., description=field.description))
|
153 |
+
for field in model_step.output_fields
|
154 |
+
}
|
155 |
+
ModelResponse = pydantic.create_model("ModelResponse", **fields)
|
156 |
+
|
157 |
+
# Execute the model step using litellm
|
158 |
+
api_response = completion(
|
159 |
+
model=f"{model_step.provider}/{model_step.model}",
|
160 |
+
system=model_step.system_prompt,
|
161 |
+
prompt=step_result,
|
162 |
+
response_format=ModelResponse,
|
163 |
+
)
|
164 |
+
# api_response = litellm.completion(
|
165 |
+
# model=model_step.model,
|
166 |
+
# messages=[{"role": "user", "content": step_result}],
|
167 |
+
# response_format=ModelResponse,
|
168 |
+
# )
|
169 |
+
|
170 |
+
# Extract and parse the model response
|
171 |
+
# model_response_content = api_response["choices"][0]["message"]["content"]
|
172 |
+
# model_response = json.loads(model_response_content)
|
173 |
+
model_response = api_response["output"]
|
174 |
+
# Map the parsed response to the output fields
|
175 |
+
outputs = {field.name: model_response[field.name] for field in model_step.output_fields}
|
176 |
+
if return_full_content:
|
177 |
+
return outputs, api_response["content"]
|
178 |
+
return outputs
|
179 |
+
|
180 |
+
|
181 |
+
# Example usage
|
182 |
+
if __name__ == "__main__":
|
183 |
+
# Define a simple model step
|
184 |
+
model_step = ModelStep(
|
185 |
+
id="step1",
|
186 |
+
model="gpt-4o-mini",
|
187 |
+
provider="OpenAI",
|
188 |
+
call_type="llm",
|
189 |
+
system_prompt="You are a simple NLP tool that takes a string, and a number N, and return the first N entities in the string, and the total count of entities in the string.",
|
190 |
+
input_fields=[
|
191 |
+
InputField(name="sentence", description="The sentence to process", variable="sentence", func=None),
|
192 |
+
InputField(name="n", description="The number of entities to return", variable="n", func=None),
|
193 |
+
],
|
194 |
+
output_fields=[
|
195 |
+
OutputField(
|
196 |
+
name="entities",
|
197 |
+
description="The first N entities in the string as a list of strings",
|
198 |
+
type="list[str]",
|
199 |
+
func=None,
|
200 |
+
),
|
201 |
+
OutputField(name="count", description="The total count of entities in the string", type="int", func=None),
|
202 |
+
],
|
203 |
+
)
|
204 |
+
|
205 |
+
# Define processed inputs
|
206 |
+
processed_inputs = {"sentence": "Abdul Akbar is a good person, but Jesus is the son of God.", "n": 3}
|
207 |
+
|
208 |
+
# Execute the model step
|
209 |
+
outputs = execute_model_step(model_step, processed_inputs)
|
210 |
+
print(outputs)
|
211 |
+
|
212 |
+
|
213 |
+
# %%
|
214 |
+
def execute_workflow(
|
215 |
+
workflow: Workflow, input_values: dict[str, Any], return_full_content: bool = False
|
216 |
+
) -> dict[str, Any] | tuple[dict[str, Any], str]:
|
217 |
+
"""
|
218 |
+
Execute the given workflow as a computational graph.
|
219 |
+
|
220 |
+
This function orchestrates the complete execution of a workflow by:
|
221 |
+
|
222 |
+
1. Validating and populating initial values using the provided external inputs
|
223 |
+
2. Building a dependency graph between workflow steps
|
224 |
+
3. Determining a valid execution order using topological sorting
|
225 |
+
4. Executing each step in the correct order, with inputs from previous steps
|
226 |
+
5. Collecting and returning the final outputs
|
227 |
+
|
228 |
+
The execution process ensures that all dependencies are satisfied before a step
|
229 |
+
is executed, and that the data flows correctly between steps according to the
|
230 |
+
variable references defined in each step's input fields.
|
231 |
+
|
232 |
+
Args:
|
233 |
+
workflow (Workflow): The workflow to execute, containing steps, their
|
234 |
+
dependencies, and input/output specifications.
|
235 |
+
input_values (dict[str, Any]): External input values to be used by the workflow.
|
236 |
+
Keys should match the required workflow.inputs.
|
237 |
+
|
238 |
+
Returns:
|
239 |
+
dict[str, Any]: A dictionary of the workflow's outputs, with keys matching
|
240 |
+
the variables defined in workflow.outputs.
|
241 |
+
|
242 |
+
Raises:
|
243 |
+
UnknownVariableError: If an input_field references a variable that is not
|
244 |
+
provided externally nor produced by any step.
|
245 |
+
CyclicDependencyError: If the workflow contains a circular dependency that
|
246 |
+
prevents a valid execution order.
|
247 |
+
FunctionNotFoundError: If a transformation function specified in input_fields.func
|
248 |
+
or output_fields.func is not available.
|
249 |
+
WorkflowError: For any other workflow-related errors, such as missing required inputs.
|
250 |
+
|
251 |
+
Example:
|
252 |
+
>>> workflow = Workflow(
|
253 |
+
... steps={
|
254 |
+
... "extract": ModelStep(...), # A step that extracts entities
|
255 |
+
... "analyze": ModelStep(...) # A step that analyzes the entities
|
256 |
+
... },
|
257 |
+
... inputs=["text"],
|
258 |
+
... outputs=["analyze.sentiment", "extract.entities"]
|
259 |
+
... )
|
260 |
+
>>> result = execute_workflow(workflow, {"text": "Apple is launching a new product tomorrow."})
|
261 |
+
>>> print(result["analyze.sentiment"])
|
262 |
+
"positive"
|
263 |
+
>>> print(result["extract.entities"])
|
264 |
+
["Apple", "product"]
|
265 |
+
"""
|
266 |
+
# Step 1: Pre-populate computed values with external workflow inputs.
|
267 |
+
computed_values: dict[str, Any] = {}
|
268 |
+
for var in workflow.inputs:
|
269 |
+
if var not in input_values:
|
270 |
+
raise WorkflowError(f"Missing required workflow input: {var}")
|
271 |
+
computed_values[var] = input_values[var]
|
272 |
+
|
273 |
+
# Step 2: Build dependency graph among model steps.
|
274 |
+
# For each step, examine its input_fields. If an input is not in the pre-populated external inputs,
|
275 |
+
# then it is expected to be produced by some step. Otherwise, raise an error.
|
276 |
+
dependencies = create_dependency_graph(workflow, input_values)
|
277 |
+
|
278 |
+
# Step 3: Determine the execution order of the steps using topological sort.
|
279 |
+
# Raises an error if a cycle is detected.
|
280 |
+
execution_order = topological_sort(dependencies)
|
281 |
+
|
282 |
+
# Step 4: Execute steps in topological order.
|
283 |
+
for step_id in execution_order:
|
284 |
+
step = workflow.steps[step_id]
|
285 |
+
|
286 |
+
# Execute the step
|
287 |
+
outputs = execute_model_step(step, computed_values)
|
288 |
+
outputs = {f"{step_id}.{k}": v for k, v in outputs.items()}
|
289 |
+
computed_values.update(outputs)
|
290 |
+
|
291 |
+
# Step 5: Gather and return workflow outputs.
|
292 |
+
final_outputs: dict[str, Any] = {}
|
293 |
+
for target, var in workflow.outputs.items():
|
294 |
+
if var not in computed_values:
|
295 |
+
raise WorkflowError(f"Workflow output variable {var} was not produced")
|
296 |
+
final_outputs[target] = computed_values[var]
|
297 |
+
|
298 |
+
return final_outputs
|
299 |
+
|
300 |
+
|
301 |
+
def run_examples():
|
302 |
+
"""
|
303 |
+
Runs three example workflows demonstrating:
|
304 |
+
1. A successful (linear) workflow execution.
|
305 |
+
2. A cyclic dependency error.
|
306 |
+
3. An unknown variable dependency error.
|
307 |
+
"""
|
308 |
+
print("Example 1: Successful Workflow Execution")
|
309 |
+
# Example 1: Simple linear workflow.
|
310 |
+
# External input "input.value" is provided. Two steps:
|
311 |
+
# - step1 takes "input.value" and produces "step1.result".
|
312 |
+
# - step2 uses "step1.result" and produces "step2.final".
|
313 |
+
from workflows.structs import ModelStep, Workflow
|
314 |
+
|
315 |
+
workflow_success = Workflow(
|
316 |
+
steps={
|
317 |
+
"step1": ModelStep(
|
318 |
+
id="step1",
|
319 |
+
model="gpt-4o-mini",
|
320 |
+
provider="OpenAI",
|
321 |
+
call_type="llm",
|
322 |
+
system_prompt="Step1 processing",
|
323 |
+
input_fields=[InputField(name="value", description="Input value", variable="input.value")],
|
324 |
+
output_fields=[OutputField(name="result", description="Processed result", type="str", func="upper")],
|
325 |
+
),
|
326 |
+
"step2": ModelStep(
|
327 |
+
id="step2",
|
328 |
+
model="gpt-4o-mini",
|
329 |
+
provider="OpenAI",
|
330 |
+
call_type="llm",
|
331 |
+
system_prompt="Step2 processing",
|
332 |
+
input_fields=[InputField(name="result", description="Result from step1", variable="step1.result")],
|
333 |
+
output_fields=[OutputField(name="final", description="Final output", type="str", func="lower")],
|
334 |
+
),
|
335 |
+
},
|
336 |
+
inputs=["input.value"],
|
337 |
+
outputs={"final": "step2.final"},
|
338 |
+
)
|
339 |
+
input_values_success = {"input.value": "Hello, World!"}
|
340 |
+
try:
|
341 |
+
outputs = execute_workflow(workflow_success, input_values_success)
|
342 |
+
print("Workflow outputs:", outputs)
|
343 |
+
except WorkflowError as e:
|
344 |
+
print("Workflow failed with error:", e)
|
345 |
+
|
346 |
+
print("\nExample 2: Cyclic Dependency Workflow")
|
347 |
+
# Example 2: Cyclic dependency.
|
348 |
+
# stepA depends on an output from stepB and vice versa.
|
349 |
+
workflow_cycle = Workflow(
|
350 |
+
steps={
|
351 |
+
"stepA": ModelStep(
|
352 |
+
id="stepA",
|
353 |
+
model="gpt-4o-mini",
|
354 |
+
provider="OpenAI",
|
355 |
+
call_type="llm",
|
356 |
+
system_prompt="StepA processing",
|
357 |
+
input_fields=[
|
358 |
+
InputField(name="input", description="Input from stepB", variable="stepB.output", func="identity")
|
359 |
+
],
|
360 |
+
output_fields=[OutputField(name="output", description="Output from A", type="str", func="upper")],
|
361 |
+
),
|
362 |
+
"stepB": ModelStep(
|
363 |
+
id="stepB",
|
364 |
+
model="gpt-4o-mini",
|
365 |
+
provider="OpenAI",
|
366 |
+
call_type="llm",
|
367 |
+
system_prompt="StepB processing",
|
368 |
+
input_fields=[
|
369 |
+
InputField(name="input", description="Input from stepA", variable="stepA.output", func="identity")
|
370 |
+
],
|
371 |
+
output_fields=[OutputField(name="output", description="Output from B", type="str", func="upper")],
|
372 |
+
),
|
373 |
+
},
|
374 |
+
inputs=[], # no external inputs
|
375 |
+
outputs={"output": "stepB.output"},
|
376 |
+
)
|
377 |
+
try:
|
378 |
+
outputs = execute_workflow(workflow_cycle, {})
|
379 |
+
print("Workflow outputs:", outputs)
|
380 |
+
except WorkflowError as e:
|
381 |
+
print("Workflow failed with error:", e)
|
382 |
+
|
383 |
+
print("\nExample 3: Unknown Variable Dependency Workflow")
|
384 |
+
# Example 3: A workflow that references a variable not provided as an input or produced by any step.
|
385 |
+
workflow_unknown = Workflow(
|
386 |
+
steps={
|
387 |
+
"stepX": ModelStep(
|
388 |
+
id="stepX",
|
389 |
+
model="gpt-4o-mini",
|
390 |
+
provider="OpenAI",
|
391 |
+
call_type="llm",
|
392 |
+
system_prompt="StepX processing",
|
393 |
+
input_fields=[
|
394 |
+
InputField(
|
395 |
+
name="input", description="Non-existent input", variable="nonexistent.value", func="identity"
|
396 |
+
)
|
397 |
+
],
|
398 |
+
output_fields=[OutputField(name="output", description="Output from X", type="str", func="upper")],
|
399 |
+
)
|
400 |
+
},
|
401 |
+
inputs=[], # no external inputs
|
402 |
+
outputs={"output": "stepX.output"},
|
403 |
+
)
|
404 |
+
try:
|
405 |
+
outputs = execute_workflow(workflow_unknown, {})
|
406 |
+
print("Workflow outputs:", outputs)
|
407 |
+
except WorkflowError as e:
|
408 |
+
print("Workflow failed with error:", e)
|
409 |
+
|
410 |
+
|
411 |
+
if __name__ == "__main__":
|
412 |
+
# create example of model_step
|
413 |
+
model_step = ModelStep(
|
414 |
+
id="step1",
|
415 |
+
model="gpt-4o-mini",
|
416 |
+
provider="OpenAI",
|
417 |
+
call_type="llm",
|
418 |
+
system_prompt="You are a simple NLP tool that takes a string, and a number N, and return the first N entities in the string, and the total count of entities in the string.",
|
419 |
+
input_fields=[
|
420 |
+
InputField(name="sentence", description="The sentence to process", variable="sentence", func=None),
|
421 |
+
InputField(name="n", description="The number of entities to return", variable="n", func=None),
|
422 |
+
],
|
423 |
+
output_fields=[
|
424 |
+
OutputField(
|
425 |
+
name="entities",
|
426 |
+
description="The first N entities in the string as a list of strings",
|
427 |
+
type="list[str]",
|
428 |
+
func=None,
|
429 |
+
),
|
430 |
+
OutputField(name="count", description="The total count of entities in the string", type="int", func=None),
|
431 |
+
],
|
432 |
+
)
|
433 |
+
|
434 |
+
processed_inputs = {"sentence": "Abdul Akbar is a good person, but Jesus is the son of God.", "n": 3}
|
435 |
+
processed_inputs = create_processed_inputs(model_step, processed_inputs)
|
436 |
+
print(processed_inputs)
|
437 |
+
|
438 |
+
run_examples()
|
439 |
+
|
440 |
+
# %%
|
src/workflows/factory.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# %%
|
2 |
+
from workflows.structs import Field, InputField, ModelStep, OutputField, Workflow
|
3 |
+
|
4 |
+
INITIAL_SYS_PROMPT = """You are a helpful performant question answering bot.
|
5 |
+
Given a question clue, output your most likely guess in a couple words with a calibrated confidence for the guess.
|
6 |
+
"""
|
7 |
+
|
8 |
+
|
9 |
+
def create_simple_workflow():
|
10 |
+
pass
|
11 |
+
|
12 |
+
|
13 |
+
def create_first_step_input_fields() -> list[InputField]:
|
14 |
+
return [
|
15 |
+
InputField(
|
16 |
+
name="question",
|
17 |
+
description="The question text progressively revealed to the agent so far.",
|
18 |
+
variable="question_text",
|
19 |
+
)
|
20 |
+
]
|
21 |
+
|
22 |
+
|
23 |
+
def create_empty_input_field() -> list[InputField]:
|
24 |
+
return [InputField(name="", description="", variable="question_text")]
|
25 |
+
|
26 |
+
|
27 |
+
def create_quizbowl_simple_step_initial_setup():
|
28 |
+
return ModelStep(
|
29 |
+
id="simple_step",
|
30 |
+
name="Quizbowl Simple Step",
|
31 |
+
model="",
|
32 |
+
provider="",
|
33 |
+
temperature=0.7,
|
34 |
+
call_type="llm",
|
35 |
+
system_prompt=INITIAL_SYS_PROMPT,
|
36 |
+
input_fields=[
|
37 |
+
InputField(name="question", description="The question to answer", variable="question"),
|
38 |
+
],
|
39 |
+
output_fields=[
|
40 |
+
OutputField(name="answer", description="The most likely answer", type="str"),
|
41 |
+
OutputField(name="confidence", description="The confidence of the answer", type="float"),
|
42 |
+
],
|
43 |
+
)
|
44 |
+
|
45 |
+
|
46 |
+
def create_new_llm_step(step_id: str, name: str) -> ModelStep:
|
47 |
+
return ModelStep(
|
48 |
+
id=step_id,
|
49 |
+
name=name,
|
50 |
+
model="gpt-4o",
|
51 |
+
provider="OpenAI",
|
52 |
+
call_type="llm",
|
53 |
+
temperature=0.7,
|
54 |
+
system_prompt="",
|
55 |
+
input_fields=create_empty_input_field(),
|
56 |
+
output_fields=[OutputField(name="", description="")],
|
57 |
+
)
|
58 |
+
|
59 |
+
|
60 |
+
def create_first_llm_step() -> ModelStep:
|
61 |
+
return ModelStep(
|
62 |
+
id="A",
|
63 |
+
name="",
|
64 |
+
model="gpt-4o",
|
65 |
+
provider="OpenAI",
|
66 |
+
call_type="llm",
|
67 |
+
temperature=0.7,
|
68 |
+
system_prompt="",
|
69 |
+
input_fields=[create_first_step_input_fields()],
|
70 |
+
output_fields=[OutputField(name="", description="")],
|
71 |
+
)
|
72 |
+
|
73 |
+
|
74 |
+
def create_quizbowl_simple_workflow():
|
75 |
+
return Workflow(
|
76 |
+
inputs=["question_text"],
|
77 |
+
outputs={"answer": "A.answer", "confidence": "A.confidence"},
|
78 |
+
steps={
|
79 |
+
"A": ModelStep(
|
80 |
+
id="A",
|
81 |
+
name="Tossup Agent",
|
82 |
+
model="gpt-4o-mini",
|
83 |
+
provider="OpenAI",
|
84 |
+
call_type="llm",
|
85 |
+
temperature=0.3,
|
86 |
+
system_prompt="You are a helpful assistant that can answer questions.",
|
87 |
+
input_fields=[InputField(name="question", description="The question text", variable="question_text")],
|
88 |
+
output_fields=[
|
89 |
+
OutputField(
|
90 |
+
name="answer",
|
91 |
+
description="The best guess at the answer to the question",
|
92 |
+
type="str",
|
93 |
+
),
|
94 |
+
OutputField(
|
95 |
+
name="confidence",
|
96 |
+
description="The confidence in the answer, ranging from 0 to 1 in increments of 0.05.",
|
97 |
+
type="float",
|
98 |
+
),
|
99 |
+
],
|
100 |
+
)
|
101 |
+
},
|
102 |
+
)
|
103 |
+
|
104 |
+
|
105 |
+
BONUS_SYS_PROMPT = """You are a quizbowl player answering bonus questions. For each part:
|
106 |
+
1. Read the leadin and part carefully
|
107 |
+
2. Provide a concise answer
|
108 |
+
3. Rate your confidence (0-1)
|
109 |
+
4. Explain your reasoning
|
110 |
+
|
111 |
+
Format your response as:
|
112 |
+
ANSWER: <your answer>
|
113 |
+
CONFIDENCE: <0-1>
|
114 |
+
EXPLANATION: <your reasoning>"""
|
115 |
+
|
116 |
+
|
117 |
+
def create_quizbowl_bonus_simple_workflow() -> Workflow:
|
118 |
+
"""Create a simple model step for bonus questions."""
|
119 |
+
return Workflow(
|
120 |
+
inputs=["leadin", "part"],
|
121 |
+
outputs={"answer": "A.answer", "confidence": "A.confidence", "explanation": "A.explanation"},
|
122 |
+
steps={
|
123 |
+
"A": ModelStep(
|
124 |
+
id="A",
|
125 |
+
name="Bonus Agent",
|
126 |
+
model="gpt-4o-mini",
|
127 |
+
provider="OpenAI",
|
128 |
+
temperature=0.3,
|
129 |
+
call_type="llm",
|
130 |
+
system_prompt=BONUS_SYS_PROMPT,
|
131 |
+
input_fields=[
|
132 |
+
InputField(
|
133 |
+
name="question_leadin",
|
134 |
+
description="The leadin text for the bonus question",
|
135 |
+
variable="leadin",
|
136 |
+
),
|
137 |
+
InputField(
|
138 |
+
name="question_part",
|
139 |
+
description="The specific part text to answer",
|
140 |
+
variable="part",
|
141 |
+
),
|
142 |
+
],
|
143 |
+
output_fields=[
|
144 |
+
OutputField(name="answer", description="The predicted answer", type="str"),
|
145 |
+
OutputField(name="confidence", description="Confidence in the answer (0-1)", type="float"),
|
146 |
+
OutputField(name="explanation", description="Short explanation for the answer", type="str"),
|
147 |
+
],
|
148 |
+
)
|
149 |
+
},
|
150 |
+
)
|
src/workflows/qb/__init__.py
ADDED
File without changes
|
src/workflows/qb/simple_agent.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from typing import Any, Iterable
|
3 |
+
|
4 |
+
# from litellm import completion
|
5 |
+
from llms import completion
|
6 |
+
from workflows.executors import execute_model_step, execute_workflow
|
7 |
+
from workflows.structs import ModelStep, Workflow
|
8 |
+
|
9 |
+
|
10 |
+
def _get_agent_response(self, prompt: str, system_prompt: str) -> dict:
|
11 |
+
"""Get response from the LLM model."""
|
12 |
+
messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}]
|
13 |
+
|
14 |
+
start_time = time.time()
|
15 |
+
response = completion(
|
16 |
+
model=self.model,
|
17 |
+
messages=messages,
|
18 |
+
temperature=self.temperature,
|
19 |
+
max_tokens=150, # Limit token usage for faster responses
|
20 |
+
)
|
21 |
+
response_time = time.time() - start_time
|
22 |
+
|
23 |
+
return response, response_time
|
24 |
+
|
25 |
+
|
26 |
+
def _get_model_step_response(
|
27 |
+
model_step: ModelStep, available_vars: dict[str, Any]
|
28 |
+
) -> tuple[dict[str, Any], str, float]:
|
29 |
+
"""Get response from the LLM model."""
|
30 |
+
start_time = time.time()
|
31 |
+
response, content = execute_model_step(model_step, available_vars, return_full_content=True)
|
32 |
+
response_time = time.time() - start_time
|
33 |
+
return response, content, response_time
|
34 |
+
|
35 |
+
|
36 |
+
def _get_workflow_response(workflow: Workflow, available_vars: dict[str, Any]) -> tuple[dict[str, Any], str, float]:
|
37 |
+
"""Get response from the LLM model."""
|
38 |
+
start_time = time.time()
|
39 |
+
response, content = execute_workflow(workflow, available_vars, return_full_content=True)
|
40 |
+
response_time = time.time() - start_time
|
41 |
+
return response, content, response_time
|
42 |
+
|
43 |
+
|
44 |
+
class SimpleTossupAgent:
|
45 |
+
external_input_variable = "question_text"
|
46 |
+
output_variables = ["answer", "confidence"]
|
47 |
+
|
48 |
+
def __init__(self, workflow: Workflow, buzz_threshold: float):
|
49 |
+
steps = list(workflow.steps.values())
|
50 |
+
assert len(steps) == 1, "Only one step is allowed in a simple workflow"
|
51 |
+
self.model_step = steps[0]
|
52 |
+
self.buzz_threshold = buzz_threshold
|
53 |
+
self.output_variables = list(workflow.outputs.keys())
|
54 |
+
|
55 |
+
if self.external_input_variable not in workflow.inputs:
|
56 |
+
raise ValueError(f"External input variable {self.external_input_variable} not found in model step inputs")
|
57 |
+
|
58 |
+
for out_var in self.output_variables:
|
59 |
+
if out_var not in workflow.outputs:
|
60 |
+
raise ValueError(f"Output variable {out_var} not found in the workflow outputs")
|
61 |
+
|
62 |
+
def run(self, question_runs: list[str], early_stop: bool = True) -> Iterable[dict]:
|
63 |
+
"""
|
64 |
+
Process a tossup question and decide when to buzz based on confidence.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
question_runs: Progressive reveals of the question text
|
68 |
+
early_stop: Whether to stop after the first buzz
|
69 |
+
|
70 |
+
Yields:
|
71 |
+
Dict with answer, confidence, and whether to buzz
|
72 |
+
"""
|
73 |
+
|
74 |
+
for i, question_text in enumerate(question_runs):
|
75 |
+
response, content, response_time = _get_model_step_response(
|
76 |
+
self.model_step, {self.external_input_variable: question_text}
|
77 |
+
)
|
78 |
+
buzz = response["confidence"] >= self.buzz_threshold
|
79 |
+
result = {
|
80 |
+
"answer": response["answer"],
|
81 |
+
"confidence": response["confidence"],
|
82 |
+
"buzz": buzz,
|
83 |
+
"question_fragment": question_text,
|
84 |
+
"position": i + 1,
|
85 |
+
"full_response": content,
|
86 |
+
"response_time": response_time,
|
87 |
+
}
|
88 |
+
|
89 |
+
yield result
|
90 |
+
|
91 |
+
# If we've reached the confidence threshold, buzz and stop
|
92 |
+
if early_stop and buzz:
|
93 |
+
return
|
94 |
+
|
95 |
+
|
96 |
+
class SimpleBonusAgent:
|
97 |
+
external_input_variables = ["leadin", "part"]
|
98 |
+
output_variables = ["answer", "confidence", "explanation"]
|
99 |
+
|
100 |
+
def __init__(self, workflow: Workflow):
|
101 |
+
steps = list(workflow.steps.values())
|
102 |
+
assert len(steps) == 1, "Only one step is allowed in a simple workflow"
|
103 |
+
self.model_step = steps[0]
|
104 |
+
self.output_variables = list(workflow.outputs.keys())
|
105 |
+
|
106 |
+
# Validate input variables
|
107 |
+
for input_var in self.external_input_variables:
|
108 |
+
if input_var not in workflow.inputs:
|
109 |
+
raise ValueError(f"External input variable {input_var} not found in model step inputs")
|
110 |
+
|
111 |
+
# Validate output variables
|
112 |
+
for out_var in self.output_variables:
|
113 |
+
if out_var not in workflow.outputs:
|
114 |
+
raise ValueError(f"Output variable {out_var} not found in the workflow outputs")
|
115 |
+
|
116 |
+
def run(self, leadin: str, part: str) -> dict:
|
117 |
+
"""
|
118 |
+
Process a bonus part with the given leadin.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
leadin: The leadin text for the bonus question
|
122 |
+
part: The specific part text to answer
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
Dict with answer, confidence, and explanation
|
126 |
+
"""
|
127 |
+
response, content, response_time = _get_model_step_response(
|
128 |
+
self.model_step,
|
129 |
+
{
|
130 |
+
"leadin": leadin,
|
131 |
+
"part": part,
|
132 |
+
},
|
133 |
+
)
|
134 |
+
|
135 |
+
return {
|
136 |
+
"answer": response["answer"],
|
137 |
+
"confidence": response["confidence"],
|
138 |
+
"explanation": response["explanation"],
|
139 |
+
"full_response": content,
|
140 |
+
"response_time": response_time,
|
141 |
+
}
|
142 |
+
|
143 |
+
|
144 |
+
# Example usage
|
145 |
+
if __name__ == "__main__":
|
146 |
+
# Load the Quizbowl dataset
|
147 |
+
from datasets import load_dataset
|
148 |
+
|
149 |
+
from workflows.factory import create_quizbowl_bonus_step_initial_setup, create_quizbowl_simple_step_initial_setup
|
150 |
+
|
151 |
+
ds_name = "umdclip/leaderboard_co_set"
|
152 |
+
ds = load_dataset(ds_name, split="train")
|
153 |
+
|
154 |
+
# Create the agents
|
155 |
+
tossup_step = create_quizbowl_simple_step_initial_setup()
|
156 |
+
tossup_step.model = "gpt-4"
|
157 |
+
tossup_step.provider = "openai"
|
158 |
+
tossup_agent = SimpleTossupAgent(workflow=tossup_step, buzz_threshold=0.9)
|
159 |
+
|
160 |
+
bonus_step = create_quizbowl_bonus_step_initial_setup()
|
161 |
+
bonus_step.model = "gpt-4"
|
162 |
+
bonus_step.provider = "openai"
|
163 |
+
bonus_agent = SimpleBonusAgent(workflow=bonus_step)
|
164 |
+
|
165 |
+
# Example for tossup mode
|
166 |
+
print("\n=== TOSSUP MODE EXAMPLE ===")
|
167 |
+
sample_question = ds[30]
|
168 |
+
print(sample_question["question_runs"][-1])
|
169 |
+
print(sample_question["gold_label"])
|
170 |
+
print()
|
171 |
+
question_runs = sample_question["question_runs"]
|
172 |
+
|
173 |
+
results = tossup_agent.run(question_runs, early_stop=True)
|
174 |
+
for result in results:
|
175 |
+
print(result["full_response"])
|
176 |
+
print(f"Guess at position {result['position']}: {result['answer']}")
|
177 |
+
print(f"Confidence: {result['confidence']}")
|
178 |
+
if result["buzz"]:
|
179 |
+
print("Buzzed!\n")
|
180 |
+
|
181 |
+
# Example for bonus mode
|
182 |
+
print("\n=== BONUS MODE EXAMPLE ===")
|
183 |
+
sample_bonus = ds[31] # Assuming this is a bonus question
|
184 |
+
leadin = sample_bonus["leadin"]
|
185 |
+
parts = sample_bonus["parts"]
|
186 |
+
|
187 |
+
print(f"Leadin: {leadin}")
|
188 |
+
for i, part in enumerate(parts):
|
189 |
+
print(f"\nPart {i + 1}: {part['part']}")
|
190 |
+
result = bonus_agent.run(leadin, part["part"])
|
191 |
+
print(f"Answer: {result['answer']}")
|
192 |
+
print(f"Confidence: {result['confidence']}")
|
193 |
+
print(f"Explanation: {result['explanation']}")
|
194 |
+
print(f"Response time: {result['response_time']:.2f}s")
|
src/workflows/quizbowl_agent.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# %%
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
6 |
+
|
7 |
+
import litellm
|
8 |
+
from datasets import load_dataset
|
9 |
+
from litellm import completion
|
10 |
+
|
11 |
+
litellm.drop_params = True
|
12 |
+
|
13 |
+
# Set your API key - you can replace this with your actual key or use environment variables
|
14 |
+
os.environ["OPENAI_API_KEY"] = (
|
15 |
+
"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"
|
16 |
+
)
|
17 |
+
|
18 |
+
DEFAULT_SYS_PROMPT = """
|
19 |
+
You are a Quizbowl expert. You will be given a question that's progressively revealed.
|
20 |
+
Your goal is to identify the answer as quickly as possible with high confidence.
|
21 |
+
Respond with a JSON object with two fields:
|
22 |
+
1. "answer": Your best guess for the answer
|
23 |
+
2. "confidence": Your confidence in your answer from 0.0 to 1.0
|
24 |
+
|
25 |
+
DO NOT include any explanation. ONLY return the JSON object.
|
26 |
+
"""
|
27 |
+
|
28 |
+
|
29 |
+
class QuizbowlAgent:
|
30 |
+
"""
|
31 |
+
An agent for playing Quizbowl with two modes:
|
32 |
+
1. Tossup mode: Fast and direct with confidence calibration for buzzing
|
33 |
+
2. Bonus round mode: Provides guess, rationale, and confidence
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
model: str = "gpt-4o-mini",
|
39 |
+
buzz_threshold: float = 0.85,
|
40 |
+
temperature: float = 0.2,
|
41 |
+
system_prompt: str = DEFAULT_SYS_PROMPT,
|
42 |
+
):
|
43 |
+
"""
|
44 |
+
Initialize the QuizbowlAgent.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
model: The LLM model to use for answering
|
48 |
+
buzz_threshold: Confidence threshold for buzzing in tossup mode (0-1)
|
49 |
+
temperature: Temperature for model sampling
|
50 |
+
"""
|
51 |
+
self.model = model
|
52 |
+
self.buzz_threshold = buzz_threshold
|
53 |
+
self.temperature = temperature
|
54 |
+
self.system_prompt = system_prompt
|
55 |
+
|
56 |
+
def _process_question_runs(self, question_runs: List[str]) -> List[str]:
|
57 |
+
"""Process question runs to extract increasing amounts of text."""
|
58 |
+
# For simpler testing, just return the runs as they are in the dataset
|
59 |
+
return question_runs
|
60 |
+
|
61 |
+
def _get_agent_response(self, prompt: str, system_prompt: str) -> Dict:
|
62 |
+
"""Get response from the LLM model."""
|
63 |
+
messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}]
|
64 |
+
|
65 |
+
start_time = time.time()
|
66 |
+
response = completion(
|
67 |
+
model=self.model,
|
68 |
+
messages=messages,
|
69 |
+
temperature=self.temperature,
|
70 |
+
max_tokens=150, # Limit token usage for faster responses
|
71 |
+
)
|
72 |
+
response_time = time.time() - start_time
|
73 |
+
|
74 |
+
return response, response_time
|
75 |
+
|
76 |
+
def _extract_confidence_and_answer(self, content: str) -> Tuple[str, float]:
|
77 |
+
"""Extract the answer and confidence score from the model response."""
|
78 |
+
try:
|
79 |
+
# Try to parse JSON from the response
|
80 |
+
data = json.loads(content)
|
81 |
+
answer = data.get("answer", "")
|
82 |
+
confidence = float(data.get("confidence", 0.0))
|
83 |
+
return answer, confidence
|
84 |
+
except (json.JSONDecodeError, ValueError):
|
85 |
+
# Fallback if parsing fails
|
86 |
+
lines = content.strip().split("\n")
|
87 |
+
answer = lines[0] if lines else ""
|
88 |
+
confidence = 0.5 # Default confidence
|
89 |
+
|
90 |
+
# Try to extract confidence from text
|
91 |
+
for line in lines:
|
92 |
+
if "confidence:" in line.lower():
|
93 |
+
try:
|
94 |
+
confidence = float(line.lower().split("confidence:")[1].strip())
|
95 |
+
except (ValueError, IndexError):
|
96 |
+
pass
|
97 |
+
|
98 |
+
return answer, confidence
|
99 |
+
|
100 |
+
def tossup_mode(self, question_runs: List[str]) -> Iterable[Dict]:
|
101 |
+
"""
|
102 |
+
Process a tossup question and decide when to buzz based on confidence.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
question_runs: Progressive reveals of the question text
|
106 |
+
|
107 |
+
Yields:
|
108 |
+
Dict with answer, confidence, and whether to buzz
|
109 |
+
"""
|
110 |
+
|
111 |
+
for i, question_text in enumerate(question_runs):
|
112 |
+
prompt = f"Question: {question_text}\n\nProvide your answer and confidence level:"
|
113 |
+
|
114 |
+
response, response_time = self._get_agent_response(prompt, DEFAULT_SYS_PROMPT)
|
115 |
+
content = response.choices[0].message.content
|
116 |
+
|
117 |
+
answer, confidence = self._extract_confidence_and_answer(content)
|
118 |
+
|
119 |
+
result = {
|
120 |
+
"answer": answer,
|
121 |
+
"confidence": confidence,
|
122 |
+
"buzz": confidence >= self.buzz_threshold,
|
123 |
+
"question_fragment": question_text,
|
124 |
+
"position": i + 1,
|
125 |
+
"full_response": content,
|
126 |
+
"response_time": response_time,
|
127 |
+
}
|
128 |
+
|
129 |
+
yield result
|
130 |
+
|
131 |
+
# If we've reached the confidence threshold, buzz and stop
|
132 |
+
if confidence >= self.buzz_threshold:
|
133 |
+
return
|
134 |
+
|
135 |
+
def tossup_mode_top5(self, question_runs: List[str]) -> Iterable[Dict]:
|
136 |
+
"""
|
137 |
+
Process a tossup question and provide the top 5 guesses with confidence levels.
|
138 |
+
|
139 |
+
Args:
|
140 |
+
question_runs: Progressive reveals of the question text
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
Dict with top 5 answers, their confidences, and whether to buzz
|
144 |
+
"""
|
145 |
+
|
146 |
+
for i, question_text in enumerate(question_runs):
|
147 |
+
prompt = f"Question: {question_text}\n\nProvide your top 5 answers and confidence levels."
|
148 |
+
|
149 |
+
response, response_time = self._get_agent_response(prompt, self.system_prompt)
|
150 |
+
content = response.choices[0].message.content
|
151 |
+
|
152 |
+
try:
|
153 |
+
# Try to parse JSON from the response
|
154 |
+
data = json.loads(content)
|
155 |
+
guesses = data.get("guesses", [])
|
156 |
+
except (json.JSONDecodeError, ValueError):
|
157 |
+
# Fallback if parsing fails
|
158 |
+
guesses = []
|
159 |
+
|
160 |
+
result = {
|
161 |
+
"guesses": guesses,
|
162 |
+
"buzz": any(guess["confidence"] >= self.buzz_threshold for guess in guesses),
|
163 |
+
"question_fragment": question_text,
|
164 |
+
"position": i + 1,
|
165 |
+
"full_response": content,
|
166 |
+
"response_time": response_time,
|
167 |
+
}
|
168 |
+
|
169 |
+
yield result
|
170 |
+
|
171 |
+
# If any guess reaches the confidence threshold, buzz and stop
|
172 |
+
if result["buzz"]:
|
173 |
+
return
|
174 |
+
|
175 |
+
def bonus_round_mode(self, question: str) -> Dict:
|
176 |
+
"""
|
177 |
+
Process a bonus round question with detailed analysis.
|
178 |
+
|
179 |
+
Args:
|
180 |
+
question: The bonus question text
|
181 |
+
|
182 |
+
Returns:
|
183 |
+
Dict with answer, rationale, and confidence
|
184 |
+
"""
|
185 |
+
system_prompt = """
|
186 |
+
You are a Quizbowl expert answering a bonus question. Provide:
|
187 |
+
1. Your direct answer
|
188 |
+
2. A very brief and crisp one line rationale for your answer (key clues that led to it)
|
189 |
+
3. Your confidence level (0.0-1.0)
|
190 |
+
|
191 |
+
Respond with a JSON object with these three fields:
|
192 |
+
{
|
193 |
+
"answer": "Your answer here",
|
194 |
+
"rationale": "Your reasoning here",
|
195 |
+
"confidence": 0.XX
|
196 |
+
}
|
197 |
+
"""
|
198 |
+
|
199 |
+
prompt = f"Bonus Question: {question}\n\nProvide your answer, rationale, and confidence:"
|
200 |
+
|
201 |
+
response = self._get_agent_response(prompt, system_prompt)
|
202 |
+
content = response.choices[0].message.content
|
203 |
+
|
204 |
+
try:
|
205 |
+
# Try to parse JSON
|
206 |
+
result = json.loads(content)
|
207 |
+
# Ensure all fields are present
|
208 |
+
if not all(k in result for k in ["answer", "rationale", "confidence"]):
|
209 |
+
raise ValueError("Missing fields in response")
|
210 |
+
except (json.JSONDecodeError, ValueError):
|
211 |
+
# If parsing fails, extract manually
|
212 |
+
lines = content.strip().split("\n")
|
213 |
+
result = {"answer": "", "rationale": "", "confidence": 0.5}
|
214 |
+
|
215 |
+
for line in lines:
|
216 |
+
if line.lower().startswith("answer:"):
|
217 |
+
result["answer"] = line[7:].strip()
|
218 |
+
elif line.lower().startswith("rationale:"):
|
219 |
+
result["rationale"] = line[10:].strip()
|
220 |
+
elif line.lower().startswith("confidence:"):
|
221 |
+
try:
|
222 |
+
result["confidence"] = float(line[11:].strip())
|
223 |
+
except ValueError:
|
224 |
+
pass
|
225 |
+
|
226 |
+
return result
|
227 |
+
|
228 |
+
|
229 |
+
# %%
|
230 |
+
# Example usage
|
231 |
+
if __name__ == "__main__":
|
232 |
+
# Load the Quizbowl dataset
|
233 |
+
ds_name = "umdclip/leaderboard_co_set"
|
234 |
+
ds = load_dataset(ds_name, split="train")
|
235 |
+
|
236 |
+
# Create the agent
|
237 |
+
agent = QuizbowlAgent(model="gpt-4-turbo", buzz_threshold=0.85)
|
238 |
+
|
239 |
+
# Example for tossup mode
|
240 |
+
print("\n=== TOSSUP MODE EXAMPLE ===")
|
241 |
+
sample_question = ds[0]
|
242 |
+
print(sample_question["question_runs"][-1])
|
243 |
+
print(sample_question["gold_label"])
|
244 |
+
question_runs = sample_question["question_runs"]
|
245 |
+
|
246 |
+
results = agent.tossup_mode(question_runs)
|
247 |
+
for result in results:
|
248 |
+
print(f"Guess at position {result['position']}: {result['answer']}")
|
249 |
+
print(f"Confidence: {result['confidence']}")
|
250 |
+
if result["buzz"]:
|
251 |
+
print("Buzzed!\n")
|
252 |
+
|
253 |
+
results = agent.tossup_mode_top5(question_runs)
|
254 |
+
for result in results:
|
255 |
+
guesses = [f"{guess['answer']} ({guess['confidence']})" for guess in result["guesses"]]
|
256 |
+
print(f"Guesses at position {result['position']}: {', '.join(guesses)}")
|
257 |
+
if result["buzz"]:
|
258 |
+
print("Buzzed!")
|
259 |
+
|
260 |
+
# Example for bonus round mode
|
261 |
+
print("\n=== BONUS ROUND MODE EXAMPLE ===")
|
262 |
+
bonus_question = sample_question["question_runs"][-1]
|
263 |
+
|
264 |
+
bonus_result = agent.bonus_round_mode(bonus_question)
|
265 |
+
print(f"Answer: {bonus_result['answer']}")
|
266 |
+
print(f"Rationale: {bonus_result['rationale']}")
|
267 |
+
print(f"Confidence: {bonus_result['confidence']}")
|
268 |
+
|
269 |
+
# %%
|
src/workflows/structs.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# %%
|
2 |
+
from typing import Any, Literal, Optional
|
3 |
+
|
4 |
+
from pydantic import BaseModel, Field, model_validator
|
5 |
+
|
6 |
+
"""
|
7 |
+
Core data structures for defining workflows and their components.
|
8 |
+
|
9 |
+
This module defines the primary classes used to model workflows, steps, and their
|
10 |
+
input/output fields. These data structures serve as the foundation for workflow
|
11 |
+
definition, validation, and execution throughout the workflows package.
|
12 |
+
|
13 |
+
The primary components are:
|
14 |
+
- InputField: Represents an input to a model step with name and source variable
|
15 |
+
- OutputField: Represents an output from a model step with name and type
|
16 |
+
- ModelStep: Represents a single step in a workflow with inputs and outputs
|
17 |
+
- Workflow: A collection of interconnected steps with defined inputs and outputs
|
18 |
+
|
19 |
+
All classes use Pydantic's BaseModel for validation and serialization support.
|
20 |
+
"""
|
21 |
+
FieldType = Literal["input", "output"]
|
22 |
+
|
23 |
+
SUPPORTED_TYPES = Literal["str", "int", "float", "bool", "list[str]", "list[int]", "list[float]", "list[bool]"]
|
24 |
+
"""Supported field types for input and output fields"""
|
25 |
+
|
26 |
+
|
27 |
+
class InputField(BaseModel):
|
28 |
+
"""
|
29 |
+
Defines an input field for a model step.
|
30 |
+
|
31 |
+
An input field specifies what data a step requires, where it comes from,
|
32 |
+
and optional pre-processing to apply before use.
|
33 |
+
|
34 |
+
Attributes:
|
35 |
+
name: The name of the input field within the step's context
|
36 |
+
description: Human-readable description of the input's purpose
|
37 |
+
variable: Reference to the source variable (format: "{step_id}.{field_name}" or external input name)
|
38 |
+
func: Optional function name to transform the input value before use
|
39 |
+
"""
|
40 |
+
|
41 |
+
name: str
|
42 |
+
description: str
|
43 |
+
variable: str
|
44 |
+
|
45 |
+
# function to call on the input before passing it to the model
|
46 |
+
func: str | None = None
|
47 |
+
|
48 |
+
|
49 |
+
class OutputField(BaseModel):
|
50 |
+
"""
|
51 |
+
Defines an output field produced by a model step.
|
52 |
+
|
53 |
+
An output field specifies a value that the step will produce, including
|
54 |
+
its data type and optional post-processing.
|
55 |
+
|
56 |
+
Attributes:
|
57 |
+
name: The name of the output field within the step's context
|
58 |
+
description: Human-readable description of the output's purpose
|
59 |
+
type: The data type of the output (one of SUPPORTED_TYPES)
|
60 |
+
func: Optional function name to transform the raw output value
|
61 |
+
"""
|
62 |
+
|
63 |
+
name: str
|
64 |
+
type: SUPPORTED_TYPES = Field(default="str")
|
65 |
+
description: str
|
66 |
+
|
67 |
+
# function to call on the output string from the model
|
68 |
+
func: str | None = None
|
69 |
+
|
70 |
+
|
71 |
+
class ModelStep(BaseModel):
|
72 |
+
"""
|
73 |
+
Represents a single step in a workflow.
|
74 |
+
|
75 |
+
A model step encapsulates the details of a specific operation within a workflow,
|
76 |
+
including what model to use, what inputs it requires, and what outputs it produces.
|
77 |
+
|
78 |
+
Attributes:
|
79 |
+
id: Unique identifier for this step within a workflow
|
80 |
+
model: The model to use for this step (e.g., "gpt-4")
|
81 |
+
provider: The provider of the model (e.g., "openai")
|
82 |
+
call_type: The type of operation (e.g., "llm", "search")
|
83 |
+
system_prompt: Instructions for the model
|
84 |
+
input_fields: List of input fields required by this step
|
85 |
+
output_fields: List of output fields produced by this step
|
86 |
+
"""
|
87 |
+
|
88 |
+
id: str
|
89 |
+
name: str
|
90 |
+
model: str
|
91 |
+
provider: str
|
92 |
+
call_type: str = "llm" # llm, search, etc # TODO: make this enum or provide explicit options using Literal
|
93 |
+
|
94 |
+
# TODO: Validate that this is not None for call_type = llm
|
95 |
+
temperature: Optional[float] = None
|
96 |
+
|
97 |
+
system_prompt: str
|
98 |
+
input_fields: list[InputField]
|
99 |
+
output_fields: list[OutputField]
|
100 |
+
|
101 |
+
def fields(self, field_type: FieldType) -> list[InputField | OutputField]:
|
102 |
+
return self.input_fields if field_type == "input" else self.output_fields
|
103 |
+
|
104 |
+
def get_full_model_name(self):
|
105 |
+
return f"{self.provider} {self.model}"
|
106 |
+
|
107 |
+
def get_produced_variables(self) -> list[str]:
|
108 |
+
return [f"{self.id}.{field.name}" for field in self.output_fields if field.name]
|
109 |
+
|
110 |
+
def update(self, update: dict[str, Any]) -> "ModelStep":
|
111 |
+
return self.model_copy(update=update)
|
112 |
+
|
113 |
+
def update_property(self, field: str, value: Any) -> "ModelStep":
|
114 |
+
"Update the `field` key of the model step with `value`."
|
115 |
+
return self.update({field: value})
|
116 |
+
|
117 |
+
def update_field(self, field_type: FieldType, index: int, key: str, value: str) -> "ModelStep":
|
118 |
+
"""Update a specific field of an input or output field at the given index."""
|
119 |
+
if field_type == "input":
|
120 |
+
fields = self.input_fields
|
121 |
+
elif field_type == "output":
|
122 |
+
fields = self.output_fields
|
123 |
+
else:
|
124 |
+
raise ValueError(f"Invalid field type: {field_type}")
|
125 |
+
|
126 |
+
if index < len(fields):
|
127 |
+
fields[index] = fields[index].model_copy(update={key: value})
|
128 |
+
return self.model_copy()
|
129 |
+
|
130 |
+
@staticmethod
|
131 |
+
def create_new_field(field_type: FieldType, input_var: str | None = None) -> InputField | OutputField:
|
132 |
+
if field_type == "input":
|
133 |
+
return InputField(name="", description="", variable=input_var)
|
134 |
+
elif field_type == "output":
|
135 |
+
return OutputField(name="", description="")
|
136 |
+
else:
|
137 |
+
raise ValueError(f"Invalid field type: {field_type}")
|
138 |
+
|
139 |
+
def add_field(self, field_type: FieldType, index: int = -1, input_var: str | None = None) -> "ModelStep":
|
140 |
+
"""Add a new field to the state and update visibility.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
field_type: Type of field to add ('input' or 'output').
|
144 |
+
index: Position to insert the new field (-1 to append).
|
145 |
+
Returns:
|
146 |
+
A new ModelStep with the updated fields.
|
147 |
+
"""
|
148 |
+
new_step = self.model_copy()
|
149 |
+
fields = new_step.input_fields if field_type == "input" else new_step.output_fields
|
150 |
+
new_field = ModelStep.create_new_field(field_type, input_var)
|
151 |
+
fields.insert(index + 1, new_field) if index != -1 else fields.append(new_field)
|
152 |
+
return new_step
|
153 |
+
|
154 |
+
def delete_field(self, field_type: FieldType, index: int) -> "ModelStep":
|
155 |
+
"""
|
156 |
+
Delete an input or output field from the state and update visibility.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
field_type: Type of field to delete ('input' or 'output').
|
160 |
+
index: Index of the field to delete. [-1 to delete the last field]
|
161 |
+
|
162 |
+
Returns:
|
163 |
+
A new ModelStep with the updated fields.
|
164 |
+
"""
|
165 |
+
new_step = self.model_copy()
|
166 |
+
fields = new_step.input_fields if field_type == "input" else new_step.output_fields
|
167 |
+
fields.pop(index)
|
168 |
+
return new_step
|
169 |
+
|
170 |
+
|
171 |
+
class Workflow(BaseModel):
|
172 |
+
"""
|
173 |
+
Represents a complete workflow composed of interconnected steps.
|
174 |
+
|
175 |
+
A workflow defines a directed acyclic graph of model steps, where outputs
|
176 |
+
from earlier steps can be used as inputs to later steps.
|
177 |
+
|
178 |
+
Attributes:
|
179 |
+
inputs: List of input variables required by the workflow
|
180 |
+
outputs: List of output variables produced by the workflow
|
181 |
+
steps: Dictionary mapping step IDs to ModelStep instances
|
182 |
+
|
183 |
+
The inputs and outputs lists use the format "{step_id}.{field_name}"
|
184 |
+
to uniquely identify variables within the workflow.
|
185 |
+
"""
|
186 |
+
|
187 |
+
# variables of form {node}.{field}
|
188 |
+
inputs: list[str] = Field(default_factory=list)
|
189 |
+
|
190 |
+
# variables of form {node}.{field}
|
191 |
+
outputs: dict[str, str | None] = Field(default_factory=dict)
|
192 |
+
steps: dict[str, ModelStep] = Field(default_factory=dict)
|
193 |
+
|
194 |
+
def model_dump(self, *args, **kwargs):
|
195 |
+
data = super().model_dump(*args, **kwargs)
|
196 |
+
data["steps"] = list(data["steps"].values())
|
197 |
+
return data
|
198 |
+
|
199 |
+
@model_validator(mode="before")
|
200 |
+
def dictify_steps(cls, data):
|
201 |
+
if "steps" in data and isinstance(data["steps"], list):
|
202 |
+
steps_dict = {}
|
203 |
+
for step in data["steps"]:
|
204 |
+
if step["id"] in steps_dict:
|
205 |
+
raise ValueError(f"Duplicate step ID: {step['id']}")
|
206 |
+
steps_dict[step["id"]] = step
|
207 |
+
data["steps"] = steps_dict
|
208 |
+
return data
|
209 |
+
|
210 |
+
def get_step_variables(self, step_id: str) -> list[str]:
|
211 |
+
"""Get all variables from a specific step."""
|
212 |
+
step = self.steps[step_id]
|
213 |
+
variables = []
|
214 |
+
for output in step.output_fields:
|
215 |
+
if output.name == "":
|
216 |
+
continue
|
217 |
+
output_var = f"{step.id}.{output.name}"
|
218 |
+
variables.append(output_var)
|
219 |
+
return variables
|
220 |
+
|
221 |
+
def get_available_variables(self) -> list[str]:
|
222 |
+
"""Get all output variables from all steps."""
|
223 |
+
variables = set(self.inputs)
|
224 |
+
for step in self.steps.values():
|
225 |
+
variables.update(self.get_step_variables(step.id))
|
226 |
+
return list(variables)
|
227 |
+
|
228 |
+
|
229 |
+
# %%
|
src/workflows/utils.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import deque
|
2 |
+
from typing import Any
|
3 |
+
|
4 |
+
from workflows.errors import CyclicDependencyError, UnknownVariableError, WorkflowError
|
5 |
+
from workflows.structs import ModelStep, Workflow
|
6 |
+
|
7 |
+
"""
|
8 |
+
Utilities for workflow dependency management and execution order determination.
|
9 |
+
|
10 |
+
This module provides functions for analyzing workflows, determining dependencies between steps,
|
11 |
+
and calculating the correct execution order to ensure all dependencies are satisfied.
|
12 |
+
Key functionality includes:
|
13 |
+
|
14 |
+
- Variable to step mapping: Identifying which step produces each variable
|
15 |
+
- Dependency graph creation: Building a graph representing dependencies between steps
|
16 |
+
- Topological sorting: Determining a valid execution order based on dependencies
|
17 |
+
- Cycle detection: Identifying cyclic dependencies that would prevent execution
|
18 |
+
|
19 |
+
These utilities form the foundation for workflow validation and execution in the
|
20 |
+
workflow_executor module.
|
21 |
+
"""
|
22 |
+
|
23 |
+
|
24 |
+
def _create_variable_step_mapping(workflow: Workflow) -> dict[str, str]:
|
25 |
+
"""
|
26 |
+
Creates a mapping from produced variable names to the model step that produces them.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
workflow (Workflow): The workflow containing steps and their input/output fields.
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
dict[str, str]: A dictionary where keys are variable names (formatted as "{step_id}.{output name}")
|
33 |
+
and values are the step IDs that produce them.
|
34 |
+
|
35 |
+
Raises:
|
36 |
+
WorkflowError: If there are duplicate step IDs or if a variable is produced by multiple steps.
|
37 |
+
|
38 |
+
Example:
|
39 |
+
For a workflow with steps "extract" and "summarize" each producing outputs:
|
40 |
+
>>> _create_variable_step_mapping(workflow)
|
41 |
+
{'extract.keywords': 'extract', 'summarize.summary': 'summarize'}
|
42 |
+
"""
|
43 |
+
variable_step_map: dict[str, str] = {} # variable name -> step id
|
44 |
+
for step_id, step in workflow.steps.items():
|
45 |
+
for output in step.output_fields:
|
46 |
+
var_name = f"{step_id}.{output.name}"
|
47 |
+
if var_name in variable_step_map:
|
48 |
+
raise WorkflowError(f"Variable '{output.name}' has duplicate entry in step {step_id}")
|
49 |
+
variable_step_map[var_name] = step_id
|
50 |
+
return variable_step_map
|
51 |
+
|
52 |
+
|
53 |
+
def create_dependency_graph(workflow: Workflow, input_values: dict[str, Any]) -> dict[str, set[str]]:
|
54 |
+
"""
|
55 |
+
Creates a dependency graph from a workflow.
|
56 |
+
|
57 |
+
This function analyzes the workflow and determines which steps depend on others
|
58 |
+
based on their input/output relationships. A step depends on another if it requires
|
59 |
+
a variable that is produced by the other step. External inputs provided through
|
60 |
+
input_values don't create dependencies.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
workflow (Workflow): The workflow containing steps and their input/output fields.
|
64 |
+
input_values (dict[str, Any]): A dictionary of external input values provided to the workflow.
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
dict[str, set[str]]: A dictionary where keys are step IDs and values are sets of step IDs
|
68 |
+
that the key step depends on.
|
69 |
+
|
70 |
+
Raises:
|
71 |
+
UnknownVariableError: If an input field references a variable that is not provided
|
72 |
+
externally nor produced by any step.
|
73 |
+
|
74 |
+
Example:
|
75 |
+
For a workflow where step "classify" depends on output from "extract":
|
76 |
+
>>> create_dependency_graph(workflow, {})
|
77 |
+
{'extract': set(), 'classify': {'extract'}}
|
78 |
+
|
79 |
+
With external input provided for "text" variable:
|
80 |
+
>>> create_dependency_graph(workflow, {'text': 'Sample text'})
|
81 |
+
{'extract': set(), 'classify': {'extract'}}
|
82 |
+
"""
|
83 |
+
produced_by = _create_variable_step_mapping(workflow)
|
84 |
+
dependencies: dict[str, set[str]] = {step_id: set() for step_id in workflow.steps}
|
85 |
+
for step_id, step in workflow.steps.items():
|
86 |
+
for input_field in step.input_fields:
|
87 |
+
var = input_field.variable
|
88 |
+
# If the variable was provided externally, then no dependency is needed.
|
89 |
+
if var in input_values:
|
90 |
+
continue
|
91 |
+
# Otherwise, check if the variable is produced by a step.
|
92 |
+
if var in produced_by:
|
93 |
+
producer_step_id = produced_by[var]
|
94 |
+
if producer_step_id != step_id: # Avoid self-dependency
|
95 |
+
dependencies[step_id].add(producer_step_id)
|
96 |
+
else:
|
97 |
+
raise UnknownVariableError(f"Variable '{var}' is not provided externally nor produced by any step")
|
98 |
+
return dependencies
|
99 |
+
|
100 |
+
|
101 |
+
def topological_sort(dependencies: dict[str, set[str]]) -> list[str]:
|
102 |
+
"""
|
103 |
+
Performs a topological sort on a dependency graph and detects cycles using Kahn's algorithm.
|
104 |
+
|
105 |
+
A topological sort orders the steps such that for every dependency from step A to step B,
|
106 |
+
step A comes before step B in the ordering. This ensures that all dependencies are satisfied
|
107 |
+
when executing steps in the returned order.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
dependencies (dict[str, set[str]]): A dictionary where each key is a node identifier and
|
111 |
+
each value is a set of nodes that the key node depends on.
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
list[str]: A list representing the nodes in topological order if no cycle is detected.
|
115 |
+
|
116 |
+
Raises:
|
117 |
+
CyclicDependencyError: If a cycle is detected in the graph.
|
118 |
+
|
119 |
+
Example:
|
120 |
+
>>> topological_sort({'A': set(), 'B': {'A'}, 'C': {'B'}})
|
121 |
+
['A', 'B', 'C']
|
122 |
+
|
123 |
+
>>> topological_sort({'A': {'B'}, 'B': {'A'}}) # Cyclic dependency
|
124 |
+
CyclicDependencyError
|
125 |
+
|
126 |
+
Algorithm:
|
127 |
+
This implementation uses Kahn's algorithm:
|
128 |
+
1. Calculate in-degree for all nodes (number of dependencies)
|
129 |
+
2. Start with nodes having 0 in-degree (no dependencies)
|
130 |
+
3. Process each node by removing its outgoing edges
|
131 |
+
4. Add newly dependency-free nodes to the processing queue
|
132 |
+
5. If not all nodes are processed, a cycle exists
|
133 |
+
"""
|
134 |
+
|
135 |
+
nodes = list(dependencies.keys())
|
136 |
+
dependents: dict[str, list[str]] = {node: [] for node in nodes}
|
137 |
+
in_degree: dict[str, int] = {node: 0 for node in nodes}
|
138 |
+
|
139 |
+
# Calculate in-degrees and build dependents list
|
140 |
+
for node, deps in dependencies.items():
|
141 |
+
in_degree[node] = len(deps)
|
142 |
+
for dep in deps:
|
143 |
+
dependents[dep].append(node)
|
144 |
+
|
145 |
+
# Initialize queue with nodes having zero in-degree
|
146 |
+
queue = deque([node for node, deg in in_degree.items() if deg == 0])
|
147 |
+
execution_order: list[str] = []
|
148 |
+
|
149 |
+
# Process nodes in topological order
|
150 |
+
while queue:
|
151 |
+
current = queue.popleft()
|
152 |
+
execution_order.append(current)
|
153 |
+
for dep in dependents[current]:
|
154 |
+
in_degree[dep] -= 1
|
155 |
+
if in_degree[dep] == 0:
|
156 |
+
queue.append(dep)
|
157 |
+
|
158 |
+
# If execution order includes all nodes, no cycle exists
|
159 |
+
if len(execution_order) != len(nodes):
|
160 |
+
raise CyclicDependencyError()
|
161 |
+
return execution_order
|
src/workflows/validators.py
ADDED
@@ -0,0 +1,586 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import keyword
|
2 |
+
import re
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from enum import Enum
|
5 |
+
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
6 |
+
|
7 |
+
from .structs import InputField, ModelStep, OutputField, Workflow
|
8 |
+
|
9 |
+
SUPPORTED_TYPES = {"str", "int", "float", "bool", "list[str]", "list[int]", "list[float]", "list[bool]"}
|
10 |
+
|
11 |
+
# Constants for validation
|
12 |
+
MAX_FIELD_NAME_LENGTH = 50
|
13 |
+
MAX_DESCRIPTION_LENGTH = 200
|
14 |
+
MAX_SYSTEM_PROMPT_LENGTH = 4000
|
15 |
+
MIN_TEMPERATURE = 0.0
|
16 |
+
MAX_TEMPERATURE = 1.0
|
17 |
+
|
18 |
+
|
19 |
+
class ValidationErrorType(Enum):
|
20 |
+
"""Types of validation errors that can occur"""
|
21 |
+
|
22 |
+
STEP = "step"
|
23 |
+
DAG = "dag"
|
24 |
+
VARIABLE = "variable"
|
25 |
+
TYPE = "type"
|
26 |
+
GENERAL = "general"
|
27 |
+
NAMING = "naming"
|
28 |
+
LENGTH = "length"
|
29 |
+
RANGE = "range"
|
30 |
+
|
31 |
+
|
32 |
+
@dataclass
|
33 |
+
class ValidationError:
|
34 |
+
"""Represents a validation error with type and message"""
|
35 |
+
|
36 |
+
error_type: ValidationErrorType
|
37 |
+
message: str
|
38 |
+
step_id: Optional[str] = None
|
39 |
+
field_name: Optional[str] = None
|
40 |
+
|
41 |
+
|
42 |
+
class WorkflowValidationError(Exception):
|
43 |
+
"""Base class for workflow validation errors"""
|
44 |
+
|
45 |
+
def __init__(self, errors: List[ValidationError]):
|
46 |
+
self.errors = errors
|
47 |
+
super().__init__(f"Workflow validation failed with {len(errors)} errors")
|
48 |
+
|
49 |
+
|
50 |
+
class WorkflowValidator:
|
51 |
+
"""Validates workflows for correctness and consistency"""
|
52 |
+
|
53 |
+
def __init__(self):
|
54 |
+
self.errors: List[ValidationError] = []
|
55 |
+
self.workflow: Optional[Workflow] = None
|
56 |
+
|
57 |
+
def validate(self, workflow: Workflow) -> bool:
|
58 |
+
"""Main validation entry point"""
|
59 |
+
self.errors = []
|
60 |
+
self.workflow = workflow
|
61 |
+
|
62 |
+
# Basic workflow validation
|
63 |
+
if not self._validate_workflow_basic(workflow):
|
64 |
+
return False
|
65 |
+
|
66 |
+
# If it's a single-step workflow, use simple validation
|
67 |
+
if len(workflow.steps) == 1:
|
68 |
+
return self.validate_simple_workflow(workflow)
|
69 |
+
|
70 |
+
# Otherwise use complex validation
|
71 |
+
return self.validate_complex_workflow(workflow)
|
72 |
+
|
73 |
+
def validate_simple_workflow(self, workflow: Workflow) -> bool:
|
74 |
+
"""Validates a single-step workflow"""
|
75 |
+
if not self.workflow:
|
76 |
+
return False
|
77 |
+
|
78 |
+
# Get the single step
|
79 |
+
step = next(iter(workflow.steps.values()))
|
80 |
+
|
81 |
+
# Validate the step itself
|
82 |
+
if not self._validate_step(step):
|
83 |
+
return False
|
84 |
+
|
85 |
+
# Validate input variables
|
86 |
+
for input_var in workflow.inputs:
|
87 |
+
if not self._is_valid_external_input(input_var):
|
88 |
+
self.errors.append(
|
89 |
+
ValidationError(ValidationErrorType.VARIABLE, f"Invalid input variable format: {input_var}")
|
90 |
+
)
|
91 |
+
return False
|
92 |
+
|
93 |
+
# Validate output variables
|
94 |
+
for output_name, output_var in workflow.outputs.items():
|
95 |
+
if not output_var:
|
96 |
+
self.errors.append(
|
97 |
+
ValidationError(ValidationErrorType.VARIABLE, f"Missing output variable for {output_name}")
|
98 |
+
)
|
99 |
+
return False
|
100 |
+
|
101 |
+
# Check if output variable references a valid step output
|
102 |
+
if not self._is_valid_variable_reference(output_var):
|
103 |
+
self.errors.append(
|
104 |
+
ValidationError(ValidationErrorType.VARIABLE, f"Invalid output variable reference: {output_var}")
|
105 |
+
)
|
106 |
+
return False
|
107 |
+
|
108 |
+
# Verify the output field exists in the step
|
109 |
+
_, field_name = self._parse_variable_reference(output_var)
|
110 |
+
if not any(field.name == field_name for field in step.output_fields):
|
111 |
+
self.errors.append(
|
112 |
+
ValidationError(
|
113 |
+
ValidationErrorType.VARIABLE,
|
114 |
+
f"Output field '{field_name}' not found in step '{step.id}'",
|
115 |
+
step.id,
|
116 |
+
field_name,
|
117 |
+
)
|
118 |
+
)
|
119 |
+
return False
|
120 |
+
|
121 |
+
return True
|
122 |
+
|
123 |
+
def validate_complex_workflow(self, workflow: Workflow) -> bool:
|
124 |
+
"""Validates a multi-step workflow"""
|
125 |
+
if not self.workflow:
|
126 |
+
return False
|
127 |
+
|
128 |
+
# Validate each step
|
129 |
+
for step in workflow.steps.values():
|
130 |
+
if not self._validate_step(step):
|
131 |
+
return False
|
132 |
+
|
133 |
+
# Validate input variables
|
134 |
+
for input_var in workflow.inputs:
|
135 |
+
if not self._is_valid_external_input(input_var):
|
136 |
+
self.errors.append(
|
137 |
+
ValidationError(ValidationErrorType.VARIABLE, f"Invalid input variable format: {input_var}")
|
138 |
+
)
|
139 |
+
return False
|
140 |
+
|
141 |
+
# Validate output variables
|
142 |
+
for output_name, output_var in workflow.outputs.items():
|
143 |
+
if not output_var:
|
144 |
+
self.errors.append(
|
145 |
+
ValidationError(ValidationErrorType.VARIABLE, f"Missing output variable for {output_name}")
|
146 |
+
)
|
147 |
+
return False
|
148 |
+
|
149 |
+
if not self._is_valid_variable_reference(output_var):
|
150 |
+
self.errors.append(
|
151 |
+
ValidationError(ValidationErrorType.VARIABLE, f"Invalid output variable reference: {output_var}")
|
152 |
+
)
|
153 |
+
return False
|
154 |
+
|
155 |
+
# Verify the output field exists in the referenced step
|
156 |
+
step_id, field_name = self._parse_variable_reference(output_var)
|
157 |
+
if step_id not in workflow.steps:
|
158 |
+
self.errors.append(
|
159 |
+
ValidationError(ValidationErrorType.VARIABLE, f"Referenced step '{step_id}' not found")
|
160 |
+
)
|
161 |
+
return False
|
162 |
+
|
163 |
+
ref_step = workflow.steps[step_id]
|
164 |
+
if not any(field.name == field_name for field in ref_step.output_fields):
|
165 |
+
self.errors.append(
|
166 |
+
ValidationError(
|
167 |
+
ValidationErrorType.VARIABLE,
|
168 |
+
f"Output field '{field_name}' not found in step '{step_id}'",
|
169 |
+
step_id,
|
170 |
+
field_name,
|
171 |
+
)
|
172 |
+
)
|
173 |
+
return False
|
174 |
+
|
175 |
+
# Build dependency graph
|
176 |
+
dep_graph: Dict[str, Set[str]] = {}
|
177 |
+
for step_id, step in workflow.steps.items():
|
178 |
+
dep_graph[step_id] = self._get_step_dependencies(step)
|
179 |
+
|
180 |
+
# Check for cycles in step dependencies
|
181 |
+
visited = set()
|
182 |
+
path = set()
|
183 |
+
|
184 |
+
def has_cycle(node: str) -> bool:
|
185 |
+
if node in path:
|
186 |
+
return True
|
187 |
+
if node in visited:
|
188 |
+
return False
|
189 |
+
|
190 |
+
visited.add(node)
|
191 |
+
path.add(node)
|
192 |
+
|
193 |
+
for neighbor in dep_graph.get(node, set()):
|
194 |
+
if has_cycle(neighbor):
|
195 |
+
return True
|
196 |
+
|
197 |
+
path.remove(node)
|
198 |
+
return False
|
199 |
+
|
200 |
+
# Check each step for cycles
|
201 |
+
for step_id in workflow.steps:
|
202 |
+
if has_cycle(step_id):
|
203 |
+
self.errors.append(
|
204 |
+
ValidationError(ValidationErrorType.DAG, f"Circular dependency detected involving step: {step_id}")
|
205 |
+
)
|
206 |
+
return False
|
207 |
+
|
208 |
+
# Check for orphaned steps (steps that aren't used by any other step)
|
209 |
+
used_steps = set()
|
210 |
+
for deps in dep_graph.values():
|
211 |
+
used_steps.update(deps)
|
212 |
+
print("Used steps: ", used_steps)
|
213 |
+
for step_id in workflow.steps:
|
214 |
+
if step_id not in used_steps and not any(
|
215 |
+
output_var and self._parse_variable_reference(output_var)[0] == step_id
|
216 |
+
for output_var in workflow.outputs.values()
|
217 |
+
):
|
218 |
+
self.errors.append(ValidationError(ValidationErrorType.DAG, f"Orphaned step detected: {step_id}"))
|
219 |
+
return False
|
220 |
+
|
221 |
+
# Validate variable dependencies
|
222 |
+
if not self._validate_variable_dependencies(workflow):
|
223 |
+
return False
|
224 |
+
|
225 |
+
return True
|
226 |
+
|
227 |
+
def _validate_workflow_basic(self, workflow: Workflow) -> bool:
|
228 |
+
"""Validates basic workflow properties"""
|
229 |
+
# Check for atleast one input
|
230 |
+
if not workflow.inputs:
|
231 |
+
self.errors.append(
|
232 |
+
ValidationError(ValidationErrorType.GENERAL, "Workflow must contain at least one input")
|
233 |
+
)
|
234 |
+
return False
|
235 |
+
|
236 |
+
if not workflow.outputs:
|
237 |
+
self.errors.append(
|
238 |
+
ValidationError(ValidationErrorType.GENERAL, "Workflow must contain at least one output")
|
239 |
+
)
|
240 |
+
return False
|
241 |
+
|
242 |
+
for output_var in workflow.outputs.values():
|
243 |
+
if output_var is None:
|
244 |
+
self.errors.append(ValidationError(ValidationErrorType.GENERAL, "Output variable cannot be None"))
|
245 |
+
return False
|
246 |
+
|
247 |
+
# Check for empty workflow
|
248 |
+
if not workflow.steps:
|
249 |
+
self.errors.append(ValidationError(ValidationErrorType.GENERAL, "Workflow must contain at least one step"))
|
250 |
+
return False
|
251 |
+
|
252 |
+
# Check for step ID consistency
|
253 |
+
for step_id, step in workflow.steps.items():
|
254 |
+
if step_id != step.id:
|
255 |
+
self.errors.append(
|
256 |
+
ValidationError(ValidationErrorType.STEP, f"Step ID mismatch: {step_id} != {step.id}", step_id)
|
257 |
+
)
|
258 |
+
return False
|
259 |
+
return True
|
260 |
+
|
261 |
+
def _validate_step(self, step: ModelStep) -> bool:
|
262 |
+
"""Validates a single step"""
|
263 |
+
# Validate required fields
|
264 |
+
if not step.id or not step.name or not step.model or not step.provider or not step.call_type:
|
265 |
+
self.errors.append(ValidationError(ValidationErrorType.STEP, "Step missing required fields", step.id))
|
266 |
+
return False
|
267 |
+
|
268 |
+
# Validate step ID and name
|
269 |
+
if not self._is_valid_identifier(step.id):
|
270 |
+
self.errors.append(
|
271 |
+
ValidationError(
|
272 |
+
ValidationErrorType.NAMING,
|
273 |
+
f"Invalid step ID format: {step.id}. Must be a valid Python identifier.",
|
274 |
+
step.id,
|
275 |
+
)
|
276 |
+
)
|
277 |
+
return False
|
278 |
+
|
279 |
+
# Validate temperature for LLM call type
|
280 |
+
if step.call_type == "llm":
|
281 |
+
if step.temperature is None:
|
282 |
+
self.errors.append(
|
283 |
+
ValidationError(ValidationErrorType.STEP, "LLM step must specify temperature", step.id)
|
284 |
+
)
|
285 |
+
return False
|
286 |
+
|
287 |
+
if not MIN_TEMPERATURE <= step.temperature <= MAX_TEMPERATURE:
|
288 |
+
self.errors.append(
|
289 |
+
ValidationError(
|
290 |
+
ValidationErrorType.RANGE,
|
291 |
+
f"Temperature must be between {MIN_TEMPERATURE} and {MAX_TEMPERATURE}",
|
292 |
+
step.id,
|
293 |
+
)
|
294 |
+
)
|
295 |
+
return False
|
296 |
+
|
297 |
+
# Validate system prompt for LLM call type
|
298 |
+
if step.call_type == "llm":
|
299 |
+
if not step.system_prompt:
|
300 |
+
self.errors.append(
|
301 |
+
ValidationError(ValidationErrorType.STEP, "LLM step must specify system prompt", step.id)
|
302 |
+
)
|
303 |
+
return False
|
304 |
+
|
305 |
+
if len(step.system_prompt) > MAX_SYSTEM_PROMPT_LENGTH:
|
306 |
+
self.errors.append(
|
307 |
+
ValidationError(
|
308 |
+
ValidationErrorType.LENGTH,
|
309 |
+
f"System prompt exceeds maximum length of {MAX_SYSTEM_PROMPT_LENGTH} characters",
|
310 |
+
step.id,
|
311 |
+
)
|
312 |
+
)
|
313 |
+
return False
|
314 |
+
|
315 |
+
# Validate input fields
|
316 |
+
input_names = set()
|
317 |
+
for field in step.input_fields:
|
318 |
+
if not self._validate_input_field(field):
|
319 |
+
return False
|
320 |
+
if field.name in input_names:
|
321 |
+
self.errors.append(
|
322 |
+
ValidationError(
|
323 |
+
ValidationErrorType.STEP, f"Duplicate input field name: {field.name}", step.id, field.name
|
324 |
+
)
|
325 |
+
)
|
326 |
+
return False
|
327 |
+
input_names.add(field.name)
|
328 |
+
|
329 |
+
# Validate output fields
|
330 |
+
output_names = set()
|
331 |
+
for field in step.output_fields:
|
332 |
+
if not self._validate_output_field(field):
|
333 |
+
return False
|
334 |
+
if field.name in output_names:
|
335 |
+
self.errors.append(
|
336 |
+
ValidationError(
|
337 |
+
ValidationErrorType.STEP, f"Duplicate output field name: {field.name}", step.id, field.name
|
338 |
+
)
|
339 |
+
)
|
340 |
+
return False
|
341 |
+
output_names.add(field.name)
|
342 |
+
|
343 |
+
return True
|
344 |
+
|
345 |
+
def _validate_input_field(self, field: InputField) -> bool:
|
346 |
+
"""Validates an input field"""
|
347 |
+
# Validate required fields
|
348 |
+
if not field.name or not field.description or not field.variable:
|
349 |
+
self.errors.append(
|
350 |
+
ValidationError(ValidationErrorType.STEP, "Input field missing required fields", field_name=field.name)
|
351 |
+
)
|
352 |
+
return False
|
353 |
+
|
354 |
+
# Validate field name
|
355 |
+
if not self._is_valid_identifier(field.name):
|
356 |
+
self.errors.append(
|
357 |
+
ValidationError(
|
358 |
+
ValidationErrorType.NAMING,
|
359 |
+
f"Invalid field name format: {field.name}. Must be a valid Python identifier.",
|
360 |
+
field_name=field.name,
|
361 |
+
)
|
362 |
+
)
|
363 |
+
return False
|
364 |
+
|
365 |
+
# Validate field name length
|
366 |
+
if len(field.name) > MAX_FIELD_NAME_LENGTH:
|
367 |
+
self.errors.append(
|
368 |
+
ValidationError(
|
369 |
+
ValidationErrorType.LENGTH,
|
370 |
+
f"Field name exceeds maximum length of {MAX_FIELD_NAME_LENGTH} characters",
|
371 |
+
field_name=field.name,
|
372 |
+
)
|
373 |
+
)
|
374 |
+
return False
|
375 |
+
|
376 |
+
# Validate description length
|
377 |
+
if len(field.description) > MAX_DESCRIPTION_LENGTH:
|
378 |
+
self.errors.append(
|
379 |
+
ValidationError(
|
380 |
+
ValidationErrorType.LENGTH,
|
381 |
+
f"Description exceeds maximum length of {MAX_DESCRIPTION_LENGTH} characters",
|
382 |
+
field_name=field.name,
|
383 |
+
)
|
384 |
+
)
|
385 |
+
return False
|
386 |
+
|
387 |
+
# Validate variable reference
|
388 |
+
if not self._is_valid_variable_reference(field.variable):
|
389 |
+
self.errors.append(
|
390 |
+
ValidationError(
|
391 |
+
ValidationErrorType.VARIABLE,
|
392 |
+
f"Invalid variable reference: {field.variable}",
|
393 |
+
field_name=field.name,
|
394 |
+
)
|
395 |
+
)
|
396 |
+
return False
|
397 |
+
|
398 |
+
return True
|
399 |
+
|
400 |
+
def _validate_output_field(self, field: OutputField) -> bool:
|
401 |
+
"""Validates an output field"""
|
402 |
+
# Validate required fields
|
403 |
+
if not field.name or not field.description:
|
404 |
+
self.errors.append(
|
405 |
+
ValidationError(
|
406 |
+
ValidationErrorType.STEP, "Output field missing required fields", field_name=field.name
|
407 |
+
)
|
408 |
+
)
|
409 |
+
return False
|
410 |
+
|
411 |
+
# Validate field name
|
412 |
+
if not self._is_valid_identifier(field.name):
|
413 |
+
self.errors.append(
|
414 |
+
ValidationError(
|
415 |
+
ValidationErrorType.NAMING,
|
416 |
+
f"Invalid field name format: {field.name}. Must be a valid Python identifier.",
|
417 |
+
field_name=field.name,
|
418 |
+
)
|
419 |
+
)
|
420 |
+
return False
|
421 |
+
|
422 |
+
# Validate field name length
|
423 |
+
if len(field.name) > MAX_FIELD_NAME_LENGTH:
|
424 |
+
self.errors.append(
|
425 |
+
ValidationError(
|
426 |
+
ValidationErrorType.LENGTH,
|
427 |
+
f"Field name exceeds maximum length of {MAX_FIELD_NAME_LENGTH} characters",
|
428 |
+
field_name=field.name,
|
429 |
+
)
|
430 |
+
)
|
431 |
+
return False
|
432 |
+
|
433 |
+
# Validate description length
|
434 |
+
if len(field.description) > MAX_DESCRIPTION_LENGTH:
|
435 |
+
self.errors.append(
|
436 |
+
ValidationError(
|
437 |
+
ValidationErrorType.LENGTH,
|
438 |
+
f"Description exceeds maximum length of {MAX_DESCRIPTION_LENGTH} characters",
|
439 |
+
field_name=field.name,
|
440 |
+
)
|
441 |
+
)
|
442 |
+
return False
|
443 |
+
|
444 |
+
# Validate type
|
445 |
+
if field.type not in SUPPORTED_TYPES:
|
446 |
+
self.errors.append(
|
447 |
+
ValidationError(
|
448 |
+
ValidationErrorType.TYPE, f"Unsupported output type: {field.type}", field_name=field.name
|
449 |
+
)
|
450 |
+
)
|
451 |
+
return False
|
452 |
+
|
453 |
+
return True
|
454 |
+
|
455 |
+
def _validate_simple_workflow_variables(self, workflow: Workflow) -> bool:
|
456 |
+
"""Validates variables in a simple workflow"""
|
457 |
+
step = next(iter(workflow.steps.values()))
|
458 |
+
|
459 |
+
# Validate input variables
|
460 |
+
for input_var in workflow.inputs:
|
461 |
+
if not self._is_valid_external_input(input_var):
|
462 |
+
self.errors.append(
|
463 |
+
ValidationError(ValidationErrorType.VARIABLE, f"Invalid input variable format: {input_var}")
|
464 |
+
)
|
465 |
+
return False
|
466 |
+
|
467 |
+
# Validate output variables
|
468 |
+
for output_name, output_var in workflow.outputs.items():
|
469 |
+
if output_var and not self._is_valid_variable_reference(output_var):
|
470 |
+
self.errors.append(
|
471 |
+
ValidationError(ValidationErrorType.VARIABLE, f"Invalid output variable reference: {output_var}")
|
472 |
+
)
|
473 |
+
return False
|
474 |
+
|
475 |
+
return True
|
476 |
+
|
477 |
+
def _validate_variable_dependencies(self, workflow: Workflow) -> bool:
|
478 |
+
"""Validates variable dependencies between steps"""
|
479 |
+
# Build variable dependency graph
|
480 |
+
var_graph: Dict[str, Set[str]] = {}
|
481 |
+
|
482 |
+
for step_id, step in workflow.steps.items():
|
483 |
+
for field in step.input_fields:
|
484 |
+
if field.variable not in var_graph:
|
485 |
+
var_graph[field.variable] = set()
|
486 |
+
|
487 |
+
# Add dependency from input variable to step's outputs
|
488 |
+
for output in step.output_fields:
|
489 |
+
var_graph[field.variable].add(f"{step_id}.{output.name}")
|
490 |
+
|
491 |
+
# Check for cycles in variable dependencies
|
492 |
+
visited = set()
|
493 |
+
path = set()
|
494 |
+
|
495 |
+
def has_cycle(node: str) -> bool:
|
496 |
+
if node in path:
|
497 |
+
return True
|
498 |
+
if node in visited:
|
499 |
+
return False
|
500 |
+
|
501 |
+
visited.add(node)
|
502 |
+
path.add(node)
|
503 |
+
|
504 |
+
for neighbor in var_graph.get(node, set()):
|
505 |
+
if has_cycle(neighbor):
|
506 |
+
return True
|
507 |
+
|
508 |
+
path.remove(node)
|
509 |
+
return False
|
510 |
+
|
511 |
+
# Check each variable for cycles
|
512 |
+
for var in var_graph:
|
513 |
+
if has_cycle(var):
|
514 |
+
self.errors.append(
|
515 |
+
ValidationError(ValidationErrorType.VARIABLE, f"Circular variable dependency detected: {var}")
|
516 |
+
)
|
517 |
+
return False
|
518 |
+
|
519 |
+
# Validate external input existence
|
520 |
+
external_inputs = set(workflow.inputs)
|
521 |
+
for step in workflow.steps.values():
|
522 |
+
for field in step.input_fields:
|
523 |
+
step_id, field_name = self._parse_variable_reference(field.variable)
|
524 |
+
if not step_id and field_name not in external_inputs:
|
525 |
+
self.errors.append(
|
526 |
+
ValidationError(
|
527 |
+
ValidationErrorType.VARIABLE,
|
528 |
+
f"External input '{field_name}' not found in workflow inputs",
|
529 |
+
field_name=field_name,
|
530 |
+
)
|
531 |
+
)
|
532 |
+
return False
|
533 |
+
|
534 |
+
return True
|
535 |
+
|
536 |
+
def _get_step_dependencies(self, step: ModelStep) -> Set[str]:
|
537 |
+
"""Gets set of step IDs that this step depends on"""
|
538 |
+
deps = set()
|
539 |
+
for field in step.input_fields:
|
540 |
+
step_id = self._parse_variable_reference(field.variable)[0]
|
541 |
+
if step_id:
|
542 |
+
deps.add(step_id)
|
543 |
+
return deps
|
544 |
+
|
545 |
+
def _parse_variable_reference(self, var: str) -> Tuple[Optional[str], str]:
|
546 |
+
"""Extracts step_id and field_name from variable reference"""
|
547 |
+
parts = var.split(".")
|
548 |
+
if len(parts) == 1:
|
549 |
+
return None, parts[0]
|
550 |
+
return parts[0], parts[1]
|
551 |
+
|
552 |
+
def _is_valid_variable_reference(self, var: str) -> bool:
|
553 |
+
"""Validates if a variable reference is properly formatted"""
|
554 |
+
if not self.workflow:
|
555 |
+
return False
|
556 |
+
parts = var.split(".")
|
557 |
+
if len(parts) == 1:
|
558 |
+
return True # External input
|
559 |
+
if len(parts) != 2:
|
560 |
+
return False
|
561 |
+
step_id, field_name = parts
|
562 |
+
return step_id in self.workflow.steps and any(
|
563 |
+
field.name == field_name for field in self.workflow.steps[step_id].output_fields
|
564 |
+
)
|
565 |
+
|
566 |
+
def _is_valid_external_input(self, var: str) -> bool:
|
567 |
+
"""Validates if a variable is a valid external input"""
|
568 |
+
if not var:
|
569 |
+
return False
|
570 |
+
if not self._is_valid_identifier(var):
|
571 |
+
return False
|
572 |
+
if keyword.iskeyword(var):
|
573 |
+
return False
|
574 |
+
if "." in var: # External inputs should not contain dots
|
575 |
+
return False
|
576 |
+
return True
|
577 |
+
|
578 |
+
def _is_valid_identifier(self, name: str) -> bool:
|
579 |
+
"""Validates if a string is a valid Python identifier"""
|
580 |
+
if not name:
|
581 |
+
return False
|
582 |
+
if keyword.iskeyword(name):
|
583 |
+
return False
|
584 |
+
if not name.strip(): # Check for whitespace-only strings
|
585 |
+
return False
|
586 |
+
return bool(re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", name))
|
tests/conftest.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
# Add the src directory to the PYTHONPATH
|
5 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../src"))
|
tests/test_executors.py
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from unittest.mock import patch
|
3 |
+
|
4 |
+
import pytest
|
5 |
+
|
6 |
+
from workflows.errors import CyclicDependencyError, WorkflowError
|
7 |
+
from workflows.executors import (
|
8 |
+
create_processed_inputs,
|
9 |
+
execute_model_step,
|
10 |
+
execute_workflow,
|
11 |
+
lower,
|
12 |
+
upper,
|
13 |
+
)
|
14 |
+
from workflows.structs import InputField, ModelStep, OutputField, Workflow
|
15 |
+
|
16 |
+
# Tests for utility functions
|
17 |
+
|
18 |
+
|
19 |
+
def test_upper():
|
20 |
+
"""Test the upper function with different input types."""
|
21 |
+
assert upper("hello") == "HELLO"
|
22 |
+
assert upper("Hello World") == "HELLO WORLD"
|
23 |
+
assert upper("") == ""
|
24 |
+
# Non-string inputs should be returned unchanged
|
25 |
+
assert upper(123) == 123
|
26 |
+
assert upper([1, 2, 3]) == [1, 2, 3]
|
27 |
+
assert upper(None) is None
|
28 |
+
|
29 |
+
|
30 |
+
def test_lower():
|
31 |
+
"""Test the lower function with different input types."""
|
32 |
+
assert lower("HELLO") == "hello"
|
33 |
+
assert lower("Hello World") == "hello world"
|
34 |
+
assert lower("") == ""
|
35 |
+
# Non-string inputs should be returned unchanged
|
36 |
+
assert lower(123) == 123
|
37 |
+
assert lower([1, 2, 3]) == [1, 2, 3]
|
38 |
+
assert lower(None) is None
|
39 |
+
|
40 |
+
|
41 |
+
# Tests for create_processed_inputs
|
42 |
+
|
43 |
+
|
44 |
+
def test_create_processed_inputs_basic():
|
45 |
+
"""Test basic input processing without transformations."""
|
46 |
+
step = ModelStep(
|
47 |
+
id="test_step",
|
48 |
+
model="gpt-4",
|
49 |
+
provider="openai",
|
50 |
+
call_type="llm",
|
51 |
+
system_prompt="Test prompt",
|
52 |
+
input_fields=[InputField(name="text", description="Input text", variable="input_text")],
|
53 |
+
output_fields=[],
|
54 |
+
)
|
55 |
+
available_vars = {"input_text": "Hello World"}
|
56 |
+
|
57 |
+
result = create_processed_inputs(step, available_vars)
|
58 |
+
assert result == {"text": "Hello World"}
|
59 |
+
|
60 |
+
|
61 |
+
def test_create_processed_inputs_with_transformation():
|
62 |
+
"""Test input processing with transformation functions."""
|
63 |
+
step = ModelStep(
|
64 |
+
id="test_step",
|
65 |
+
model="gpt-4",
|
66 |
+
provider="openai",
|
67 |
+
call_type="llm",
|
68 |
+
system_prompt="Test prompt",
|
69 |
+
input_fields=[
|
70 |
+
InputField(name="upper_text", description="Uppercase text", variable="input_text", func="upper"),
|
71 |
+
InputField(name="lower_text", description="Lowercase text", variable="input_caps", func="lower"),
|
72 |
+
],
|
73 |
+
output_fields=[],
|
74 |
+
)
|
75 |
+
available_vars = {"input_text": "hello", "input_caps": "WORLD"}
|
76 |
+
|
77 |
+
result = create_processed_inputs(step, available_vars)
|
78 |
+
assert result == {"upper_text": "HELLO", "lower_text": "world"}
|
79 |
+
|
80 |
+
|
81 |
+
def test_create_processed_inputs_missing_var():
|
82 |
+
"""Test that appropriate error is raised when a variable is missing."""
|
83 |
+
step = ModelStep(
|
84 |
+
id="test_step",
|
85 |
+
model="gpt-4",
|
86 |
+
provider="openai",
|
87 |
+
call_type="llm",
|
88 |
+
system_prompt="Test prompt",
|
89 |
+
input_fields=[InputField(name="text", description="Input text", variable="missing_var")],
|
90 |
+
output_fields=[],
|
91 |
+
)
|
92 |
+
available_vars = {"input_text": "Hello World"}
|
93 |
+
|
94 |
+
with pytest.raises(KeyError):
|
95 |
+
create_processed_inputs(step, available_vars)
|
96 |
+
|
97 |
+
|
98 |
+
def test_create_processed_inputs_unknown_func():
|
99 |
+
"""Test that appropriate error is raised when an unknown function is specified."""
|
100 |
+
step = ModelStep(
|
101 |
+
id="test_step",
|
102 |
+
model="gpt-4",
|
103 |
+
provider="openai",
|
104 |
+
call_type="llm",
|
105 |
+
system_prompt="Test prompt",
|
106 |
+
input_fields=[InputField(name="text", description="Input text", variable="input_text", func="unknown_func")],
|
107 |
+
output_fields=[],
|
108 |
+
)
|
109 |
+
available_vars = {"input_text": "Hello World"}
|
110 |
+
|
111 |
+
# This should raise an error when the function isn't found
|
112 |
+
with pytest.raises(Exception):
|
113 |
+
create_processed_inputs(step, available_vars)
|
114 |
+
|
115 |
+
|
116 |
+
# Tests for execute_model_step
|
117 |
+
|
118 |
+
|
119 |
+
@patch("workflows.executors.litellm.completion")
|
120 |
+
def test_execute_model_step_success(mock_completion):
|
121 |
+
"""Test successful execution of a model step with mocked litellm response."""
|
122 |
+
# Mock the litellm response
|
123 |
+
mock_response = {"choices": [{"message": {"content": json.dumps({"summary": "This is a summary"})}}]}
|
124 |
+
mock_completion.return_value = mock_response
|
125 |
+
|
126 |
+
# Create a test step
|
127 |
+
step = ModelStep(
|
128 |
+
id="summarize",
|
129 |
+
model="gpt-3.5-turbo",
|
130 |
+
provider="openai",
|
131 |
+
call_type="llm",
|
132 |
+
system_prompt="Summarize the text",
|
133 |
+
input_fields=[InputField(name="text", description="Text to summarize", variable="input_text")],
|
134 |
+
output_fields=[OutputField(name="summary", description="Summary of the text", type="str")],
|
135 |
+
)
|
136 |
+
|
137 |
+
# Execute the step
|
138 |
+
result = execute_model_step(step, {"input_text": "Long text to be summarized..."})
|
139 |
+
|
140 |
+
# Verify the results
|
141 |
+
assert result == {"summary": "This is a summary"}
|
142 |
+
|
143 |
+
# Verify the litellm call was made correctly
|
144 |
+
mock_completion.assert_called_once()
|
145 |
+
args, kwargs = mock_completion.call_args
|
146 |
+
assert kwargs["model"] == "gpt-3.5-turbo"
|
147 |
+
assert "Summarize the text" in kwargs["messages"][0]["content"]
|
148 |
+
|
149 |
+
|
150 |
+
@patch("workflows.executors.litellm.completion")
|
151 |
+
def test_execute_model_step_error(mock_completion):
|
152 |
+
"""Test handling of errors in model step execution."""
|
153 |
+
# Make litellm raise an exception
|
154 |
+
mock_completion.side_effect = Exception("API Error")
|
155 |
+
|
156 |
+
# Create a test step
|
157 |
+
step = ModelStep(
|
158 |
+
id="summarize",
|
159 |
+
model="gpt-3.5-turbo",
|
160 |
+
provider="openai",
|
161 |
+
call_type="llm",
|
162 |
+
system_prompt="Summarize the text",
|
163 |
+
input_fields=[InputField(name="text", description="Text to summarize", variable="input_text")],
|
164 |
+
output_fields=[OutputField(name="summary", description="Summary of the text", type="str")],
|
165 |
+
)
|
166 |
+
|
167 |
+
# Execute the step - should raise an exception
|
168 |
+
with pytest.raises(Exception):
|
169 |
+
execute_model_step(step, {"input_text": "Long text to be summarized..."})
|
170 |
+
|
171 |
+
|
172 |
+
# Tests for execute_workflow
|
173 |
+
|
174 |
+
|
175 |
+
@patch("workflows.executors.execute_model_step")
|
176 |
+
def test_execute_workflow_simple(mock_execute_step):
|
177 |
+
"""Test execution of a simple workflow with a single step."""
|
178 |
+
# Configure mock to return expected outputs
|
179 |
+
mock_execute_step.return_value = {"summary": "This is a summary"}
|
180 |
+
|
181 |
+
# Create a simple workflow
|
182 |
+
step = ModelStep(
|
183 |
+
id="summarize",
|
184 |
+
model="gpt-3.5-turbo",
|
185 |
+
provider="openai",
|
186 |
+
call_type="llm",
|
187 |
+
system_prompt="Summarize the text",
|
188 |
+
input_fields=[InputField(name="text", description="Text to summarize", variable="input_text")],
|
189 |
+
output_fields=[OutputField(name="summary", description="Summary of the text", type="str")],
|
190 |
+
)
|
191 |
+
|
192 |
+
workflow = Workflow(steps={"summarize": step}, inputs=["input_text"], outputs={"summary": "summarize.summary"})
|
193 |
+
|
194 |
+
# Execute the workflow
|
195 |
+
result = execute_workflow(workflow, {"input_text": "Long text to be summarized..."})
|
196 |
+
|
197 |
+
# Verify the results
|
198 |
+
assert result == {"summary": "This is a summary"}
|
199 |
+
|
200 |
+
# Verify execute_model_step was called correctly
|
201 |
+
mock_execute_step.assert_called_once()
|
202 |
+
|
203 |
+
|
204 |
+
@patch("workflows.executors.execute_model_step")
|
205 |
+
def test_execute_workflow_multi_step(mock_execute_step):
|
206 |
+
"""Test execution of a multi-step workflow with dependencies."""
|
207 |
+
|
208 |
+
# Configure mock to return different values based on the step
|
209 |
+
def side_effect(step, available_vars):
|
210 |
+
if step.id == "extract":
|
211 |
+
return {"entities": ["Apple", "product"]}
|
212 |
+
elif step.id == "analyze":
|
213 |
+
return {"sentiment": "positive"}
|
214 |
+
return {}
|
215 |
+
|
216 |
+
mock_execute_step.side_effect = side_effect
|
217 |
+
|
218 |
+
# Create extract step
|
219 |
+
extract_step = ModelStep(
|
220 |
+
id="extract",
|
221 |
+
model="gpt-3.5-turbo",
|
222 |
+
provider="openai",
|
223 |
+
call_type="llm",
|
224 |
+
system_prompt="Extract entities",
|
225 |
+
input_fields=[InputField(name="text", description="Text to analyze", variable="input_text")],
|
226 |
+
output_fields=[OutputField(name="entities", description="Extracted entities", type="list[str]")],
|
227 |
+
)
|
228 |
+
|
229 |
+
# Create analyze step that depends on extract step
|
230 |
+
analyze_step = ModelStep(
|
231 |
+
id="analyze",
|
232 |
+
model="gpt-4",
|
233 |
+
provider="openai",
|
234 |
+
call_type="llm",
|
235 |
+
system_prompt="Analyze sentiment",
|
236 |
+
input_fields=[InputField(name="entities", description="Entities to analyze", variable="extract.entities")],
|
237 |
+
output_fields=[OutputField(name="sentiment", description="Sentiment analysis", type="str")],
|
238 |
+
)
|
239 |
+
|
240 |
+
workflow = Workflow(
|
241 |
+
steps={"extract": extract_step, "analyze": analyze_step},
|
242 |
+
inputs=["input_text"],
|
243 |
+
outputs={"entities": "extract.entities", "sentiment": "analyze.sentiment"},
|
244 |
+
)
|
245 |
+
|
246 |
+
# Execute the workflow
|
247 |
+
result = execute_workflow(workflow, {"input_text": "Apple is launching a new product tomorrow."})
|
248 |
+
|
249 |
+
# Verify the results
|
250 |
+
assert result == {"entities": ["Apple", "product"], "sentiment": "positive"}
|
251 |
+
|
252 |
+
# Verify execute_model_step was called twice (once for each step)
|
253 |
+
assert mock_execute_step.call_count == 2
|
254 |
+
|
255 |
+
|
256 |
+
def test_execute_workflow_missing_input():
|
257 |
+
"""Test that an error is raised when a required input is missing."""
|
258 |
+
step = ModelStep(
|
259 |
+
id="summarize",
|
260 |
+
model="gpt-3.5-turbo",
|
261 |
+
provider="openai",
|
262 |
+
call_type="llm",
|
263 |
+
system_prompt="Summarize the text",
|
264 |
+
input_fields=[InputField(name="text", description="Text to summarize", variable="input_text")],
|
265 |
+
output_fields=[OutputField(name="summary", description="Summary of the text", type="str")],
|
266 |
+
)
|
267 |
+
|
268 |
+
workflow = Workflow(steps={"summarize": step}, inputs=["input_text"], outputs={"summary": "summarize.summary"})
|
269 |
+
|
270 |
+
# Execute with missing input
|
271 |
+
with pytest.raises(WorkflowError, match="Missing required workflow input"):
|
272 |
+
execute_workflow(workflow, {})
|
273 |
+
|
274 |
+
|
275 |
+
@patch("workflows.executors.create_dependency_graph")
|
276 |
+
def test_execute_workflow_cyclic_dependency(mock_dependency_graph):
|
277 |
+
"""Test that a cyclic dependency in the workflow raises an appropriate error."""
|
278 |
+
# Make create_dependency_graph raise a CyclicDependencyError
|
279 |
+
mock_dependency_graph.side_effect = CyclicDependencyError()
|
280 |
+
|
281 |
+
step = ModelStep(
|
282 |
+
id="test",
|
283 |
+
model="gpt-3.5-turbo",
|
284 |
+
provider="openai",
|
285 |
+
call_type="llm",
|
286 |
+
system_prompt="Test",
|
287 |
+
input_fields=[],
|
288 |
+
output_fields=[],
|
289 |
+
)
|
290 |
+
|
291 |
+
workflow = Workflow(steps={"test": step}, inputs=[], outputs=[])
|
292 |
+
|
293 |
+
# This should propagate the CyclicDependencyError
|
294 |
+
with pytest.raises(CyclicDependencyError):
|
295 |
+
execute_workflow(workflow, {})
|
tests/test_utils.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
|
3 |
+
from workflows.errors import CyclicDependencyError, UnknownVariableError, WorkflowError
|
4 |
+
from workflows.utils import _create_variable_step_mapping, create_dependency_graph, topological_sort
|
5 |
+
|
6 |
+
|
7 |
+
# Dummy classes to simulate Workflow, Step, and Field
|
8 |
+
class DummyField:
|
9 |
+
def __init__(self, name, type="str", variable=None):
|
10 |
+
self.name = name
|
11 |
+
self.type = type
|
12 |
+
# For input fields, variable property is needed
|
13 |
+
self.variable = variable if variable is not None else name
|
14 |
+
|
15 |
+
|
16 |
+
class DummyStep:
|
17 |
+
def __init__(self, input_fields, output_fields):
|
18 |
+
self.input_fields = input_fields
|
19 |
+
self.output_fields = output_fields
|
20 |
+
|
21 |
+
|
22 |
+
class DummyWorkflow:
|
23 |
+
def __init__(self, steps):
|
24 |
+
# steps is a dict with key as step_id and value as DummyStep
|
25 |
+
self.steps = steps
|
26 |
+
|
27 |
+
|
28 |
+
# Tests for _create_variable_step_mapping
|
29 |
+
|
30 |
+
|
31 |
+
def test_create_variable_step_mapping_success():
|
32 |
+
# Create a workflow with two steps producing unique output variables
|
33 |
+
step_a = DummyStep(input_fields=[], output_fields=[DummyField("out1")])
|
34 |
+
step_b = DummyStep(input_fields=[], output_fields=[DummyField("out2")])
|
35 |
+
workflow = DummyWorkflow({"A": step_a, "B": step_b})
|
36 |
+
mapping = _create_variable_step_mapping(workflow)
|
37 |
+
assert mapping == {"A.out1": "A", "B.out2": "B"}
|
38 |
+
|
39 |
+
|
40 |
+
def test_create_variable_step_mapping_duplicate():
|
41 |
+
# Create a workflow where two steps produce an output with same name
|
42 |
+
step_a = DummyStep(input_fields=[], output_fields=[DummyField("out"), DummyField("out")])
|
43 |
+
workflow = DummyWorkflow({"A": step_a})
|
44 |
+
with pytest.raises(WorkflowError):
|
45 |
+
_create_variable_step_mapping(workflow)
|
46 |
+
|
47 |
+
|
48 |
+
def test_create_variable_step_mapping_empty():
|
49 |
+
"""Test _create_variable_step_mapping with an empty workflow should return an empty mapping."""
|
50 |
+
workflow = DummyWorkflow({})
|
51 |
+
mapping = _create_variable_step_mapping(workflow)
|
52 |
+
assert mapping == {}
|
53 |
+
|
54 |
+
|
55 |
+
def test_create_variable_step_mapping_multiple_outputs():
|
56 |
+
"""Test a workflow where a single step produces multiple outputs with unique names."""
|
57 |
+
step = DummyStep(input_fields=[], output_fields=[DummyField("out1"), DummyField("out2")])
|
58 |
+
workflow = DummyWorkflow({"A": step})
|
59 |
+
mapping = _create_variable_step_mapping(workflow)
|
60 |
+
assert mapping == {"A.out1": "A", "A.out2": "A"}
|
61 |
+
|
62 |
+
|
63 |
+
# Tests for create_dependency_graph
|
64 |
+
|
65 |
+
|
66 |
+
def test_create_dependency_graph_success_with_dependency():
|
67 |
+
# Step A produces 'A.out', which is used as input in step B
|
68 |
+
step_a = DummyStep(input_fields=[], output_fields=[DummyField("out")])
|
69 |
+
# For input_fields, explicitly set variable to reference A.out
|
70 |
+
step_b = DummyStep(input_fields=[DummyField("dummy", variable="A.out")], output_fields=[DummyField("result")])
|
71 |
+
workflow = DummyWorkflow({"A": step_a, "B": step_b})
|
72 |
+
# No external input provided for A.out so dependency must be created
|
73 |
+
deps = create_dependency_graph(workflow, input_values={})
|
74 |
+
# Step B depends on step A
|
75 |
+
assert deps["B"] == {"A"}
|
76 |
+
# Step A has no dependencies
|
77 |
+
assert deps["A"] == set()
|
78 |
+
|
79 |
+
|
80 |
+
def test_create_dependency_graph_success_with_external_input():
|
81 |
+
# Step B expects an input, but it is provided externally
|
82 |
+
step_b = DummyStep(
|
83 |
+
input_fields=[DummyField("param", variable="external_param")], output_fields=[DummyField("result")]
|
84 |
+
)
|
85 |
+
workflow = DummyWorkflow({"B": step_b})
|
86 |
+
# Provide external input for external_param
|
87 |
+
deps = create_dependency_graph(workflow, input_values={"external_param": 42})
|
88 |
+
# With external input, no dependency is needed
|
89 |
+
assert deps["B"] == set()
|
90 |
+
|
91 |
+
|
92 |
+
def test_create_dependency_graph_unknown_variable():
|
93 |
+
# Step B expects an input that is neither produced by any step nor provided externally
|
94 |
+
step_b = DummyStep(
|
95 |
+
input_fields=[DummyField("param", variable="non_existent")], output_fields=[DummyField("result")]
|
96 |
+
)
|
97 |
+
workflow = DummyWorkflow({"B": step_b})
|
98 |
+
with pytest.raises(UnknownVariableError):
|
99 |
+
_ = create_dependency_graph(workflow, input_values={})
|
100 |
+
|
101 |
+
|
102 |
+
def test_create_dependency_graph_complex():
|
103 |
+
"""Test create_dependency_graph on a more complex workflow with multiple dependencies."""
|
104 |
+
# Step A produces A.out, Step B uses A.out, Step C uses B.out, and Step D uses both A.out and B.out
|
105 |
+
step_a = DummyStep(input_fields=[], output_fields=[DummyField("out")])
|
106 |
+
step_b = DummyStep(input_fields=[DummyField("inp", variable="A.out")], output_fields=[DummyField("out")])
|
107 |
+
step_c = DummyStep(input_fields=[DummyField("inp", variable="B.out")], output_fields=[DummyField("result")])
|
108 |
+
step_d = DummyStep(
|
109 |
+
input_fields=[DummyField("inp1", variable="A.out"), DummyField("inp2", variable="B.out")],
|
110 |
+
output_fields=[DummyField("final")],
|
111 |
+
)
|
112 |
+
|
113 |
+
workflow = DummyWorkflow({"A": step_a, "B": step_b, "C": step_c, "D": step_d})
|
114 |
+
# Provide external input for "B.out" so that step B's output isn't expected to come from a step
|
115 |
+
# However, to simulate dependency, assume external input is not provided for the dependencies used in step C and D
|
116 |
+
# Therefore, workflow must resolve A.out for step B, and then step B produces B.out for steps C and D.
|
117 |
+
# Let's not provide any external input, so both dependencies are created.
|
118 |
+
|
119 |
+
deps = create_dependency_graph(workflow, input_values={})
|
120 |
+
# Expected dependencies:
|
121 |
+
# B depends on A
|
122 |
+
# C depends on B
|
123 |
+
# D depends on both A and B
|
124 |
+
assert deps["B"] == {"A"}
|
125 |
+
assert deps["C"] == {"B"}
|
126 |
+
assert deps["D"] == {"A", "B"}
|
127 |
+
|
128 |
+
|
129 |
+
# Tests for topological_sort
|
130 |
+
|
131 |
+
|
132 |
+
def test_topological_sort_success():
|
133 |
+
# Create a simple dependency graph: A -> B -> C
|
134 |
+
deps = {"A": set(), "B": {"A"}, "C": {"B"}}
|
135 |
+
order = topological_sort(deps)
|
136 |
+
# Check that order satisfies dependencies: A before B, B before C
|
137 |
+
assert order.index("A") < order.index("B") < order.index("C")
|
138 |
+
|
139 |
+
|
140 |
+
def test_topological_sort_cycle():
|
141 |
+
# Create a cyclic dependency: A -> B and B -> A
|
142 |
+
deps = {"A": {"B"}, "B": {"A"}}
|
143 |
+
with pytest.raises(CyclicDependencyError):
|
144 |
+
_ = topological_sort(deps)
|
145 |
+
|
146 |
+
|
147 |
+
def test_topological_sort_single_node():
|
148 |
+
"""Test topological_sort on a graph with a single node and no dependencies."""
|
149 |
+
deps = {"A": set()}
|
150 |
+
order = topological_sort(deps)
|
151 |
+
assert order == ["A"]
|
152 |
+
|
153 |
+
|
154 |
+
def test_topological_sort_disconnected():
|
155 |
+
"""Test topological_sort on a graph with disconnected nodes (no dependencies among them)."""
|
156 |
+
deps = {"A": set(), "B": set(), "C": set()}
|
157 |
+
order = topological_sort(deps)
|
158 |
+
# The order can be in any permutation, but must contain all nodes
|
159 |
+
assert set(order) == {"A", "B", "C"}
|
tests/test_validators.py
ADDED
@@ -0,0 +1,647 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List
|
2 |
+
|
3 |
+
import pytest
|
4 |
+
from pydantic import ValidationError as PydanticValidationError
|
5 |
+
|
6 |
+
from workflows.structs import InputField, ModelStep, OutputField, Workflow
|
7 |
+
from workflows.validators import ValidationError, ValidationErrorType, WorkflowValidator
|
8 |
+
|
9 |
+
|
10 |
+
# Test Data
|
11 |
+
def create_basic_step(step_id: str = "step1") -> ModelStep:
|
12 |
+
"""Creates a basic valid step for testing"""
|
13 |
+
return ModelStep(
|
14 |
+
id=step_id,
|
15 |
+
name="Test Step",
|
16 |
+
model="gpt-4",
|
17 |
+
provider="openai",
|
18 |
+
call_type="llm",
|
19 |
+
temperature=0.7,
|
20 |
+
system_prompt="Test prompt",
|
21 |
+
input_fields=[],
|
22 |
+
output_fields=[],
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
def create_basic_workflow(steps: List[ModelStep] | None = None) -> Workflow:
|
27 |
+
"""Creates a basic valid workflow for testing"""
|
28 |
+
if steps is None:
|
29 |
+
steps = [create_basic_step()]
|
30 |
+
return Workflow(inputs=[], outputs={}, steps={step.id: step for step in steps})
|
31 |
+
|
32 |
+
|
33 |
+
# Additional Test Data
|
34 |
+
def create_step_with_fields(
|
35 |
+
step_id: str, input_fields: List[InputField], output_fields: List[OutputField]
|
36 |
+
) -> ModelStep:
|
37 |
+
"""Creates a step with specific input and output fields"""
|
38 |
+
return ModelStep(
|
39 |
+
id=step_id,
|
40 |
+
name="Test Step",
|
41 |
+
model="gpt-4",
|
42 |
+
provider="openai",
|
43 |
+
call_type="llm",
|
44 |
+
temperature=0.7,
|
45 |
+
system_prompt="Test prompt",
|
46 |
+
input_fields=input_fields,
|
47 |
+
output_fields=output_fields,
|
48 |
+
)
|
49 |
+
|
50 |
+
|
51 |
+
def create_valid_workflow() -> Workflow:
|
52 |
+
# Create a step with input and output fields
|
53 |
+
step = create_step_with_fields(
|
54 |
+
"step1",
|
55 |
+
[InputField(name="input", description="test", variable="external_input")],
|
56 |
+
[OutputField(name="output", description="test", type="str")],
|
57 |
+
)
|
58 |
+
|
59 |
+
# Create workflow with the single step
|
60 |
+
workflow = create_basic_workflow([step])
|
61 |
+
workflow.inputs = ["external_input"]
|
62 |
+
workflow.outputs = {"output": "step1.output"}
|
63 |
+
return workflow
|
64 |
+
|
65 |
+
|
66 |
+
# Basic Workflow Validation Tests
|
67 |
+
class TestBasicWorkflowValidation:
|
68 |
+
def test_empty_workflow(self):
|
69 |
+
"""Test validation of empty workflow"""
|
70 |
+
validator = WorkflowValidator()
|
71 |
+
workflow = Workflow(inputs=["input"], outputs={"output": "input"}, steps={})
|
72 |
+
assert not validator.validate(workflow)
|
73 |
+
assert len(validator.errors) == 1
|
74 |
+
assert validator.errors[0].error_type == ValidationErrorType.GENERAL
|
75 |
+
assert "must contain at least one step" in validator.errors[0].message
|
76 |
+
|
77 |
+
def test_workflow_without_inputs(self):
|
78 |
+
"""Test validation of workflow without inputs"""
|
79 |
+
validator = WorkflowValidator()
|
80 |
+
workflow = create_basic_workflow()
|
81 |
+
workflow.inputs = []
|
82 |
+
workflow.outputs = {"output": "step1.field"}
|
83 |
+
assert not validator.validate(workflow)
|
84 |
+
assert len(validator.errors) == 1
|
85 |
+
assert validator.errors[0].error_type == ValidationErrorType.GENERAL
|
86 |
+
assert "must contain at least one input" in validator.errors[0].message
|
87 |
+
|
88 |
+
def test_workflow_without_outputs(self):
|
89 |
+
"""Test validation of workflow without outputs"""
|
90 |
+
validator = WorkflowValidator()
|
91 |
+
workflow = create_basic_workflow()
|
92 |
+
workflow.inputs = ["input"]
|
93 |
+
workflow.outputs = {}
|
94 |
+
assert not validator.validate(workflow)
|
95 |
+
assert len(validator.errors) == 1
|
96 |
+
assert validator.errors[0].error_type == ValidationErrorType.GENERAL
|
97 |
+
assert "must contain at least one output" in validator.errors[0].message
|
98 |
+
|
99 |
+
def test_single_step_workflow(self):
|
100 |
+
"""Test validation of valid single-step workflow"""
|
101 |
+
validator = WorkflowValidator()
|
102 |
+
|
103 |
+
# Create a step with input and output fields
|
104 |
+
workflow = create_valid_workflow()
|
105 |
+
|
106 |
+
assert validator.validate(workflow)
|
107 |
+
assert len(validator.errors) == 0
|
108 |
+
|
109 |
+
|
110 |
+
# Step Validation Tests
|
111 |
+
class TestStepValidation:
|
112 |
+
def test_missing_required_fields(self):
|
113 |
+
"""Test validation of step with missing required fields"""
|
114 |
+
validator = WorkflowValidator()
|
115 |
+
step = ModelStep(
|
116 |
+
id="step1",
|
117 |
+
name="", # Missing name
|
118 |
+
model="", # Missing model
|
119 |
+
provider="", # Missing provider
|
120 |
+
call_type="", # Missing call_type
|
121 |
+
temperature=0.7,
|
122 |
+
system_prompt="Test prompt",
|
123 |
+
input_fields=[],
|
124 |
+
output_fields=[],
|
125 |
+
)
|
126 |
+
workflow = create_basic_workflow([step])
|
127 |
+
workflow.inputs = ["input"]
|
128 |
+
workflow.outputs = {"output": "step1.field"}
|
129 |
+
assert not validator.validate(workflow)
|
130 |
+
assert len(validator.errors) == 1
|
131 |
+
assert validator.errors[0].error_type == ValidationErrorType.STEP
|
132 |
+
|
133 |
+
def test_invalid_step_id(self):
|
134 |
+
"""Test validation of step with invalid ID format"""
|
135 |
+
validator = WorkflowValidator()
|
136 |
+
step = create_basic_step("123invalid") # Invalid ID format
|
137 |
+
workflow = create_basic_workflow([step])
|
138 |
+
workflow.inputs = ["input"]
|
139 |
+
workflow.outputs = {"output": "step1.field"}
|
140 |
+
assert not validator.validate(workflow)
|
141 |
+
assert len(validator.errors) == 1
|
142 |
+
assert validator.errors[0].error_type == ValidationErrorType.NAMING
|
143 |
+
|
144 |
+
def test_llm_temperature_validation(self):
|
145 |
+
"""Test validation of LLM step temperature"""
|
146 |
+
validator = WorkflowValidator()
|
147 |
+
|
148 |
+
# Test invalid temperature
|
149 |
+
step = create_basic_step()
|
150 |
+
step.temperature = 1.5 # Invalid temperature
|
151 |
+
workflow = create_basic_workflow([step])
|
152 |
+
workflow.inputs = ["input"]
|
153 |
+
workflow.outputs = {"output": "step1.field"}
|
154 |
+
assert not validator.validate(workflow)
|
155 |
+
assert len(validator.errors) == 1
|
156 |
+
assert validator.errors[0].error_type == ValidationErrorType.RANGE
|
157 |
+
|
158 |
+
# Test missing temperature
|
159 |
+
step = create_basic_step()
|
160 |
+
step.temperature = None # Missing temperature
|
161 |
+
workflow = create_basic_workflow([step])
|
162 |
+
workflow.inputs = ["input"]
|
163 |
+
workflow.outputs = {"output": "step1.field"}
|
164 |
+
assert not validator.validate(workflow)
|
165 |
+
assert len(validator.errors) == 1
|
166 |
+
assert validator.errors[0].error_type == ValidationErrorType.STEP
|
167 |
+
|
168 |
+
def test_llm_system_prompt_validation(self):
|
169 |
+
"""Test validation of LLM step system prompt"""
|
170 |
+
validator = WorkflowValidator()
|
171 |
+
|
172 |
+
# Test missing system prompt
|
173 |
+
step = create_basic_step()
|
174 |
+
step.system_prompt = "" # Missing system prompt
|
175 |
+
workflow = create_basic_workflow([step])
|
176 |
+
workflow.inputs = ["input"]
|
177 |
+
workflow.outputs = {"output": "step1.field"}
|
178 |
+
assert not validator.validate(workflow)
|
179 |
+
assert len(validator.errors) == 1
|
180 |
+
assert validator.errors[0].error_type == ValidationErrorType.STEP
|
181 |
+
|
182 |
+
# Test too long system prompt
|
183 |
+
step = create_basic_step()
|
184 |
+
step.system_prompt = "x" * 4001 # Too long
|
185 |
+
workflow = create_basic_workflow([step])
|
186 |
+
workflow.inputs = ["input"]
|
187 |
+
workflow.outputs = {"output": "step1.field"}
|
188 |
+
assert not validator.validate(workflow)
|
189 |
+
assert len(validator.errors) == 1
|
190 |
+
assert validator.errors[0].error_type == ValidationErrorType.LENGTH
|
191 |
+
|
192 |
+
|
193 |
+
# Field Validation Tests
|
194 |
+
class TestFieldValidation:
|
195 |
+
def test_input_field_validation(self):
|
196 |
+
"""Test validation of input fields"""
|
197 |
+
validator = WorkflowValidator()
|
198 |
+
|
199 |
+
# Test missing required fields
|
200 |
+
step = create_basic_step()
|
201 |
+
step.input_fields = [InputField(name="", description="", variable="")]
|
202 |
+
workflow = create_basic_workflow([step])
|
203 |
+
workflow.inputs = ["input"]
|
204 |
+
workflow.outputs = {"output": "step1.field"}
|
205 |
+
assert not validator.validate(workflow)
|
206 |
+
assert len(validator.errors) == 1
|
207 |
+
assert validator.errors[0].error_type == ValidationErrorType.STEP
|
208 |
+
|
209 |
+
# Test invalid field name
|
210 |
+
step = create_basic_step()
|
211 |
+
step.input_fields = [InputField(name="123invalid", description="test", variable="test")]
|
212 |
+
workflow = create_basic_workflow([step])
|
213 |
+
workflow.inputs = ["input"]
|
214 |
+
workflow.outputs = {"output": "step1.field"}
|
215 |
+
assert not validator.validate(workflow)
|
216 |
+
assert len(validator.errors) == 1
|
217 |
+
assert validator.errors[0].error_type == ValidationErrorType.NAMING
|
218 |
+
|
219 |
+
# Test too long description
|
220 |
+
step = create_basic_step()
|
221 |
+
step.input_fields = [InputField(name="test", description="x" * 201, variable="test")]
|
222 |
+
workflow = create_basic_workflow([step])
|
223 |
+
workflow.inputs = ["input"]
|
224 |
+
workflow.outputs = {"output": "step1.field"}
|
225 |
+
assert not validator.validate(workflow)
|
226 |
+
assert len(validator.errors) == 1
|
227 |
+
assert validator.errors[0].error_type == ValidationErrorType.LENGTH
|
228 |
+
|
229 |
+
def test_output_field_validation(self):
|
230 |
+
"""Test validation of output fields"""
|
231 |
+
validator = WorkflowValidator()
|
232 |
+
|
233 |
+
# Test missing required fields
|
234 |
+
step = create_basic_step()
|
235 |
+
step.output_fields = [OutputField(name="", description="", type="str")]
|
236 |
+
workflow = create_basic_workflow([step])
|
237 |
+
workflow.inputs = ["input"]
|
238 |
+
workflow.outputs = {"output": "step1.field"}
|
239 |
+
assert not validator.validate(workflow)
|
240 |
+
assert len(validator.errors) == 1
|
241 |
+
assert validator.errors[0].error_type == ValidationErrorType.STEP
|
242 |
+
|
243 |
+
# Test invalid field name
|
244 |
+
step = create_basic_step()
|
245 |
+
step.output_fields = [OutputField(name="123invalid", description="test", type="str")]
|
246 |
+
workflow = create_basic_workflow([step])
|
247 |
+
workflow.inputs = ["input"]
|
248 |
+
workflow.outputs = {"output": "step1.field"}
|
249 |
+
assert not validator.validate(workflow)
|
250 |
+
assert len(validator.errors) == 1
|
251 |
+
assert validator.errors[0].error_type == ValidationErrorType.NAMING
|
252 |
+
|
253 |
+
def test_field_name_length(self):
|
254 |
+
"""Test validation of field name length"""
|
255 |
+
validator = WorkflowValidator()
|
256 |
+
|
257 |
+
# Test too long field name
|
258 |
+
step = create_basic_step()
|
259 |
+
step.input_fields = [InputField(name="x" * 51, description="test", variable="test")]
|
260 |
+
workflow = create_basic_workflow([step])
|
261 |
+
workflow.inputs = ["input"]
|
262 |
+
workflow.outputs = {"output": "step1.field"}
|
263 |
+
assert not validator.validate(workflow)
|
264 |
+
assert len(validator.errors) == 1
|
265 |
+
assert validator.errors[0].error_type == ValidationErrorType.LENGTH
|
266 |
+
|
267 |
+
def test_field_description_length(self):
|
268 |
+
"""Test validation of field description length"""
|
269 |
+
validator = WorkflowValidator()
|
270 |
+
|
271 |
+
# Test too long description
|
272 |
+
step = create_basic_step()
|
273 |
+
step.input_fields = [InputField(name="test", description="x" * 201, variable="test")]
|
274 |
+
workflow = create_basic_workflow([step])
|
275 |
+
workflow.inputs = ["input"]
|
276 |
+
workflow.outputs = {"output": "step1.field"}
|
277 |
+
assert not validator.validate(workflow)
|
278 |
+
assert len(validator.errors) == 1
|
279 |
+
assert validator.errors[0].error_type == ValidationErrorType.LENGTH
|
280 |
+
|
281 |
+
def test_whitespace_only_strings(self):
|
282 |
+
"""Test validation of whitespace-only strings"""
|
283 |
+
validator = WorkflowValidator()
|
284 |
+
|
285 |
+
# Test whitespace-only field name
|
286 |
+
step = create_basic_step()
|
287 |
+
step.input_fields = [InputField(name=" ", description="test", variable="test")]
|
288 |
+
workflow = create_basic_workflow([step])
|
289 |
+
workflow.inputs = ["input"]
|
290 |
+
workflow.outputs = {"output": "step1.field"}
|
291 |
+
assert not validator.validate(workflow)
|
292 |
+
assert len(validator.errors) == 1
|
293 |
+
assert validator.errors[0].error_type == ValidationErrorType.NAMING
|
294 |
+
|
295 |
+
def test_special_characters(self):
|
296 |
+
"""Test validation of special characters in names"""
|
297 |
+
validator = WorkflowValidator()
|
298 |
+
|
299 |
+
# Test special characters in field name
|
300 |
+
step = create_basic_step()
|
301 |
+
step.input_fields = [InputField(name="test@field", description="test", variable="test")]
|
302 |
+
workflow = create_basic_workflow([step])
|
303 |
+
workflow.inputs = ["input"]
|
304 |
+
workflow.outputs = {"output": "step1.field"}
|
305 |
+
assert not validator.validate(workflow)
|
306 |
+
assert len(validator.errors) == 1
|
307 |
+
assert validator.errors[0].error_type == ValidationErrorType.NAMING
|
308 |
+
|
309 |
+
|
310 |
+
# Variable Reference Tests
|
311 |
+
class TestVariableReference:
|
312 |
+
def test_external_input_validation(self):
|
313 |
+
"""Test validation of external input variables"""
|
314 |
+
validator = WorkflowValidator()
|
315 |
+
|
316 |
+
# Test invalid external input format
|
317 |
+
workflow = create_valid_workflow()
|
318 |
+
workflow.inputs = ["step1.field"] # Invalid format
|
319 |
+
assert not validator.validate(workflow)
|
320 |
+
assert len(validator.errors) == 1
|
321 |
+
assert validator.errors[0].error_type == ValidationErrorType.VARIABLE
|
322 |
+
|
323 |
+
def test_step_output_reference(self):
|
324 |
+
"""Test validation of step output references"""
|
325 |
+
validator = WorkflowValidator()
|
326 |
+
|
327 |
+
# Test invalid output reference
|
328 |
+
workflow = create_basic_workflow()
|
329 |
+
workflow.inputs = ["input"]
|
330 |
+
workflow.outputs = {"output": "nonexistent_step.field"}
|
331 |
+
assert not validator.validate(workflow)
|
332 |
+
assert len(validator.errors) == 1
|
333 |
+
assert validator.errors[0].error_type == ValidationErrorType.VARIABLE
|
334 |
+
|
335 |
+
# Test valid output reference
|
336 |
+
step = create_basic_step()
|
337 |
+
step.output_fields = [OutputField(name="field", description="test", type="str")]
|
338 |
+
workflow = create_basic_workflow([step])
|
339 |
+
workflow.inputs = ["input"]
|
340 |
+
workflow.outputs = {"output": "step1.field"}
|
341 |
+
assert validator.validate(workflow)
|
342 |
+
assert len(validator.errors) == 0
|
343 |
+
|
344 |
+
|
345 |
+
# DAG Validation Tests
|
346 |
+
class TestDAGValidation:
|
347 |
+
def test_cycle_detection(self):
|
348 |
+
"""Test detection of cycles in workflow"""
|
349 |
+
validator = WorkflowValidator()
|
350 |
+
|
351 |
+
# Create a workflow with a cycle
|
352 |
+
step1 = create_step_with_fields(
|
353 |
+
"step1",
|
354 |
+
[InputField(name="input", description="test", variable="step3.output")],
|
355 |
+
[OutputField(name="output", description="test", type="str")],
|
356 |
+
)
|
357 |
+
step2 = create_step_with_fields(
|
358 |
+
"step2",
|
359 |
+
[InputField(name="input", description="test", variable="step1.output")],
|
360 |
+
[OutputField(name="output", description="test", type="str")],
|
361 |
+
)
|
362 |
+
step3 = create_step_with_fields(
|
363 |
+
"step3",
|
364 |
+
[InputField(name="input", description="test", variable="step2.output")],
|
365 |
+
[OutputField(name="output", description="test", type="str")],
|
366 |
+
)
|
367 |
+
|
368 |
+
workflow = create_basic_workflow([step1, step2, step3])
|
369 |
+
workflow.inputs = ["input"]
|
370 |
+
workflow.outputs = {"output": "step3.output"}
|
371 |
+
assert not validator.validate(workflow)
|
372 |
+
assert len(validator.errors) == 1
|
373 |
+
assert validator.errors[0].error_type == ValidationErrorType.DAG
|
374 |
+
|
375 |
+
def test_orphaned_steps(self):
|
376 |
+
"""Test detection of orphaned steps"""
|
377 |
+
validator = WorkflowValidator()
|
378 |
+
|
379 |
+
# Create a workflow with an orphaned step
|
380 |
+
step1 = create_step_with_fields(
|
381 |
+
"step1",
|
382 |
+
[InputField(name="input", description="test", variable="step2.output")],
|
383 |
+
[OutputField(name="output", description="test", type="str")],
|
384 |
+
)
|
385 |
+
step2 = create_step_with_fields(
|
386 |
+
"step2",
|
387 |
+
[InputField(name="input", description="test", variable="step1.output")],
|
388 |
+
[OutputField(name="output", description="test", type="str")],
|
389 |
+
)
|
390 |
+
step3 = create_step_with_fields(
|
391 |
+
"step3",
|
392 |
+
[],
|
393 |
+
[OutputField(name="output", description="test", type="str")],
|
394 |
+
)
|
395 |
+
|
396 |
+
workflow = create_basic_workflow([step1, step2, step3])
|
397 |
+
workflow.inputs = ["input"]
|
398 |
+
workflow.outputs = {"output": "step3.output"}
|
399 |
+
assert not validator.validate(workflow)
|
400 |
+
assert len(validator.errors) == 1
|
401 |
+
assert validator.errors[0].error_type == ValidationErrorType.DAG
|
402 |
+
|
403 |
+
|
404 |
+
# Variable Dependency Tests
|
405 |
+
class TestVariableDependencies:
|
406 |
+
def test_circular_dependencies(self):
|
407 |
+
"""Test detection of circular variable dependencies"""
|
408 |
+
validator = WorkflowValidator()
|
409 |
+
|
410 |
+
# Create a workflow with circular variable dependencies
|
411 |
+
step1 = create_step_with_fields(
|
412 |
+
"step1",
|
413 |
+
[InputField(name="input", description="test", variable="step2.output")],
|
414 |
+
[OutputField(name="output", description="test", type="str")],
|
415 |
+
)
|
416 |
+
step2 = create_step_with_fields(
|
417 |
+
"step2",
|
418 |
+
[InputField(name="input", description="test", variable="step1.output")],
|
419 |
+
[OutputField(name="output", description="test", type="str")],
|
420 |
+
)
|
421 |
+
|
422 |
+
workflow = create_basic_workflow([step1, step2])
|
423 |
+
workflow.inputs = ["input"]
|
424 |
+
workflow.outputs = {"output": "step2.output"}
|
425 |
+
assert not validator.validate(workflow)
|
426 |
+
assert len(validator.errors) == 1
|
427 |
+
assert validator.errors[0].error_type == ValidationErrorType.DAG
|
428 |
+
|
429 |
+
def test_valid_dependencies(self):
|
430 |
+
"""Test validation of valid variable dependencies"""
|
431 |
+
validator = WorkflowValidator()
|
432 |
+
|
433 |
+
# Create a workflow with valid dependencies
|
434 |
+
step1 = create_step_with_fields(
|
435 |
+
"step1",
|
436 |
+
[InputField(name="input", description="test", variable="external_input")],
|
437 |
+
[OutputField(name="output", description="test", type="str")],
|
438 |
+
)
|
439 |
+
step2 = create_step_with_fields(
|
440 |
+
"step2",
|
441 |
+
[InputField(name="input", description="test", variable="step1.output")],
|
442 |
+
[OutputField(name="output", description="test", type="str")],
|
443 |
+
)
|
444 |
+
step3 = create_step_with_fields(
|
445 |
+
"step3",
|
446 |
+
[InputField(name="input", description="test", variable="step2.output")],
|
447 |
+
[OutputField(name="output", description="test", type="str")],
|
448 |
+
)
|
449 |
+
|
450 |
+
workflow = create_basic_workflow([step1, step2, step3])
|
451 |
+
workflow.inputs = ["external_input"]
|
452 |
+
workflow.outputs = {"output": "step3.output"}
|
453 |
+
assert validator.validate(workflow)
|
454 |
+
assert len(validator.errors) == 0
|
455 |
+
|
456 |
+
|
457 |
+
# Type Compatibility Tests
|
458 |
+
class TestTypeCompatibility:
|
459 |
+
def test_basic_type_compatibility(self):
|
460 |
+
"""Test validation of basic type compatibility"""
|
461 |
+
validator = WorkflowValidator()
|
462 |
+
|
463 |
+
# Create steps with type mismatch
|
464 |
+
step1 = create_step_with_fields(
|
465 |
+
"step1",
|
466 |
+
[InputField(name="input", description="test", variable="external_input")],
|
467 |
+
[OutputField(name="output", description="test", type="int")],
|
468 |
+
)
|
469 |
+
step2 = create_step_with_fields(
|
470 |
+
"step2",
|
471 |
+
[InputField(name="input", description="test", variable="step1.output")],
|
472 |
+
[OutputField(name="output", description="test", type="str")],
|
473 |
+
)
|
474 |
+
|
475 |
+
workflow = create_basic_workflow([step1, step2])
|
476 |
+
workflow.inputs = ["external_input"]
|
477 |
+
workflow.outputs = {"output": "step2.output"}
|
478 |
+
assert validator.validate(workflow)
|
479 |
+
|
480 |
+
# def test_list_type_compatibility(self):
|
481 |
+
# """Test validation of list type compatibility"""
|
482 |
+
# validator = WorkflowValidator()
|
483 |
+
|
484 |
+
# # Test compatible list types
|
485 |
+
# step1 = create_step_with_fields(
|
486 |
+
# "step1", [], [OutputField(name="output", description="test", type="list[str]")]
|
487 |
+
# )
|
488 |
+
# step2 = create_step_with_fields(
|
489 |
+
# "step2", [InputField(name="input", description="test", variable="step1.output")], []
|
490 |
+
# )
|
491 |
+
|
492 |
+
# workflow = create_basic_workflow([step1, step2])
|
493 |
+
# workflow.inputs = ["input"]
|
494 |
+
# workflow.outputs = {"output": "step2.output"}
|
495 |
+
# assert validator.validate(workflow)
|
496 |
+
# assert len(validator.errors) == 0
|
497 |
+
|
498 |
+
# # Test incompatible list types
|
499 |
+
# step1 = create_step_with_fields(
|
500 |
+
# "step1", [], [OutputField(name="output", description="test", type="list[int]")]
|
501 |
+
# )
|
502 |
+
# step2 = create_step_with_fields(
|
503 |
+
# "step2", [InputField(name="input", description="test", variable="step1.output")], []
|
504 |
+
# )
|
505 |
+
|
506 |
+
# workflow = create_basic_workflow([step1, step2])
|
507 |
+
# workflow.inputs = ["input"]
|
508 |
+
# workflow.outputs = {"output": "step2.output"}
|
509 |
+
# assert not validator.validate(workflow)
|
510 |
+
# assert len(validator.errors) == 1
|
511 |
+
# assert validator.errors[0].error_type == ValidationErrorType.TYPE
|
512 |
+
|
513 |
+
|
514 |
+
# Complex Workflow Tests
|
515 |
+
class TestComplexWorkflows:
|
516 |
+
def test_multi_output_workflow(self):
|
517 |
+
"""Test validation of workflow with multiple outputs"""
|
518 |
+
validator = WorkflowValidator()
|
519 |
+
|
520 |
+
# Create a workflow with multiple outputs
|
521 |
+
step1 = create_step_with_fields(
|
522 |
+
"step1",
|
523 |
+
[],
|
524 |
+
[
|
525 |
+
OutputField(name="output1", description="test", type="str"),
|
526 |
+
OutputField(name="output2", description="test", type="int"),
|
527 |
+
],
|
528 |
+
)
|
529 |
+
step2 = create_step_with_fields(
|
530 |
+
"step2",
|
531 |
+
[InputField(name="input", description="test", variable="step1.output1")],
|
532 |
+
[OutputField(name="output", description="test", type="str")],
|
533 |
+
)
|
534 |
+
|
535 |
+
workflow = create_basic_workflow([step1, step2])
|
536 |
+
workflow.inputs = ["input"]
|
537 |
+
workflow.outputs = {"output1": "step1.output1", "output2": "step1.output2", "output3": "step2.output"}
|
538 |
+
assert validator.validate(workflow)
|
539 |
+
assert len(validator.errors) == 0
|
540 |
+
|
541 |
+
def test_complex_dependencies(self):
|
542 |
+
"""Test validation of workflow with complex dependencies"""
|
543 |
+
validator = WorkflowValidator()
|
544 |
+
|
545 |
+
# Create a workflow with complex dependencies
|
546 |
+
step1 = create_step_with_fields(
|
547 |
+
"step1",
|
548 |
+
[InputField(name="input", description="test", variable="external_input")],
|
549 |
+
[OutputField(name="output", description="test", type="str")],
|
550 |
+
)
|
551 |
+
step2 = create_step_with_fields(
|
552 |
+
"step2",
|
553 |
+
[InputField(name="input", description="test", variable="step1.output")],
|
554 |
+
[OutputField(name="output", description="test", type="str")],
|
555 |
+
)
|
556 |
+
step3 = create_step_with_fields(
|
557 |
+
"step3",
|
558 |
+
[
|
559 |
+
InputField(name="input1", description="test", variable="step1.output"),
|
560 |
+
InputField(name="input2", description="test", variable="step2.output"),
|
561 |
+
],
|
562 |
+
[OutputField(name="output", description="test", type="str")],
|
563 |
+
)
|
564 |
+
|
565 |
+
workflow = create_basic_workflow([step1, step2, step3])
|
566 |
+
workflow.inputs = ["external_input"]
|
567 |
+
workflow.outputs = {"output": "step3.output"}
|
568 |
+
assert validator.validate(workflow)
|
569 |
+
assert len(validator.errors) == 0
|
570 |
+
|
571 |
+
|
572 |
+
# External Input Tests
|
573 |
+
class TestExternalInputs:
|
574 |
+
def test_external_input_existence(self):
|
575 |
+
"""Test validation of external input existence"""
|
576 |
+
validator = WorkflowValidator()
|
577 |
+
|
578 |
+
# Test missing external input
|
579 |
+
step = create_step_with_fields(
|
580 |
+
"step1", [InputField(name="input", description="test", variable="missing_input")], []
|
581 |
+
)
|
582 |
+
workflow = create_basic_workflow([step])
|
583 |
+
workflow.inputs = ["valid_input"]
|
584 |
+
workflow.outputs = {"output": "step1.output"}
|
585 |
+
assert not validator.validate(workflow)
|
586 |
+
assert len(validator.errors) == 1
|
587 |
+
assert validator.errors[0].error_type == ValidationErrorType.VARIABLE
|
588 |
+
|
589 |
+
def test_external_input_naming_conflicts(self):
|
590 |
+
"""Test validation of external input naming conflicts"""
|
591 |
+
validator = WorkflowValidator()
|
592 |
+
|
593 |
+
# Test conflict between external input and step output
|
594 |
+
step = create_step_with_fields("step1", [], [OutputField(name="output", description="test", type="str")])
|
595 |
+
workflow = create_basic_workflow([step])
|
596 |
+
workflow.inputs = ["step1.output"] # Conflict with step output
|
597 |
+
workflow.outputs = {"output": "step1.output"}
|
598 |
+
assert not validator.validate(workflow)
|
599 |
+
assert len(validator.errors) == 1
|
600 |
+
assert validator.errors[0].error_type == ValidationErrorType.VARIABLE
|
601 |
+
|
602 |
+
|
603 |
+
# Edge Cases
|
604 |
+
class TestEdgeCases:
|
605 |
+
def test_empty_workflow_with_inputs(self):
|
606 |
+
"""Test validation of empty workflow with inputs"""
|
607 |
+
validator = WorkflowValidator()
|
608 |
+
workflow = Workflow(inputs=["input"], outputs={}, steps={})
|
609 |
+
assert not validator.validate(workflow)
|
610 |
+
assert len(validator.errors) == 1
|
611 |
+
assert validator.errors[0].error_type == ValidationErrorType.GENERAL
|
612 |
+
|
613 |
+
def test_workflow_with_empty_outputs(self):
|
614 |
+
"""Test validation of workflow with empty outputs"""
|
615 |
+
validator = WorkflowValidator()
|
616 |
+
workflow = create_valid_workflow()
|
617 |
+
workflow.outputs = {} # Empty output reference
|
618 |
+
assert not validator.validate(workflow)
|
619 |
+
assert len(validator.errors) == 1
|
620 |
+
assert validator.errors[0].error_type == ValidationErrorType.GENERAL
|
621 |
+
|
622 |
+
def test_workflow_with_none_outputs(self):
|
623 |
+
"""Test validation of workflow with empty outputs"""
|
624 |
+
validator = WorkflowValidator()
|
625 |
+
workflow = create_valid_workflow()
|
626 |
+
workflow.outputs = {"output": None} # Empty output reference
|
627 |
+
assert not validator.validate(workflow)
|
628 |
+
assert len(validator.errors) == 1
|
629 |
+
assert validator.errors[0].error_type == ValidationErrorType.GENERAL
|
630 |
+
|
631 |
+
def test_workflow_with_duplicate_output_names(self):
|
632 |
+
"""Test validation of workflow with duplicate output names"""
|
633 |
+
validator = WorkflowValidator()
|
634 |
+
step = create_step_with_fields(
|
635 |
+
"step1",
|
636 |
+
[],
|
637 |
+
[
|
638 |
+
OutputField(name="output", description="test", type="str"),
|
639 |
+
OutputField(name="output", description="test", type="str"),
|
640 |
+
],
|
641 |
+
)
|
642 |
+
workflow = create_basic_workflow([step])
|
643 |
+
workflow.inputs = ["input"]
|
644 |
+
workflow.outputs = {"output": "step1.output"}
|
645 |
+
assert not validator.validate(workflow)
|
646 |
+
assert len(validator.errors) == 1
|
647 |
+
assert validator.errors[0].error_type == ValidationErrorType.STEP
|