from typing import Any, Literal from pydantic import BaseModel, Field, model_validator from workflows.structs import ModelStep, TossupWorkflow, Workflow def make_step_id(step_number: int): """Make a step id from a step name.""" if step_number < 26: return chr(ord("A") + step_number) else: # For more than 26 steps, use AA, AB, AC, etc. first_char = chr(ord("A") + (step_number // 26) - 1) second_char = chr(ord("A") + (step_number % 26)) return f"{first_char}{second_char}" def make_step_number(step_id: str): """Make a step number from a step id.""" if len(step_id) == 1: return ord(step_id) - ord("A") else: return (ord(step_id[0]) - ord("A")) * 26 + (ord(step_id[1]) - ord("A")) + 1 class ModelStepUIState(BaseModel): """Represents the UI state for a model step component.""" expanded: bool = True active_tab: Literal["model-tab", "inputs-tab", "outputs-tab"] = "model-tab" class Config: frozen = True def update(self, key: str, value: Any) -> "ModelStepUIState": """Update the UI state.""" return self.model_copy(update={key: value}) class PipelineUIState(BaseModel): """Represents the UI state for a pipeline component.""" step_ids: list[str] = Field(default_factory=list) steps: dict[str, ModelStepUIState] = Field(default_factory=dict) def model_post_init(self, __context: Any) -> None: if not self.steps and self.step_ids: self.steps = {step_id: ModelStepUIState() for step_id in self.step_ids} return super().model_post_init(__context) def get_step_position(self, step_id: str): """Get the position of a step in the pipeline.""" return next((i for i, step in enumerate(self.step_ids) if step == step_id), None) @property def n_steps(self) -> int: """Get the number of steps in the pipeline.""" return len(self.step_ids) @classmethod def from_workflow(cls, workflow: Workflow): """Create a pipeline UI state from a workflow.""" return PipelineUIState( step_ids=list(workflow.steps.keys()), steps={step_id: ModelStepUIState() for step_id in workflow.steps.keys()}, ) @classmethod def from_pipeline_state(cls, pipeline_state: "PipelineState"): """Create a pipeline UI state from a pipeline state.""" return cls.from_workflow(pipeline_state.workflow) # Update methods def insert_step(self, step_id: str, position: int = -1) -> "PipelineUIState": """Insert a step into the pipeline at the given position.""" if position == -1: position = len(self.step_ids) self.step_ids.insert(position, step_id) steps = self.steps | {step_id: ModelStepUIState()} return self.model_copy(update={"step_ids": self.step_ids, "steps": steps}) def remove_step(self, step_id: str) -> "PipelineUIState": """Remove a step from the pipeline.""" if step_id not in self.step_ids: raise ValueError(f"Step {step_id} not found in pipeline. Step IDs: {self.step_ids}") self.step_ids.remove(step_id) self.steps.pop(step_id) return self.model_copy(update={"step_ids": self.step_ids, "steps": self.steps}) def update_step(self, step_id: str, ui_state: ModelStepUIState) -> "PipelineUIState": """Update a step in the pipeline.""" if step_id not in self.steps: raise ValueError(f"Step {step_id} not found in pipeline. Step IDs: {self.step_ids}") return self.model_copy(update={"steps": self.steps | {step_id: ui_state}}) class PipelineState(BaseModel): """Represents the state for a pipeline component.""" workflow: Workflow ui_state: PipelineUIState @classmethod def from_workflow(cls, workflow: Workflow): """Create a pipeline state from a workflow.""" return cls(workflow=workflow, ui_state=PipelineUIState.from_workflow(workflow)) def update_workflow(self, workflow: Workflow) -> "PipelineState": return self.model_copy(update={"workflow": workflow}) def insert_step(self, position: int, step: ModelStep) -> "PipelineState": if step.id in self.workflow.steps: raise ValueError(f"Step {step.id} already exists in pipeline") # Validate position if position != -1 and (position < 0 or position > self.n_steps): raise ValueError(f"Invalid position: {position}. Must be between 0 and {self.n_steps} or -1") # Create a new workflow with updated steps workflow = self.workflow.add_step(step) self.ui_state = self.ui_state.insert_step(step.id, position) # Return a new PipelineState with the updated workflow return self.model_copy(update={"workflow": workflow, "ui_state": self.ui_state}) def remove_step(self, position: int) -> "PipelineState": step_id = self.ui_state.step_ids[position] workflow = self.workflow.remove_step(step_id) ui_state = self.ui_state.remove_step(step_id) return self.model_copy(update={"workflow": workflow, "ui_state": ui_state}) def update_step(self, step: ModelStep, ui_state: ModelStepUIState | None = None) -> "PipelineState": """Update a step in the pipeline.""" if step.id not in self.workflow.steps: raise ValueError(f"Step {step.id} not found in pipeline") workflow = self.workflow.update_step(step) update = {"workflow": workflow} if ui_state is not None: update["ui_state"] = self.ui_state.update_step(step.id, ui_state) return self.model_copy(update=update) def get_available_variables(self, model_step_id: str | None = None) -> list[str]: """Get all variables from all steps.""" available_variables = self.available_variables if model_step_id is None: return available_variables prefix = f"{model_step_id}." return [var for var in available_variables if not var.startswith(prefix)] @property def available_variables(self) -> list[str]: return self.workflow.get_available_variables() @property def n_steps(self) -> int: return len(self.workflow.steps) def get_new_step_id(self) -> str: """Get a step ID for a new step.""" if not self.workflow.steps: return "A" else: last_step_number = max(map(make_step_number, self.workflow.steps.keys())) return make_step_id(last_step_number + 1) class TossupPipelineState(PipelineState): workflow: TossupWorkflow