File size: 7,926 Bytes
193db9d
 
 
 
 
e00ec4e
193db9d
 
 
 
0bab47c
193db9d
 
e00ec4e
193db9d
e00ec4e
 
193db9d
 
e00ec4e
 
193db9d
 
 
e00ec4e
 
 
 
 
 
 
 
193db9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e00ec4e
 
 
 
 
193db9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e00ec4e
193db9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e00ec4e
193db9d
 
 
 
 
e00ec4e
193db9d
e00ec4e
193db9d
 
 
 
 
 
 
e00ec4e
193db9d
 
 
e00ec4e
193db9d
 
e00ec4e
 
 
 
 
 
 
 
193db9d
 
 
 
 
 
 
0bab47c
 
 
193db9d
 
 
 
 
 
 
 
 
 
e00ec4e
193db9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import json
from typing import Any, Literal

import gradio as gr
import yaml
from loguru import logger
from pydantic import BaseModel, Field

from components import utils
from workflows.factory import create_new_llm_step
from workflows.structs import ModelStep, TossupWorkflow, Workflow


def make_step_id(step_number: int):
    """Make a step id from a step name."""
    if step_number < 26:
        return chr(ord("A") + step_number)
    else:
        # For more than 26 steps, use AA, AB, AC, etc.
        first_char = chr(ord("A") + (step_number // 26) - 1)
        second_char = chr(ord("A") + (step_number % 26))
        return f"{first_char}{second_char}"


def make_step_number(step_id: str):
    """Make a step number from a step id."""
    if len(step_id) == 1:
        return ord(step_id) - ord("A")
    else:
        return (ord(step_id[0]) - ord("A")) * 26 + (ord(step_id[1]) - ord("A")) + 1


class ModelStepUIState(BaseModel):
    """Represents the UI state for a model step component."""

    expanded: bool = True
    active_tab: Literal["model-tab", "inputs-tab", "outputs-tab"] = "model-tab"

    def update(self, key: str, value: Any) -> "ModelStepUIState":
        """Update the UI state."""
        new_state = self.model_copy(update={key: value})
        return new_state


class PipelineUIState(BaseModel):
    """Represents the UI state for a pipeline component."""

    step_ids: list[str] = Field(default_factory=list)
    steps: dict[str, ModelStepUIState] = Field(default_factory=dict)

    def model_post_init(self, __context: utils.Any) -> None:
        if not self.steps and self.step_ids:
            self.steps = {step_id: ModelStepUIState() for step_id in self.step_ids}
        return super().model_post_init(__context)

    def get_step_position(self, step_id: str):
        """Get the position of a step in the pipeline."""
        return next((i for i, step in enumerate(self.step_ids) if step == step_id), None)

    @property
    def n_steps(self) -> int:
        """Get the number of steps in the pipeline."""
        return len(self.step_ids)

    @classmethod
    def from_workflow(cls, workflow: Workflow):
        """Create a pipeline UI state from a workflow."""
        return PipelineUIState(
            step_ids=list(workflow.steps.keys()),
            steps={step_id: ModelStepUIState() for step_id in workflow.steps.keys()},
        )


class PipelineState(BaseModel):
    """Represents the state for a pipeline component."""

    workflow: Workflow
    ui_state: PipelineUIState

    def insert_step(self, position: int, step: ModelStep) -> "PipelineState":
        if step.id in self.workflow.steps:
            raise ValueError(f"Step {step.id} already exists in pipeline")

        # Validate position
        if position != -1 and (position < 0 or position > self.n_steps):
            raise ValueError(f"Invalid position: {position}. Must be between 0 and {self.n_steps} or -1")

        self.workflow.steps[step.id] = step

        self.ui_state = self.ui_state.model_copy()
        self.ui_state.steps[step.id] = ModelStepUIState()
        if position == -1:
            self.ui_state.step_ids.append(step.id)
        else:
            self.ui_state.step_ids.insert(position, step.id)
        return self

    def remove_step(self, position: int) -> "PipelineState":
        step_id = self.ui_state.step_ids.pop(position)
        self.workflow.steps.pop(step_id)
        self.ui_state = self.ui_state.model_copy()
        self.ui_state.steps.pop(step_id)
        self.update_output_variables_mapping()
        return self

    def update_output_variables_mapping(self) -> "PipelineState":
        available_variables = set(self.available_variables)
        for output_field in self.workflow.outputs:
            if self.workflow.outputs[output_field] not in available_variables:
                self.workflow.outputs[output_field] = None
        return self

    @property
    def available_variables(self) -> list[str]:
        return self.workflow.get_available_variables()

    @property
    def n_steps(self) -> int:
        return len(self.workflow.steps)

    def get_new_step_id(self) -> str:
        """Get a step ID for a new step."""
        if not self.workflow.steps:
            return "A"
        else:
            last_step_number = max(map(make_step_number, self.workflow.steps.keys()))
            return make_step_id(last_step_number + 1)


class PipelineStateManager:
    """Manages a pipeline of multiple steps."""

    def get_formatted_config(self, state: PipelineState, format: Literal["json", "yaml"] = "yaml"):
        """Get the full pipeline configuration."""
        config = state.workflow.model_dump(exclude_defaults=True)
        if isinstance(state.workflow, TossupWorkflow):
            buzzer_config = state.workflow.buzzer.model_dump(exclude_defaults=False)
            config["buzzer"] = buzzer_config
        if format == "yaml":
            return yaml.dump(config, default_flow_style=False, sort_keys=False, indent=4)
        else:
            return json.dumps(config, indent=4, sort_keys=False)

    def count_state(self):
        return gr.State(len(self.steps))

    def add_step(self, state: PipelineState, position: int = -1, name=""):
        """Create a new step and return its state."""
        step_id = state.get_new_step_id()
        step_name = name or f"Step {state.n_steps + 1}"
        new_step = create_new_llm_step(step_id=step_id, name=step_name)
        state = state.insert_step(position, new_step)
        return state, state.ui_state, state.available_variables

    def remove_step(self, state: PipelineState, position: int):
        """Remove a step from the pipeline."""
        if 0 <= position < state.n_steps:
            state = state.remove_step(position)
        else:
            raise ValueError(f"Invalid step position: {position}")
        return state, state.ui_state, state.available_variables

    def move_up(self, ui_state: PipelineUIState, position: int):
        """Move a step up in the pipeline."""
        utils.move_item(ui_state.step_ids, position, "up")
        return ui_state.model_copy()

    def move_down(self, ui_state: PipelineUIState, position: int):
        """Move a step down in the pipeline."""
        utils.move_item(ui_state.step_ids, position, "down")
        return ui_state.model_copy()

    def update_model_step_state(self, state: PipelineState, model_step: ModelStep, ui_state: ModelStepUIState):
        """Update a step in the pipeline."""
        state.workflow.steps[model_step.id] = model_step.model_copy()
        state.ui_state.steps[model_step.id] = ui_state.model_copy()
        state.ui_state = state.ui_state.model_copy()
        state.update_output_variables_mapping()
        return state, state.ui_state, state.available_variables

    def update_output_variables(self, state: PipelineState, target: str, produced_variable: str):
        if produced_variable == "Choose variable...":
            produced_variable = None
        """Update the output variables for a step."""
        state.workflow.outputs.update({target: produced_variable})
        return state

    def update_model_step_ui(self, state: PipelineState, step_ui: ModelStepUIState, step_id: str):
        """Update a step in the pipeline."""
        state.ui_state.steps[step_id] = step_ui.model_copy()
        return state, state.ui_state

    def get_all_variables(self, state: PipelineState, model_step_id: str | None = None) -> list[str]:
        """Get all variables from all steps."""
        available_variables = state.available_variables
        if model_step_id is None:
            return available_variables
        else:
            prefix = f"{model_step_id}."
            return [var for var in available_variables if not var.startswith(prefix)]

    def get_pipeline_config(self):
        """Get the full pipeline configuration."""
        return self.workflow