File size: 6,652 Bytes
9756440
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c225678
 
9756440
 
 
 
 
 
 
c225678
9756440
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c225678
9756440
 
da814b0
 
9756440
 
 
 
 
da814b0
 
9756440
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
from typing import Any, Literal

from pydantic import BaseModel, Field, model_validator

from workflows.structs import ModelStep, TossupWorkflow, Workflow


def make_step_id(step_number: int):
    """Make a step id from a step name."""
    if step_number < 26:
        return chr(ord("A") + step_number)
    else:
        # For more than 26 steps, use AA, AB, AC, etc.
        first_char = chr(ord("A") + (step_number // 26) - 1)
        second_char = chr(ord("A") + (step_number % 26))
        return f"{first_char}{second_char}"


def make_step_number(step_id: str):
    """Make a step number from a step id."""
    if len(step_id) == 1:
        return ord(step_id) - ord("A")
    else:
        return (ord(step_id[0]) - ord("A")) * 26 + (ord(step_id[1]) - ord("A")) + 1


class ModelStepUIState(BaseModel):
    """Represents the UI state for a model step component."""

    expanded: bool = True
    active_tab: Literal["model-tab", "inputs-tab", "outputs-tab"] = "model-tab"

    class Config:
        frozen = True

    def update(self, key: str, value: Any) -> "ModelStepUIState":
        """Update the UI state."""
        return self.model_copy(update={key: value})


class PipelineUIState(BaseModel):
    """Represents the UI state for a pipeline component."""

    step_ids: list[str] = Field(default_factory=list)
    steps: dict[str, ModelStepUIState] = Field(default_factory=dict)

    def model_post_init(self, __context: Any) -> None:
        if not self.steps and self.step_ids:
            self.steps = {step_id: ModelStepUIState() for step_id in self.step_ids}
        return super().model_post_init(__context)

    def get_step_position(self, step_id: str):
        """Get the position of a step in the pipeline."""
        return next((i for i, step in enumerate(self.step_ids) if step == step_id), None)

    @property
    def n_steps(self) -> int:
        """Get the number of steps in the pipeline."""
        return len(self.step_ids)

    @classmethod
    def from_workflow(cls, workflow: Workflow):
        """Create a pipeline UI state from a workflow."""
        return PipelineUIState(
            step_ids=list(workflow.steps.keys()),
            steps={step_id: ModelStepUIState() for step_id in workflow.steps.keys()},
        )

    @classmethod
    def from_pipeline_state(cls, pipeline_state: "PipelineState"):
        """Create a pipeline UI state from a pipeline state."""
        return cls.from_workflow(pipeline_state.workflow)

    # Update methods

    def insert_step(self, step_id: str, position: int = -1) -> "PipelineUIState":
        """Insert a step into the pipeline at the given position."""
        if position == -1:
            position = len(self.step_ids)
        self.step_ids.insert(position, step_id)
        steps = self.steps | {step_id: ModelStepUIState()}
        return self.model_copy(update={"step_ids": self.step_ids, "steps": steps})

    def remove_step(self, step_id: str) -> "PipelineUIState":
        """Remove a step from the pipeline."""
        if step_id not in self.step_ids:
            raise ValueError(f"Step {step_id} not found in pipeline. Step IDs: {self.step_ids}")
        self.step_ids.remove(step_id)
        self.steps.pop(step_id)
        return self.model_copy(update={"step_ids": self.step_ids, "steps": self.steps})

    def update_step(self, step_id: str, ui_state: ModelStepUIState) -> "PipelineUIState":
        """Update a step in the pipeline."""
        if step_id not in self.steps:
            raise ValueError(f"Step {step_id} not found in pipeline. Step IDs: {self.step_ids}")
        return self.model_copy(update={"steps": self.steps | {step_id: ui_state}})


class PipelineState(BaseModel):
    """Represents the state for a pipeline component."""

    workflow: Workflow
    ui_state: PipelineUIState

    @classmethod
    def from_workflow(cls, workflow: Workflow):
        """Create a pipeline state from a workflow."""
        return cls(workflow=workflow, ui_state=PipelineUIState.from_workflow(workflow))

    def update_workflow(self, workflow: Workflow) -> "PipelineState":
        return self.model_copy(update={"workflow": workflow})

    def insert_step(self, position: int, step: ModelStep) -> "PipelineState":
        if step.id in self.workflow.steps:
            raise ValueError(f"Step {step.id} already exists in pipeline")

        # Validate position
        if position != -1 and (position < 0 or position > self.n_steps):
            raise ValueError(f"Invalid position: {position}. Must be between 0 and {self.n_steps} or -1")

        # Create a new workflow with updated steps
        workflow = self.workflow.add_step(step)

        self.ui_state = self.ui_state.insert_step(step.id, position)

        # Return a new PipelineState with the updated workflow
        return self.model_copy(update={"workflow": workflow, "ui_state": self.ui_state})

    def remove_step(self, position: int) -> "PipelineState":
        step_id = self.ui_state.step_ids[position]

        workflow = self.workflow.remove_step(step_id)
        ui_state = self.ui_state.remove_step(step_id)
        return self.model_copy(update={"workflow": workflow, "ui_state": ui_state})

    def update_step(self, step: ModelStep, ui_state: ModelStepUIState | None = None) -> "PipelineState":
        """Update a step in the pipeline."""
        if step.id not in self.workflow.steps:
            raise ValueError(f"Step {step.id} not found in pipeline")
        workflow = self.workflow.update_step(step)
        update = {"workflow": workflow}
        if ui_state is not None:
            update["ui_state"] = self.ui_state.update_step(step.id, ui_state)
        return self.model_copy(update=update)

    def get_available_variables(self, model_step_id: str | None = None) -> list[str]:
        """Get all variables from all steps."""
        available_variables = self.available_variables
        if model_step_id is None:
            return available_variables
        prefix = f"{model_step_id}."
        return [var for var in available_variables if not var.startswith(prefix)]

    @property
    def available_variables(self) -> list[str]:
        return self.workflow.get_available_variables()

    @property
    def n_steps(self) -> int:
        return len(self.workflow.steps)

    def get_new_step_id(self) -> str:
        """Get a step ID for a new step."""
        if not self.workflow.steps:
            return "A"
        else:
            last_step_number = max(map(make_step_number, self.workflow.steps.keys()))
            return make_step_id(last_step_number + 1)


class TossupPipelineState(PipelineState):
    workflow: TossupWorkflow