File size: 14,038 Bytes
193db9d
9756440
0bab47c
193db9d
 
0bab47c
193db9d
 
d0ae1a9
 
193db9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bab47c
193db9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9756440
 
 
193db9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9756440
 
 
193db9d
0bab47c
 
 
 
 
 
193db9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bab47c
193db9d
 
 
 
 
 
 
 
02b7dec
 
 
193db9d
 
 
4f5d1cb
 
193db9d
 
 
 
 
9756440
193db9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9756440
 
 
 
 
 
 
 
 
 
193db9d
 
 
 
 
 
 
 
 
 
 
 
9756440
 
193db9d
9756440
193db9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e00ec4e
 
193db9d
 
 
 
 
 
 
38e3800
 
 
 
 
 
 
193db9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9756440
 
 
 
d0ae1a9
 
 
 
 
 
 
9756440
 
 
 
 
 
 
 
 
 
da814b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9756440
193db9d
0bab47c
 
 
 
 
 
 
 
 
 
 
 
02b7dec
 
9756440
02b7dec
d0ae1a9
 
 
 
0bab47c
 
 
 
 
 
d0ae1a9
 
 
0bab47c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9756440
d0ae1a9
 
 
 
 
 
 
 
 
 
 
 
 
9756440
 
 
d0ae1a9
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
# %%
from copy import deepcopy
from enum import Enum
from typing import Any, Literal, Optional

import numpy as np
from pydantic import BaseModel, Field, model_validator

from .configs import AVAILABLE_MODELS

"""
Core data structures for defining workflows and their components.

This module defines the primary classes used to model workflows, steps, and their
input/output fields. These data structures serve as the foundation for workflow
definition, validation, and execution throughout the workflows package.

The primary components are:
- InputField: Represents an input to a model step with name and source variable
- OutputField: Represents an output from a model step with name and type
- ModelStep: Represents a single step in a workflow with inputs and outputs
- Workflow: A collection of interconnected steps with defined inputs and outputs

All classes use Pydantic's BaseModel for validation and serialization support.
"""
FieldType = Literal["input", "output"]


SUPPORTED_TYPES = Literal["str", "int", "float", "bool", "list[str]", "list[int]", "list[float]", "list[bool]"]
"""Supported field types for input and output fields"""


class InputField(BaseModel):
    """
    Defines an input field for a model step.

    An input field specifies what data a step requires, where it comes from,
    and optional pre-processing to apply before use.

    Attributes:
        name: The name of the input field within the step's context
        description: Human-readable description of the input's purpose
        variable: Reference to the source variable (format: "{step_id}.{field_name}" or external input name)
        func: Optional function name to transform the input value before use
    """

    name: str
    description: str
    variable: str

    # function to call on the input before passing it to the model
    func: str | None = None

    class Config:
        frozen = True


class OutputField(BaseModel):
    """
    Defines an output field produced by a model step.

    An output field specifies a value that the step will produce, including
    its data type and optional post-processing.

    Attributes:
        name: The name of the output field within the step's context
        description: Human-readable description of the output's purpose
        type: The data type of the output (one of SUPPORTED_TYPES)
        func: Optional function name to transform the raw output value
    """

    name: str
    type: SUPPORTED_TYPES = Field(default="str")
    description: str

    # function to call on the output string from the model
    func: str | None = None

    class Config:
        frozen = True


class CallType(str, Enum):
    LLM = "llm"
    SEARCH = "search"
    PYTHON_FUNC = "python_func"


class ModelStep(BaseModel):
    """
    Represents a single step in a workflow.

    A model step encapsulates the details of a specific operation within a workflow,
    including what model to use, what inputs it requires, and what outputs it produces.

    Attributes:
        id: Unique identifier for this step within a workflow
        model: The model to use for this step (e.g., "gpt-4")
        provider: The provider of the model (e.g., "openai")
        call_type: The type of operation (e.g., "llm", "search")
        system_prompt: Instructions for the model
        input_fields: List of input fields required by this step
        output_fields: List of output fields produced by this step
    """

    id: str
    name: str
    model: str
    provider: str
    call_type: CallType = CallType.LLM

    # TODO: Validate that this is not None for call_type = llm
    temperature: Optional[float] = None

    system_prompt: str
    input_fields: list[InputField]
    output_fields: list[OutputField]

    class Config:
        use_enum_values = True

    def fields(self, field_type: FieldType) -> list[InputField | OutputField]:
        return self.input_fields if field_type == "input" else self.output_fields

    def get_full_model_name(self) -> str:
        return f"{self.provider}/{self.model}"

    def get_produced_variables(self) -> list[str]:
        return [f"{self.id}.{field.name}" for field in self.output_fields if field.name]

    def update(self, update: dict[str, Any]) -> "ModelStep":
        """Returns a new copy with the updated properties."""
        return self.model_copy(update=update)

    def update_property(self, field: str, value: Any) -> "ModelStep":
        "Update the `field` key of the model step with `value`."
        return self.update({field: value})

    def update_field(self, field_type: FieldType, index: int, key: str, value: str) -> "ModelStep":
        """Update a specific field of an input or output field at the given index."""
        if field_type == "input":
            fields = self.input_fields
        elif field_type == "output":
            fields = self.output_fields
        else:
            raise ValueError(f"Invalid field type: {field_type}")

        if index < len(fields):
            fields[index] = fields[index].model_copy(update={key: value})
        return self.model_copy()

    @staticmethod
    def create_new_field(field_type: FieldType, input_var: str | None = None) -> InputField | OutputField:
        if field_type == "input":
            return InputField(name="", description="", variable=input_var)
        elif field_type == "output":
            return OutputField(name="", description="")
        else:
            raise ValueError(f"Invalid field type: {field_type}")

    def add_field(self, field_type: FieldType, index: int = -1, input_var: str | None = None) -> "ModelStep":
        """Add a new field to the state and update visibility.

        Args:
            field_type: Type of field to add ('input' or 'output').
            index: Position to insert the new field (-1 to append).
        Returns:
            A new ModelStep with the updated fields.
        """
        if field_type == "input":
            fields = deepcopy(self.input_fields)
            new_field = ModelStep.create_new_field(field_type, input_var)
            fields.insert(index + 1, new_field) if index != -1 else fields.append(new_field)
            return self.model_copy(update={"input_fields": fields})
        else:
            fields = deepcopy(self.output_fields)
            new_field = ModelStep.create_new_field(field_type)
            fields.insert(index + 1, new_field) if index != -1 else fields.append(new_field)
            return self.model_copy(update={"output_fields": fields})

    def delete_field(self, field_type: FieldType, index: int) -> "ModelStep":
        """
        Delete an input or output field from the state and update visibility.

        Args:
            field_type: Type of field to delete ('input' or 'output').
            index: Index of the field to delete. [-1 to delete the last field]

        Returns:
            A new ModelStep with the updated fields.
        """
        fields = self.input_fields if field_type == "input" else self.output_fields
        fields = deepcopy(fields)
        fields.pop(index)
        return self.model_copy(update={"input_fields": fields} if field_type == "input" else {"output_fields": fields})


class Workflow(BaseModel):
    """
    Represents a complete workflow composed of interconnected steps.

    A workflow defines a directed acyclic graph of model steps, where outputs
    from earlier steps can be used as inputs to later steps.

    Attributes:
        inputs: List of input variables required by the workflow
        outputs: List of output variables produced by the workflow
        steps: Dictionary mapping step IDs to ModelStep instances

    The inputs and outputs lists use the format "{step_id}.{field_name}"
    to uniquely identify variables within the workflow.
    """

    # variables of form {node}.{field}
    inputs: list[str] = Field(default_factory=list)

    # variables of form {node}.{field}
    outputs: dict[str, str | None] = Field(default_factory=dict)
    steps: dict[str, ModelStep] = Field(default_factory=dict)

    def model_dump(self, *args, **kwargs):
        data = super().model_dump(*args, **kwargs)
        if "steps" in data:
            data["steps"] = list(data["steps"].values())
        return data

    @model_validator(mode="before")
    def dictify_steps(cls, data):
        if "steps" in data and isinstance(data["steps"], list):
            steps_dict = {}
            for step in data["steps"]:
                if isinstance(step, ModelStep):
                    step_id = step.id
                else:
                    step_id = step["id"]
                if step_id in steps_dict:
                    raise ValueError(f"Duplicate step ID: {step_id}")
                steps_dict[step_id] = step
            data["steps"] = steps_dict
        return data

    def get_step_variables(self, step_id: str) -> list[str]:
        """Get all variables from a specific step."""
        step = self.steps[step_id]
        variables = []
        for output in step.output_fields:
            if output.name == "":
                continue
            output_var = f"{step.id}.{output.name}"
            variables.append(output_var)
        return variables

    def get_available_variables(self) -> list[str]:
        """Get all output variables from all steps."""
        variables = set(self.inputs)
        for step in self.steps.values():
            variables.update(self.get_step_variables(step.id))
        return list(variables)

    def get_model_selections(self) -> dict[str, str]:
        """Get all model selections for all steps."""
        return {step_id: step.get_full_model_name() for step_id, step in self.steps.items()}

    def get_output_model_selections(self) -> dict[str, str]:
        """Get all output model selections for all steps."""
        return {
            output_var: target_var.split(".")[0] if target_var else None
            for output_var, target_var in self.outputs.items()
        }

    # Step update method

    def add_step(self, step: ModelStep) -> "Workflow":
        """Add a step to the workflow."""
        steps = self.steps | {step.id: step}
        return self.model_copy(update={"steps": steps})

    def remove_step(self, step_id: str) -> "Workflow":
        """Remove a step from the workflow."""
        self.steps.pop(step_id)
        workflow = self.model_copy(update={"steps": self.steps})
        workflow.refresh_output_variables()
        return workflow

    def update_step(self, step: ModelStep) -> "Workflow":
        """Update a step in the workflow."""
        self.steps[step.id] = step
        steps = self.steps | {step.id: step}
        workflow = self.model_copy(update={"steps": steps})
        workflow.refresh_output_variables()
        return workflow

    # Output variables
    def refresh_output_variables(self) -> "Workflow":
        """Refresh the output variables for the workflow."""
        produced_variables = self.get_available_variables()
        self.outputs = {k: (v if v in produced_variables else None) for k, v in self.outputs.items()}
        return self


class BuzzerMethod(str, Enum):
    AND = "AND"
    OR = "OR"


class Buzzer(BaseModel):
    """Configuration for when to buzz in a tossup question."""

    method: BuzzerMethod = BuzzerMethod.AND  # Logic to combine thresholds
    confidence_threshold: float = Field(default=0.8, ge=0.0, le=1.0)  # Minimum confidence to trigger a buzz
    prob_threshold: float | None = None  # Optional log probability threshold

    class Config:
        use_enum_values = True
        frozen = True

    def update(self, **kwargs) -> "Buzzer":
        """Update the buzzer with the given kwargs."""
        return self.model_copy(update=kwargs)

    def run(self, confidence: float, prob: float | None = None, logprob: float | None = None) -> bool:
        """Run the buzzer logic."""
        if logprob is not None and prob is not None:
            raise ValueError("Cannot provide both logprob and prob")
        if self.prob_threshold is None:
            return confidence >= self.confidence_threshold
        if logprob is None and prob is None:
            raise ValueError("Must provide either logprob or prob if prob_threshold is not None")
        prob = prob or float(np.exp(logprob))
        if self.method == BuzzerMethod.AND:
            return confidence >= self.confidence_threshold and prob >= self.prob_threshold
        elif self.method == BuzzerMethod.OR:
            return confidence >= self.confidence_threshold or prob >= self.prob_threshold
        else:
            raise ValueError(f"Invalid buzzer method: {self.method}")

    @model_validator(mode="after")
    def validate_method_with_log_prob(cls, data):
        """Validate that if prob_threshold is None, method must be 'and'."""
        if data.prob_threshold is None and data.method != BuzzerMethod.AND:
            raise ValueError("If prob_threshold is None, method must be 'and'")
        return data


class TossupWorkflow(Workflow):
    """Workflow specialized for tossup questions with buzzing capability."""

    buzzer: Buzzer

    def get_answer_model(self, answer_var: str | None = None) -> str | None:
        answer_var = answer_var or self.outputs["answer"]
        if answer_var is None:
            return None
        step_id = answer_var.split(".")[0]
        return self.steps[step_id].get_full_model_name()

    def is_token_probs_supported(self, answer_var: str | None = None) -> bool:
        model_name = self.get_answer_model(answer_var)
        if model_name is None:
            return True
        return AVAILABLE_MODELS[model_name].get("logprobs", False)

    def update_buzzer(self, buzzer: Buzzer) -> "TossupWorkflow":
        """Update the buzzer."""
        return self.model_copy(update={"buzzer": buzzer})

    def refresh_buzzer(self) -> "TossupWorkflow":
        if not self.is_token_probs_supported():
            return self.update_buzzer(self.buzzer.update(prob_threshold=None, method="AND"))
        return self