File size: 14,038 Bytes
193db9d 9756440 0bab47c 193db9d 0bab47c 193db9d d0ae1a9 193db9d 0bab47c 193db9d 9756440 193db9d 9756440 193db9d 0bab47c 193db9d 0bab47c 193db9d 02b7dec 193db9d 4f5d1cb 193db9d 9756440 193db9d 9756440 193db9d 9756440 193db9d 9756440 193db9d e00ec4e 193db9d 38e3800 193db9d 9756440 d0ae1a9 9756440 da814b0 9756440 193db9d 0bab47c 02b7dec 9756440 02b7dec d0ae1a9 0bab47c d0ae1a9 0bab47c 9756440 d0ae1a9 9756440 d0ae1a9 |
|
# %%
from copy import deepcopy
from enum import Enum
from typing import Any, Literal, Optional
import numpy as np
from pydantic import BaseModel, Field, model_validator
from .configs import AVAILABLE_MODELS
"""
Core data structures for defining workflows and their components.
This module defines the primary classes used to model workflows, steps, and their
input/output fields. These data structures serve as the foundation for workflow
definition, validation, and execution throughout the workflows package.
The primary components are:
- InputField: Represents an input to a model step with name and source variable
- OutputField: Represents an output from a model step with name and type
- ModelStep: Represents a single step in a workflow with inputs and outputs
- Workflow: A collection of interconnected steps with defined inputs and outputs
All classes use Pydantic's BaseModel for validation and serialization support.
"""
FieldType = Literal["input", "output"]
SUPPORTED_TYPES = Literal["str", "int", "float", "bool", "list[str]", "list[int]", "list[float]", "list[bool]"]
"""Supported field types for input and output fields"""
class InputField(BaseModel):
"""
Defines an input field for a model step.
An input field specifies what data a step requires, where it comes from,
and optional pre-processing to apply before use.
Attributes:
name: The name of the input field within the step's context
description: Human-readable description of the input's purpose
variable: Reference to the source variable (format: "{step_id}.{field_name}" or external input name)
func: Optional function name to transform the input value before use
"""
name: str
description: str
variable: str
# function to call on the input before passing it to the model
func: str | None = None
class Config:
frozen = True
class OutputField(BaseModel):
"""
Defines an output field produced by a model step.
An output field specifies a value that the step will produce, including
its data type and optional post-processing.
Attributes:
name: The name of the output field within the step's context
description: Human-readable description of the output's purpose
type: The data type of the output (one of SUPPORTED_TYPES)
func: Optional function name to transform the raw output value
"""
name: str
type: SUPPORTED_TYPES = Field(default="str")
description: str
# function to call on the output string from the model
func: str | None = None
class Config:
frozen = True
class CallType(str, Enum):
LLM = "llm"
SEARCH = "search"
PYTHON_FUNC = "python_func"
class ModelStep(BaseModel):
"""
Represents a single step in a workflow.
A model step encapsulates the details of a specific operation within a workflow,
including what model to use, what inputs it requires, and what outputs it produces.
Attributes:
id: Unique identifier for this step within a workflow
model: The model to use for this step (e.g., "gpt-4")
provider: The provider of the model (e.g., "openai")
call_type: The type of operation (e.g., "llm", "search")
system_prompt: Instructions for the model
input_fields: List of input fields required by this step
output_fields: List of output fields produced by this step
"""
id: str
name: str
model: str
provider: str
call_type: CallType = CallType.LLM
# TODO: Validate that this is not None for call_type = llm
temperature: Optional[float] = None
system_prompt: str
input_fields: list[InputField]
output_fields: list[OutputField]
class Config:
use_enum_values = True
def fields(self, field_type: FieldType) -> list[InputField | OutputField]:
return self.input_fields if field_type == "input" else self.output_fields
def get_full_model_name(self) -> str:
return f"{self.provider}/{self.model}"
def get_produced_variables(self) -> list[str]:
return [f"{self.id}.{field.name}" for field in self.output_fields if field.name]
def update(self, update: dict[str, Any]) -> "ModelStep":
"""Returns a new copy with the updated properties."""
return self.model_copy(update=update)
def update_property(self, field: str, value: Any) -> "ModelStep":
"Update the `field` key of the model step with `value`."
return self.update({field: value})
def update_field(self, field_type: FieldType, index: int, key: str, value: str) -> "ModelStep":
"""Update a specific field of an input or output field at the given index."""
if field_type == "input":
fields = self.input_fields
elif field_type == "output":
fields = self.output_fields
else:
raise ValueError(f"Invalid field type: {field_type}")
if index < len(fields):
fields[index] = fields[index].model_copy(update={key: value})
return self.model_copy()
@staticmethod
def create_new_field(field_type: FieldType, input_var: str | None = None) -> InputField | OutputField:
if field_type == "input":
return InputField(name="", description="", variable=input_var)
elif field_type == "output":
return OutputField(name="", description="")
else:
raise ValueError(f"Invalid field type: {field_type}")
def add_field(self, field_type: FieldType, index: int = -1, input_var: str | None = None) -> "ModelStep":
"""Add a new field to the state and update visibility.
Args:
field_type: Type of field to add ('input' or 'output').
index: Position to insert the new field (-1 to append).
Returns:
A new ModelStep with the updated fields.
"""
if field_type == "input":
fields = deepcopy(self.input_fields)
new_field = ModelStep.create_new_field(field_type, input_var)
fields.insert(index + 1, new_field) if index != -1 else fields.append(new_field)
return self.model_copy(update={"input_fields": fields})
else:
fields = deepcopy(self.output_fields)
new_field = ModelStep.create_new_field(field_type)
fields.insert(index + 1, new_field) if index != -1 else fields.append(new_field)
return self.model_copy(update={"output_fields": fields})
def delete_field(self, field_type: FieldType, index: int) -> "ModelStep":
"""
Delete an input or output field from the state and update visibility.
Args:
field_type: Type of field to delete ('input' or 'output').
index: Index of the field to delete. [-1 to delete the last field]
Returns:
A new ModelStep with the updated fields.
"""
fields = self.input_fields if field_type == "input" else self.output_fields
fields = deepcopy(fields)
fields.pop(index)
return self.model_copy(update={"input_fields": fields} if field_type == "input" else {"output_fields": fields})
class Workflow(BaseModel):
"""
Represents a complete workflow composed of interconnected steps.
A workflow defines a directed acyclic graph of model steps, where outputs
from earlier steps can be used as inputs to later steps.
Attributes:
inputs: List of input variables required by the workflow
outputs: List of output variables produced by the workflow
steps: Dictionary mapping step IDs to ModelStep instances
The inputs and outputs lists use the format "{step_id}.{field_name}"
to uniquely identify variables within the workflow.
"""
# variables of form {node}.{field}
inputs: list[str] = Field(default_factory=list)
# variables of form {node}.{field}
outputs: dict[str, str | None] = Field(default_factory=dict)
steps: dict[str, ModelStep] = Field(default_factory=dict)
def model_dump(self, *args, **kwargs):
data = super().model_dump(*args, **kwargs)
if "steps" in data:
data["steps"] = list(data["steps"].values())
return data
@model_validator(mode="before")
def dictify_steps(cls, data):
if "steps" in data and isinstance(data["steps"], list):
steps_dict = {}
for step in data["steps"]:
if isinstance(step, ModelStep):
step_id = step.id
else:
step_id = step["id"]
if step_id in steps_dict:
raise ValueError(f"Duplicate step ID: {step_id}")
steps_dict[step_id] = step
data["steps"] = steps_dict
return data
def get_step_variables(self, step_id: str) -> list[str]:
"""Get all variables from a specific step."""
step = self.steps[step_id]
variables = []
for output in step.output_fields:
if output.name == "":
continue
output_var = f"{step.id}.{output.name}"
variables.append(output_var)
return variables
def get_available_variables(self) -> list[str]:
"""Get all output variables from all steps."""
variables = set(self.inputs)
for step in self.steps.values():
variables.update(self.get_step_variables(step.id))
return list(variables)
def get_model_selections(self) -> dict[str, str]:
"""Get all model selections for all steps."""
return {step_id: step.get_full_model_name() for step_id, step in self.steps.items()}
def get_output_model_selections(self) -> dict[str, str]:
"""Get all output model selections for all steps."""
return {
output_var: target_var.split(".")[0] if target_var else None
for output_var, target_var in self.outputs.items()
}
# Step update method
def add_step(self, step: ModelStep) -> "Workflow":
"""Add a step to the workflow."""
steps = self.steps | {step.id: step}
return self.model_copy(update={"steps": steps})
def remove_step(self, step_id: str) -> "Workflow":
"""Remove a step from the workflow."""
self.steps.pop(step_id)
workflow = self.model_copy(update={"steps": self.steps})
workflow.refresh_output_variables()
return workflow
def update_step(self, step: ModelStep) -> "Workflow":
"""Update a step in the workflow."""
self.steps[step.id] = step
steps = self.steps | {step.id: step}
workflow = self.model_copy(update={"steps": steps})
workflow.refresh_output_variables()
return workflow
# Output variables
def refresh_output_variables(self) -> "Workflow":
"""Refresh the output variables for the workflow."""
produced_variables = self.get_available_variables()
self.outputs = {k: (v if v in produced_variables else None) for k, v in self.outputs.items()}
return self
class BuzzerMethod(str, Enum):
AND = "AND"
OR = "OR"
class Buzzer(BaseModel):
"""Configuration for when to buzz in a tossup question."""
method: BuzzerMethod = BuzzerMethod.AND # Logic to combine thresholds
confidence_threshold: float = Field(default=0.8, ge=0.0, le=1.0) # Minimum confidence to trigger a buzz
prob_threshold: float | None = None # Optional log probability threshold
class Config:
use_enum_values = True
frozen = True
def update(self, **kwargs) -> "Buzzer":
"""Update the buzzer with the given kwargs."""
return self.model_copy(update=kwargs)
def run(self, confidence: float, prob: float | None = None, logprob: float | None = None) -> bool:
"""Run the buzzer logic."""
if logprob is not None and prob is not None:
raise ValueError("Cannot provide both logprob and prob")
if self.prob_threshold is None:
return confidence >= self.confidence_threshold
if logprob is None and prob is None:
raise ValueError("Must provide either logprob or prob if prob_threshold is not None")
prob = prob or float(np.exp(logprob))
if self.method == BuzzerMethod.AND:
return confidence >= self.confidence_threshold and prob >= self.prob_threshold
elif self.method == BuzzerMethod.OR:
return confidence >= self.confidence_threshold or prob >= self.prob_threshold
else:
raise ValueError(f"Invalid buzzer method: {self.method}")
@model_validator(mode="after")
def validate_method_with_log_prob(cls, data):
"""Validate that if prob_threshold is None, method must be 'and'."""
if data.prob_threshold is None and data.method != BuzzerMethod.AND:
raise ValueError("If prob_threshold is None, method must be 'and'")
return data
class TossupWorkflow(Workflow):
"""Workflow specialized for tossup questions with buzzing capability."""
buzzer: Buzzer
def get_answer_model(self, answer_var: str | None = None) -> str | None:
answer_var = answer_var or self.outputs["answer"]
if answer_var is None:
return None
step_id = answer_var.split(".")[0]
return self.steps[step_id].get_full_model_name()
def is_token_probs_supported(self, answer_var: str | None = None) -> bool:
model_name = self.get_answer_model(answer_var)
if model_name is None:
return True
return AVAILABLE_MODELS[model_name].get("logprobs", False)
def update_buzzer(self, buzzer: Buzzer) -> "TossupWorkflow":
"""Update the buzzer."""
return self.model_copy(update={"buzzer": buzzer})
def refresh_buzzer(self) -> "TossupWorkflow":
if not self.is_token_probs_supported():
return self.update_buzzer(self.buzzer.update(prob_threshold=None, method="AND"))
return self
|