|
from typing import Any, Literal |
|
|
|
from pydantic import BaseModel, Field, model_validator |
|
|
|
from workflows.structs import ModelStep, TossupWorkflow, Workflow |
|
|
|
|
|
def make_step_id(step_number: int): |
|
"""Make a step id from a step name.""" |
|
if step_number < 26: |
|
return chr(ord("A") + step_number) |
|
else: |
|
|
|
first_char = chr(ord("A") + (step_number // 26) - 1) |
|
second_char = chr(ord("A") + (step_number % 26)) |
|
return f"{first_char}{second_char}" |
|
|
|
|
|
def make_step_number(step_id: str): |
|
"""Make a step number from a step id.""" |
|
if len(step_id) == 1: |
|
return ord(step_id) - ord("A") |
|
else: |
|
return (ord(step_id[0]) - ord("A")) * 26 + (ord(step_id[1]) - ord("A")) + 1 |
|
|
|
|
|
class ModelStepUIState(BaseModel): |
|
"""Represents the UI state for a model step component.""" |
|
|
|
expanded: bool = True |
|
active_tab: Literal["model-tab", "inputs-tab", "outputs-tab"] = "model-tab" |
|
|
|
class Config: |
|
frozen = True |
|
|
|
def update(self, key: str, value: Any) -> "ModelStepUIState": |
|
"""Update the UI state.""" |
|
return self.model_copy(update={key: value}) |
|
|
|
|
|
class PipelineUIState(BaseModel): |
|
"""Represents the UI state for a pipeline component.""" |
|
|
|
step_ids: list[str] = Field(default_factory=list) |
|
steps: dict[str, ModelStepUIState] = Field(default_factory=dict) |
|
|
|
def model_post_init(self, __context: Any) -> None: |
|
if not self.steps and self.step_ids: |
|
self.steps = {step_id: ModelStepUIState() for step_id in self.step_ids} |
|
return super().model_post_init(__context) |
|
|
|
def get_step_position(self, step_id: str): |
|
"""Get the position of a step in the pipeline.""" |
|
return next((i for i, step in enumerate(self.step_ids) if step == step_id), None) |
|
|
|
@property |
|
def n_steps(self) -> int: |
|
"""Get the number of steps in the pipeline.""" |
|
return len(self.step_ids) |
|
|
|
@classmethod |
|
def from_workflow(cls, workflow: Workflow): |
|
"""Create a pipeline UI state from a workflow.""" |
|
return PipelineUIState( |
|
step_ids=list(workflow.steps.keys()), |
|
steps={step_id: ModelStepUIState() for step_id in workflow.steps.keys()}, |
|
) |
|
|
|
@classmethod |
|
def from_pipeline_state(cls, pipeline_state: "PipelineState"): |
|
"""Create a pipeline UI state from a pipeline state.""" |
|
return cls.from_workflow(pipeline_state.workflow) |
|
|
|
|
|
|
|
def insert_step(self, step_id: str, position: int = -1) -> "PipelineUIState": |
|
"""Insert a step into the pipeline at the given position.""" |
|
if position == -1: |
|
position = len(self.step_ids) |
|
self.step_ids.insert(position, step_id) |
|
steps = self.steps | {step_id: ModelStepUIState()} |
|
return self.model_copy(update={"step_ids": self.step_ids, "steps": steps}) |
|
|
|
def remove_step(self, step_id: str) -> "PipelineUIState": |
|
"""Remove a step from the pipeline.""" |
|
if step_id not in self.step_ids: |
|
raise ValueError(f"Step {step_id} not found in pipeline. Step IDs: {self.step_ids}") |
|
self.step_ids.remove(step_id) |
|
self.steps.pop(step_id) |
|
return self.model_copy(update={"step_ids": self.step_ids, "steps": self.steps}) |
|
|
|
def update_step(self, step_id: str, ui_state: ModelStepUIState) -> "PipelineUIState": |
|
"""Update a step in the pipeline.""" |
|
if step_id not in self.steps: |
|
raise ValueError(f"Step {step_id} not found in pipeline. Step IDs: {self.step_ids}") |
|
return self.model_copy(update={"steps": self.steps | {step_id: ui_state}}) |
|
|
|
|
|
class PipelineState(BaseModel): |
|
"""Represents the state for a pipeline component.""" |
|
|
|
workflow: Workflow |
|
ui_state: PipelineUIState |
|
|
|
@classmethod |
|
def from_workflow(cls, workflow: Workflow): |
|
"""Create a pipeline state from a workflow.""" |
|
return cls(workflow=workflow, ui_state=PipelineUIState.from_workflow(workflow)) |
|
|
|
def update_workflow(self, workflow: Workflow) -> "PipelineState": |
|
return self.model_copy(update={"workflow": workflow}) |
|
|
|
def insert_step(self, position: int, step: ModelStep) -> "PipelineState": |
|
if step.id in self.workflow.steps: |
|
raise ValueError(f"Step {step.id} already exists in pipeline") |
|
|
|
|
|
if position != -1 and (position < 0 or position > self.n_steps): |
|
raise ValueError(f"Invalid position: {position}. Must be between 0 and {self.n_steps} or -1") |
|
|
|
|
|
workflow = self.workflow.add_step(step) |
|
|
|
self.ui_state = self.ui_state.insert_step(step.id, position) |
|
|
|
|
|
return self.model_copy(update={"workflow": workflow, "ui_state": self.ui_state}) |
|
|
|
def remove_step(self, position: int) -> "PipelineState": |
|
step_id = self.ui_state.step_ids[position] |
|
|
|
workflow = self.workflow.remove_step(step_id) |
|
ui_state = self.ui_state.remove_step(step_id) |
|
return self.model_copy(update={"workflow": workflow, "ui_state": ui_state}) |
|
|
|
def update_step(self, step: ModelStep, ui_state: ModelStepUIState | None = None) -> "PipelineState": |
|
"""Update a step in the pipeline.""" |
|
if step.id not in self.workflow.steps: |
|
raise ValueError(f"Step {step.id} not found in pipeline") |
|
workflow = self.workflow.update_step(step) |
|
update = {"workflow": workflow} |
|
if ui_state is not None: |
|
update["ui_state"] = self.ui_state.update_step(step.id, ui_state) |
|
return self.model_copy(update=update) |
|
|
|
def get_available_variables(self, model_step_id: str | None = None) -> list[str]: |
|
"""Get all variables from all steps.""" |
|
available_variables = self.available_variables |
|
if model_step_id is None: |
|
return available_variables |
|
prefix = f"{model_step_id}." |
|
return [var for var in available_variables if not var.startswith(prefix)] |
|
|
|
@property |
|
def available_variables(self) -> list[str]: |
|
return self.workflow.get_available_variables() |
|
|
|
@property |
|
def n_steps(self) -> int: |
|
return len(self.workflow.steps) |
|
|
|
def get_new_step_id(self) -> str: |
|
"""Get a step ID for a new step.""" |
|
if not self.workflow.steps: |
|
return "A" |
|
else: |
|
last_step_number = max(map(make_step_number, self.workflow.steps.keys())) |
|
return make_step_id(last_step_number + 1) |
|
|
|
|
|
class TossupPipelineState(PipelineState): |
|
workflow: TossupWorkflow |
|
|