File size: 14,038 Bytes
193db9d 9756440 0bab47c 193db9d 0bab47c 193db9d d0ae1a9 193db9d 0bab47c 193db9d 9756440 193db9d 9756440 193db9d 0bab47c 193db9d 0bab47c 193db9d 02b7dec 193db9d 4f5d1cb 193db9d 9756440 193db9d 9756440 193db9d 9756440 193db9d 9756440 193db9d e00ec4e 193db9d 38e3800 193db9d 9756440 d0ae1a9 9756440 da814b0 9756440 193db9d 0bab47c 02b7dec 9756440 02b7dec d0ae1a9 0bab47c d0ae1a9 0bab47c 9756440 d0ae1a9 9756440 d0ae1a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 |
# %%
from copy import deepcopy
from enum import Enum
from typing import Any, Literal, Optional
import numpy as np
from pydantic import BaseModel, Field, model_validator
from .configs import AVAILABLE_MODELS
"""
Core data structures for defining workflows and their components.
This module defines the primary classes used to model workflows, steps, and their
input/output fields. These data structures serve as the foundation for workflow
definition, validation, and execution throughout the workflows package.
The primary components are:
- InputField: Represents an input to a model step with name and source variable
- OutputField: Represents an output from a model step with name and type
- ModelStep: Represents a single step in a workflow with inputs and outputs
- Workflow: A collection of interconnected steps with defined inputs and outputs
All classes use Pydantic's BaseModel for validation and serialization support.
"""
FieldType = Literal["input", "output"]
SUPPORTED_TYPES = Literal["str", "int", "float", "bool", "list[str]", "list[int]", "list[float]", "list[bool]"]
"""Supported field types for input and output fields"""
class InputField(BaseModel):
"""
Defines an input field for a model step.
An input field specifies what data a step requires, where it comes from,
and optional pre-processing to apply before use.
Attributes:
name: The name of the input field within the step's context
description: Human-readable description of the input's purpose
variable: Reference to the source variable (format: "{step_id}.{field_name}" or external input name)
func: Optional function name to transform the input value before use
"""
name: str
description: str
variable: str
# function to call on the input before passing it to the model
func: str | None = None
class Config:
frozen = True
class OutputField(BaseModel):
"""
Defines an output field produced by a model step.
An output field specifies a value that the step will produce, including
its data type and optional post-processing.
Attributes:
name: The name of the output field within the step's context
description: Human-readable description of the output's purpose
type: The data type of the output (one of SUPPORTED_TYPES)
func: Optional function name to transform the raw output value
"""
name: str
type: SUPPORTED_TYPES = Field(default="str")
description: str
# function to call on the output string from the model
func: str | None = None
class Config:
frozen = True
class CallType(str, Enum):
LLM = "llm"
SEARCH = "search"
PYTHON_FUNC = "python_func"
class ModelStep(BaseModel):
"""
Represents a single step in a workflow.
A model step encapsulates the details of a specific operation within a workflow,
including what model to use, what inputs it requires, and what outputs it produces.
Attributes:
id: Unique identifier for this step within a workflow
model: The model to use for this step (e.g., "gpt-4")
provider: The provider of the model (e.g., "openai")
call_type: The type of operation (e.g., "llm", "search")
system_prompt: Instructions for the model
input_fields: List of input fields required by this step
output_fields: List of output fields produced by this step
"""
id: str
name: str
model: str
provider: str
call_type: CallType = CallType.LLM
# TODO: Validate that this is not None for call_type = llm
temperature: Optional[float] = None
system_prompt: str
input_fields: list[InputField]
output_fields: list[OutputField]
class Config:
use_enum_values = True
def fields(self, field_type: FieldType) -> list[InputField | OutputField]:
return self.input_fields if field_type == "input" else self.output_fields
def get_full_model_name(self) -> str:
return f"{self.provider}/{self.model}"
def get_produced_variables(self) -> list[str]:
return [f"{self.id}.{field.name}" for field in self.output_fields if field.name]
def update(self, update: dict[str, Any]) -> "ModelStep":
"""Returns a new copy with the updated properties."""
return self.model_copy(update=update)
def update_property(self, field: str, value: Any) -> "ModelStep":
"Update the `field` key of the model step with `value`."
return self.update({field: value})
def update_field(self, field_type: FieldType, index: int, key: str, value: str) -> "ModelStep":
"""Update a specific field of an input or output field at the given index."""
if field_type == "input":
fields = self.input_fields
elif field_type == "output":
fields = self.output_fields
else:
raise ValueError(f"Invalid field type: {field_type}")
if index < len(fields):
fields[index] = fields[index].model_copy(update={key: value})
return self.model_copy()
@staticmethod
def create_new_field(field_type: FieldType, input_var: str | None = None) -> InputField | OutputField:
if field_type == "input":
return InputField(name="", description="", variable=input_var)
elif field_type == "output":
return OutputField(name="", description="")
else:
raise ValueError(f"Invalid field type: {field_type}")
def add_field(self, field_type: FieldType, index: int = -1, input_var: str | None = None) -> "ModelStep":
"""Add a new field to the state and update visibility.
Args:
field_type: Type of field to add ('input' or 'output').
index: Position to insert the new field (-1 to append).
Returns:
A new ModelStep with the updated fields.
"""
if field_type == "input":
fields = deepcopy(self.input_fields)
new_field = ModelStep.create_new_field(field_type, input_var)
fields.insert(index + 1, new_field) if index != -1 else fields.append(new_field)
return self.model_copy(update={"input_fields": fields})
else:
fields = deepcopy(self.output_fields)
new_field = ModelStep.create_new_field(field_type)
fields.insert(index + 1, new_field) if index != -1 else fields.append(new_field)
return self.model_copy(update={"output_fields": fields})
def delete_field(self, field_type: FieldType, index: int) -> "ModelStep":
"""
Delete an input or output field from the state and update visibility.
Args:
field_type: Type of field to delete ('input' or 'output').
index: Index of the field to delete. [-1 to delete the last field]
Returns:
A new ModelStep with the updated fields.
"""
fields = self.input_fields if field_type == "input" else self.output_fields
fields = deepcopy(fields)
fields.pop(index)
return self.model_copy(update={"input_fields": fields} if field_type == "input" else {"output_fields": fields})
class Workflow(BaseModel):
"""
Represents a complete workflow composed of interconnected steps.
A workflow defines a directed acyclic graph of model steps, where outputs
from earlier steps can be used as inputs to later steps.
Attributes:
inputs: List of input variables required by the workflow
outputs: List of output variables produced by the workflow
steps: Dictionary mapping step IDs to ModelStep instances
The inputs and outputs lists use the format "{step_id}.{field_name}"
to uniquely identify variables within the workflow.
"""
# variables of form {node}.{field}
inputs: list[str] = Field(default_factory=list)
# variables of form {node}.{field}
outputs: dict[str, str | None] = Field(default_factory=dict)
steps: dict[str, ModelStep] = Field(default_factory=dict)
def model_dump(self, *args, **kwargs):
data = super().model_dump(*args, **kwargs)
if "steps" in data:
data["steps"] = list(data["steps"].values())
return data
@model_validator(mode="before")
def dictify_steps(cls, data):
if "steps" in data and isinstance(data["steps"], list):
steps_dict = {}
for step in data["steps"]:
if isinstance(step, ModelStep):
step_id = step.id
else:
step_id = step["id"]
if step_id in steps_dict:
raise ValueError(f"Duplicate step ID: {step_id}")
steps_dict[step_id] = step
data["steps"] = steps_dict
return data
def get_step_variables(self, step_id: str) -> list[str]:
"""Get all variables from a specific step."""
step = self.steps[step_id]
variables = []
for output in step.output_fields:
if output.name == "":
continue
output_var = f"{step.id}.{output.name}"
variables.append(output_var)
return variables
def get_available_variables(self) -> list[str]:
"""Get all output variables from all steps."""
variables = set(self.inputs)
for step in self.steps.values():
variables.update(self.get_step_variables(step.id))
return list(variables)
def get_model_selections(self) -> dict[str, str]:
"""Get all model selections for all steps."""
return {step_id: step.get_full_model_name() for step_id, step in self.steps.items()}
def get_output_model_selections(self) -> dict[str, str]:
"""Get all output model selections for all steps."""
return {
output_var: target_var.split(".")[0] if target_var else None
for output_var, target_var in self.outputs.items()
}
# Step update method
def add_step(self, step: ModelStep) -> "Workflow":
"""Add a step to the workflow."""
steps = self.steps | {step.id: step}
return self.model_copy(update={"steps": steps})
def remove_step(self, step_id: str) -> "Workflow":
"""Remove a step from the workflow."""
self.steps.pop(step_id)
workflow = self.model_copy(update={"steps": self.steps})
workflow.refresh_output_variables()
return workflow
def update_step(self, step: ModelStep) -> "Workflow":
"""Update a step in the workflow."""
self.steps[step.id] = step
steps = self.steps | {step.id: step}
workflow = self.model_copy(update={"steps": steps})
workflow.refresh_output_variables()
return workflow
# Output variables
def refresh_output_variables(self) -> "Workflow":
"""Refresh the output variables for the workflow."""
produced_variables = self.get_available_variables()
self.outputs = {k: (v if v in produced_variables else None) for k, v in self.outputs.items()}
return self
class BuzzerMethod(str, Enum):
AND = "AND"
OR = "OR"
class Buzzer(BaseModel):
"""Configuration for when to buzz in a tossup question."""
method: BuzzerMethod = BuzzerMethod.AND # Logic to combine thresholds
confidence_threshold: float = Field(default=0.8, ge=0.0, le=1.0) # Minimum confidence to trigger a buzz
prob_threshold: float | None = None # Optional log probability threshold
class Config:
use_enum_values = True
frozen = True
def update(self, **kwargs) -> "Buzzer":
"""Update the buzzer with the given kwargs."""
return self.model_copy(update=kwargs)
def run(self, confidence: float, prob: float | None = None, logprob: float | None = None) -> bool:
"""Run the buzzer logic."""
if logprob is not None and prob is not None:
raise ValueError("Cannot provide both logprob and prob")
if self.prob_threshold is None:
return confidence >= self.confidence_threshold
if logprob is None and prob is None:
raise ValueError("Must provide either logprob or prob if prob_threshold is not None")
prob = prob or float(np.exp(logprob))
if self.method == BuzzerMethod.AND:
return confidence >= self.confidence_threshold and prob >= self.prob_threshold
elif self.method == BuzzerMethod.OR:
return confidence >= self.confidence_threshold or prob >= self.prob_threshold
else:
raise ValueError(f"Invalid buzzer method: {self.method}")
@model_validator(mode="after")
def validate_method_with_log_prob(cls, data):
"""Validate that if prob_threshold is None, method must be 'and'."""
if data.prob_threshold is None and data.method != BuzzerMethod.AND:
raise ValueError("If prob_threshold is None, method must be 'and'")
return data
class TossupWorkflow(Workflow):
"""Workflow specialized for tossup questions with buzzing capability."""
buzzer: Buzzer
def get_answer_model(self, answer_var: str | None = None) -> str | None:
answer_var = answer_var or self.outputs["answer"]
if answer_var is None:
return None
step_id = answer_var.split(".")[0]
return self.steps[step_id].get_full_model_name()
def is_token_probs_supported(self, answer_var: str | None = None) -> bool:
model_name = self.get_answer_model(answer_var)
if model_name is None:
return True
return AVAILABLE_MODELS[model_name].get("logprobs", False)
def update_buzzer(self, buzzer: Buzzer) -> "TossupWorkflow":
"""Update the buzzer."""
return self.model_copy(update={"buzzer": buzzer})
def refresh_buzzer(self) -> "TossupWorkflow":
if not self.is_token_probs_supported():
return self.update_buzzer(self.buzzer.update(prob_threshold=None, method="AND"))
return self
|