Maharshi Gor
BugFix: Step Creation and removal.
e00ec4e
raw
history blame
7.73 kB
import json
from typing import Any, Literal
import gradio as gr
import yaml
from loguru import logger
from pydantic import BaseModel, Field
from components import utils
from workflows.factory import create_new_llm_step
from workflows.structs import ModelStep, Workflow
def make_step_id(step_number: int):
"""Make a step id from a step name."""
if step_number < 26:
return chr(ord("A") + step_number)
else:
# For more than 26 steps, use AA, AB, AC, etc.
first_char = chr(ord("A") + (step_number // 26) - 1)
second_char = chr(ord("A") + (step_number % 26))
return f"{first_char}{second_char}"
def make_step_number(step_id: str):
"""Make a step number from a step id."""
if len(step_id) == 1:
return ord(step_id) - ord("A")
else:
return (ord(step_id[0]) - ord("A")) * 26 + (ord(step_id[1]) - ord("A")) + 1
class ModelStepUIState(BaseModel):
"""Represents the UI state for a model step component."""
expanded: bool = True
active_tab: Literal["model-tab", "inputs-tab", "outputs-tab"] = "model-tab"
def update(self, key: str, value: Any) -> "ModelStepUIState":
"""Update the UI state."""
new_state = self.model_copy(update={key: value})
return new_state
class PipelineUIState(BaseModel):
"""Represents the UI state for a pipeline component."""
step_ids: list[str] = Field(default_factory=list)
steps: dict[str, ModelStepUIState] = Field(default_factory=dict)
def model_post_init(self, __context: utils.Any) -> None:
if not self.steps and self.step_ids:
self.steps = {step_id: ModelStepUIState() for step_id in self.step_ids}
return super().model_post_init(__context)
def get_step_position(self, step_id: str):
"""Get the position of a step in the pipeline."""
return next((i for i, step in enumerate(self.step_ids) if step == step_id), None)
@property
def n_steps(self) -> int:
"""Get the number of steps in the pipeline."""
return len(self.step_ids)
@classmethod
def from_workflow(cls, workflow: Workflow):
"""Create a pipeline UI state from a workflow."""
return PipelineUIState(
step_ids=list(workflow.steps.keys()),
steps={step_id: ModelStepUIState() for step_id in workflow.steps.keys()},
)
class PipelineState(BaseModel):
"""Represents the state for a pipeline component."""
workflow: Workflow
ui_state: PipelineUIState
def insert_step(self, position: int, step: ModelStep) -> "PipelineState":
if step.id in self.workflow.steps:
raise ValueError(f"Step {step.id} already exists in pipeline")
# Validate position
if position != -1 and (position < 0 or position > self.n_steps):
raise ValueError(f"Invalid position: {position}. Must be between 0 and {self.n_steps} or -1")
self.workflow.steps[step.id] = step
self.ui_state = self.ui_state.model_copy()
self.ui_state.steps[step.id] = ModelStepUIState()
if position == -1:
self.ui_state.step_ids.append(step.id)
else:
self.ui_state.step_ids.insert(position, step.id)
return self
def remove_step(self, position: int) -> "PipelineState":
step_id = self.ui_state.step_ids.pop(position)
self.workflow.steps.pop(step_id)
self.ui_state = self.ui_state.model_copy()
self.ui_state.steps.pop(step_id)
self.update_output_variables_mapping()
return self
def update_output_variables_mapping(self) -> "PipelineState":
available_variables = set(self.available_variables)
for output_field in self.workflow.outputs:
if self.workflow.outputs[output_field] not in available_variables:
self.workflow.outputs[output_field] = None
return self
@property
def available_variables(self) -> list[str]:
return self.workflow.get_available_variables()
@property
def n_steps(self) -> int:
return len(self.workflow.steps)
def get_new_step_id(self) -> str:
"""Get a step ID for a new step."""
if not self.workflow.steps:
return "A"
else:
last_step_number = max(map(make_step_number, self.workflow.steps.keys()))
return make_step_id(last_step_number + 1)
class PipelineStateManager:
"""Manages a pipeline of multiple steps."""
def get_formatted_config(self, state: PipelineState, format: Literal["json", "yaml"] = "yaml"):
"""Get the full pipeline configuration."""
config = state.workflow.model_dump(exclude_defaults=True)
if format == "yaml":
return yaml.dump(config, default_flow_style=False, sort_keys=False, indent=4)
else:
return json.dumps(config, indent=4, sort_keys=False)
def count_state(self):
return gr.State(len(self.steps))
def add_step(self, state: PipelineState, position: int = -1, name=""):
"""Create a new step and return its state."""
step_id = state.get_new_step_id()
step_name = name or f"Step {state.n_steps + 1}"
new_step = create_new_llm_step(step_id=step_id, name=step_name)
state = state.insert_step(position, new_step)
return state, state.ui_state, state.available_variables
def remove_step(self, state: PipelineState, position: int):
"""Remove a step from the pipeline."""
if 0 <= position < state.n_steps:
state = state.remove_step(position)
else:
raise ValueError(f"Invalid step position: {position}")
return state, state.ui_state, state.available_variables
def move_up(self, ui_state: PipelineUIState, position: int):
"""Move a step up in the pipeline."""
utils.move_item(ui_state.step_ids, position, "up")
return ui_state.model_copy()
def move_down(self, ui_state: PipelineUIState, position: int):
"""Move a step down in the pipeline."""
utils.move_item(ui_state.step_ids, position, "down")
return ui_state.model_copy()
def update_model_step_state(self, state: PipelineState, model_step: ModelStep, ui_state: ModelStepUIState):
"""Update a step in the pipeline."""
state.workflow.steps[model_step.id] = model_step.model_copy()
state.ui_state.steps[model_step.id] = ui_state.model_copy()
state.ui_state = state.ui_state.model_copy()
state.update_output_variables_mapping()
return state, state.ui_state, state.available_variables
def update_output_variables(self, state: PipelineState, target: str, produced_variable: str):
if produced_variable == "Choose variable...":
produced_variable = None
"""Update the output variables for a step."""
state.workflow.outputs.update({target: produced_variable})
return state
def update_model_step_ui(self, state: PipelineState, step_ui: ModelStepUIState, step_id: str):
"""Update a step in the pipeline."""
state.ui_state.steps[step_id] = step_ui.model_copy()
return state, state.ui_state
def get_all_variables(self, state: PipelineState, model_step_id: str | None = None) -> list[str]:
"""Get all variables from all steps."""
available_variables = state.available_variables
if model_step_id is None:
return available_variables
else:
prefix = f"{model_step_id}."
return [var for var in available_variables if not var.startswith(prefix)]
def get_pipeline_config(self):
"""Get the full pipeline configuration."""
return self.workflow