Maharshi Gor
Bugfix: double remove of step_id in UI state
c225678
raw
history blame
6.65 kB
from typing import Any, Literal
from pydantic import BaseModel, Field, model_validator
from workflows.structs import ModelStep, TossupWorkflow, Workflow
def make_step_id(step_number: int):
"""Make a step id from a step name."""
if step_number < 26:
return chr(ord("A") + step_number)
else:
# For more than 26 steps, use AA, AB, AC, etc.
first_char = chr(ord("A") + (step_number // 26) - 1)
second_char = chr(ord("A") + (step_number % 26))
return f"{first_char}{second_char}"
def make_step_number(step_id: str):
"""Make a step number from a step id."""
if len(step_id) == 1:
return ord(step_id) - ord("A")
else:
return (ord(step_id[0]) - ord("A")) * 26 + (ord(step_id[1]) - ord("A")) + 1
class ModelStepUIState(BaseModel):
"""Represents the UI state for a model step component."""
expanded: bool = True
active_tab: Literal["model-tab", "inputs-tab", "outputs-tab"] = "model-tab"
class Config:
frozen = True
def update(self, key: str, value: Any) -> "ModelStepUIState":
"""Update the UI state."""
return self.model_copy(update={key: value})
class PipelineUIState(BaseModel):
"""Represents the UI state for a pipeline component."""
step_ids: list[str] = Field(default_factory=list)
steps: dict[str, ModelStepUIState] = Field(default_factory=dict)
def model_post_init(self, __context: Any) -> None:
if not self.steps and self.step_ids:
self.steps = {step_id: ModelStepUIState() for step_id in self.step_ids}
return super().model_post_init(__context)
def get_step_position(self, step_id: str):
"""Get the position of a step in the pipeline."""
return next((i for i, step in enumerate(self.step_ids) if step == step_id), None)
@property
def n_steps(self) -> int:
"""Get the number of steps in the pipeline."""
return len(self.step_ids)
@classmethod
def from_workflow(cls, workflow: Workflow):
"""Create a pipeline UI state from a workflow."""
return PipelineUIState(
step_ids=list(workflow.steps.keys()),
steps={step_id: ModelStepUIState() for step_id in workflow.steps.keys()},
)
@classmethod
def from_pipeline_state(cls, pipeline_state: "PipelineState"):
"""Create a pipeline UI state from a pipeline state."""
return cls.from_workflow(pipeline_state.workflow)
# Update methods
def insert_step(self, step_id: str, position: int = -1) -> "PipelineUIState":
"""Insert a step into the pipeline at the given position."""
if position == -1:
position = len(self.step_ids)
self.step_ids.insert(position, step_id)
steps = self.steps | {step_id: ModelStepUIState()}
return self.model_copy(update={"step_ids": self.step_ids, "steps": steps})
def remove_step(self, step_id: str) -> "PipelineUIState":
"""Remove a step from the pipeline."""
if step_id not in self.step_ids:
raise ValueError(f"Step {step_id} not found in pipeline. Step IDs: {self.step_ids}")
self.step_ids.remove(step_id)
self.steps.pop(step_id)
return self.model_copy(update={"step_ids": self.step_ids, "steps": self.steps})
def update_step(self, step_id: str, ui_state: ModelStepUIState) -> "PipelineUIState":
"""Update a step in the pipeline."""
if step_id not in self.steps:
raise ValueError(f"Step {step_id} not found in pipeline. Step IDs: {self.step_ids}")
return self.model_copy(update={"steps": self.steps | {step_id: ui_state}})
class PipelineState(BaseModel):
"""Represents the state for a pipeline component."""
workflow: Workflow
ui_state: PipelineUIState
@classmethod
def from_workflow(cls, workflow: Workflow):
"""Create a pipeline state from a workflow."""
return cls(workflow=workflow, ui_state=PipelineUIState.from_workflow(workflow))
def update_workflow(self, workflow: Workflow) -> "PipelineState":
return self.model_copy(update={"workflow": workflow})
def insert_step(self, position: int, step: ModelStep) -> "PipelineState":
if step.id in self.workflow.steps:
raise ValueError(f"Step {step.id} already exists in pipeline")
# Validate position
if position != -1 and (position < 0 or position > self.n_steps):
raise ValueError(f"Invalid position: {position}. Must be between 0 and {self.n_steps} or -1")
# Create a new workflow with updated steps
workflow = self.workflow.add_step(step)
self.ui_state = self.ui_state.insert_step(step.id, position)
# Return a new PipelineState with the updated workflow
return self.model_copy(update={"workflow": workflow, "ui_state": self.ui_state})
def remove_step(self, position: int) -> "PipelineState":
step_id = self.ui_state.step_ids[position]
workflow = self.workflow.remove_step(step_id)
ui_state = self.ui_state.remove_step(step_id)
return self.model_copy(update={"workflow": workflow, "ui_state": ui_state})
def update_step(self, step: ModelStep, ui_state: ModelStepUIState | None = None) -> "PipelineState":
"""Update a step in the pipeline."""
if step.id not in self.workflow.steps:
raise ValueError(f"Step {step.id} not found in pipeline")
workflow = self.workflow.update_step(step)
update = {"workflow": workflow}
if ui_state is not None:
update["ui_state"] = self.ui_state.update_step(step.id, ui_state)
return self.model_copy(update=update)
def get_available_variables(self, model_step_id: str | None = None) -> list[str]:
"""Get all variables from all steps."""
available_variables = self.available_variables
if model_step_id is None:
return available_variables
prefix = f"{model_step_id}."
return [var for var in available_variables if not var.startswith(prefix)]
@property
def available_variables(self) -> list[str]:
return self.workflow.get_available_variables()
@property
def n_steps(self) -> int:
return len(self.workflow.steps)
def get_new_step_id(self) -> str:
"""Get a step ID for a new step."""
if not self.workflow.steps:
return "A"
else:
last_step_number = max(map(make_step_number, self.workflow.steps.keys()))
return make_step_id(last_step_number + 1)
class TossupPipelineState(PipelineState):
workflow: TossupWorkflow