|
import json |
|
from typing import Any, Literal |
|
|
|
import gradio as gr |
|
import yaml |
|
from loguru import logger |
|
from pydantic import BaseModel, Field |
|
|
|
from components import utils |
|
from workflows.factory import create_new_llm_step |
|
from workflows.structs import ModelStep, Workflow |
|
|
|
|
|
def make_step_id(step_number: int): |
|
"""Make a step id from a step name.""" |
|
if step_number < 26: |
|
return chr(ord("A") + step_number) |
|
else: |
|
|
|
first_char = chr(ord("A") + (step_number // 26) - 1) |
|
second_char = chr(ord("A") + (step_number % 26)) |
|
return f"{first_char}{second_char}" |
|
|
|
|
|
def make_step_number(step_id: str): |
|
"""Make a step number from a step id.""" |
|
if len(step_id) == 1: |
|
return ord(step_id) - ord("A") |
|
else: |
|
return (ord(step_id[0]) - ord("A")) * 26 + (ord(step_id[1]) - ord("A")) + 1 |
|
|
|
|
|
class ModelStepUIState(BaseModel): |
|
"""Represents the UI state for a model step component.""" |
|
|
|
expanded: bool = True |
|
active_tab: Literal["model-tab", "inputs-tab", "outputs-tab"] = "model-tab" |
|
|
|
def update(self, key: str, value: Any) -> "ModelStepUIState": |
|
"""Update the UI state.""" |
|
new_state = self.model_copy(update={key: value}) |
|
return new_state |
|
|
|
|
|
class PipelineUIState(BaseModel): |
|
"""Represents the UI state for a pipeline component.""" |
|
|
|
step_ids: list[str] = Field(default_factory=list) |
|
steps: dict[str, ModelStepUIState] = Field(default_factory=dict) |
|
|
|
def model_post_init(self, __context: utils.Any) -> None: |
|
if not self.steps and self.step_ids: |
|
self.steps = {step_id: ModelStepUIState() for step_id in self.step_ids} |
|
return super().model_post_init(__context) |
|
|
|
def get_step_position(self, step_id: str): |
|
"""Get the position of a step in the pipeline.""" |
|
return next((i for i, step in enumerate(self.step_ids) if step == step_id), None) |
|
|
|
@property |
|
def n_steps(self) -> int: |
|
"""Get the number of steps in the pipeline.""" |
|
return len(self.step_ids) |
|
|
|
@classmethod |
|
def from_workflow(cls, workflow: Workflow): |
|
"""Create a pipeline UI state from a workflow.""" |
|
return PipelineUIState( |
|
step_ids=list(workflow.steps.keys()), |
|
steps={step_id: ModelStepUIState() for step_id in workflow.steps.keys()}, |
|
) |
|
|
|
|
|
class PipelineState(BaseModel): |
|
"""Represents the state for a pipeline component.""" |
|
|
|
workflow: Workflow |
|
ui_state: PipelineUIState |
|
|
|
def insert_step(self, position: int, step: ModelStep) -> "PipelineState": |
|
if step.id in self.workflow.steps: |
|
raise ValueError(f"Step {step.id} already exists in pipeline") |
|
|
|
|
|
if position != -1 and (position < 0 or position > self.n_steps): |
|
raise ValueError(f"Invalid position: {position}. Must be between 0 and {self.n_steps} or -1") |
|
|
|
self.workflow.steps[step.id] = step |
|
|
|
self.ui_state = self.ui_state.model_copy() |
|
self.ui_state.steps[step.id] = ModelStepUIState() |
|
if position == -1: |
|
self.ui_state.step_ids.append(step.id) |
|
else: |
|
self.ui_state.step_ids.insert(position, step.id) |
|
return self |
|
|
|
def remove_step(self, position: int) -> "PipelineState": |
|
step_id = self.ui_state.step_ids.pop(position) |
|
self.workflow.steps.pop(step_id) |
|
self.ui_state = self.ui_state.model_copy() |
|
self.ui_state.steps.pop(step_id) |
|
self.update_output_variables_mapping() |
|
return self |
|
|
|
def update_output_variables_mapping(self) -> "PipelineState": |
|
available_variables = set(self.available_variables) |
|
for output_field in self.workflow.outputs: |
|
if self.workflow.outputs[output_field] not in available_variables: |
|
self.workflow.outputs[output_field] = None |
|
return self |
|
|
|
@property |
|
def available_variables(self) -> list[str]: |
|
return self.workflow.get_available_variables() |
|
|
|
@property |
|
def n_steps(self) -> int: |
|
return len(self.workflow.steps) |
|
|
|
def get_new_step_id(self) -> str: |
|
"""Get a step ID for a new step.""" |
|
if not self.workflow.steps: |
|
return "A" |
|
else: |
|
last_step_number = max(map(make_step_number, self.workflow.steps.keys())) |
|
return make_step_id(last_step_number + 1) |
|
|
|
|
|
class PipelineStateManager: |
|
"""Manages a pipeline of multiple steps.""" |
|
|
|
def get_formatted_config(self, state: PipelineState, format: Literal["json", "yaml"] = "yaml"): |
|
"""Get the full pipeline configuration.""" |
|
config = state.workflow.model_dump(exclude_defaults=True) |
|
if format == "yaml": |
|
return yaml.dump(config, default_flow_style=False, sort_keys=False, indent=4) |
|
else: |
|
return json.dumps(config, indent=4, sort_keys=False) |
|
|
|
def count_state(self): |
|
return gr.State(len(self.steps)) |
|
|
|
def add_step(self, state: PipelineState, position: int = -1, name=""): |
|
"""Create a new step and return its state.""" |
|
step_id = state.get_new_step_id() |
|
step_name = name or f"Step {state.n_steps + 1}" |
|
new_step = create_new_llm_step(step_id=step_id, name=step_name) |
|
state = state.insert_step(position, new_step) |
|
return state, state.ui_state, state.available_variables |
|
|
|
def remove_step(self, state: PipelineState, position: int): |
|
"""Remove a step from the pipeline.""" |
|
if 0 <= position < state.n_steps: |
|
state = state.remove_step(position) |
|
else: |
|
raise ValueError(f"Invalid step position: {position}") |
|
return state, state.ui_state, state.available_variables |
|
|
|
def move_up(self, ui_state: PipelineUIState, position: int): |
|
"""Move a step up in the pipeline.""" |
|
utils.move_item(ui_state.step_ids, position, "up") |
|
return ui_state.model_copy() |
|
|
|
def move_down(self, ui_state: PipelineUIState, position: int): |
|
"""Move a step down in the pipeline.""" |
|
utils.move_item(ui_state.step_ids, position, "down") |
|
return ui_state.model_copy() |
|
|
|
def update_model_step_state(self, state: PipelineState, model_step: ModelStep, ui_state: ModelStepUIState): |
|
"""Update a step in the pipeline.""" |
|
state.workflow.steps[model_step.id] = model_step.model_copy() |
|
state.ui_state.steps[model_step.id] = ui_state.model_copy() |
|
state.ui_state = state.ui_state.model_copy() |
|
state.update_output_variables_mapping() |
|
return state, state.ui_state, state.available_variables |
|
|
|
def update_output_variables(self, state: PipelineState, target: str, produced_variable: str): |
|
if produced_variable == "Choose variable...": |
|
produced_variable = None |
|
"""Update the output variables for a step.""" |
|
state.workflow.outputs.update({target: produced_variable}) |
|
return state |
|
|
|
def update_model_step_ui(self, state: PipelineState, step_ui: ModelStepUIState, step_id: str): |
|
"""Update a step in the pipeline.""" |
|
state.ui_state.steps[step_id] = step_ui.model_copy() |
|
return state, state.ui_state |
|
|
|
def get_all_variables(self, state: PipelineState, model_step_id: str | None = None) -> list[str]: |
|
"""Get all variables from all steps.""" |
|
available_variables = state.available_variables |
|
if model_step_id is None: |
|
return available_variables |
|
else: |
|
prefix = f"{model_step_id}." |
|
return [var for var in available_variables if not var.startswith(prefix)] |
|
|
|
def get_pipeline_config(self): |
|
"""Get the full pipeline configuration.""" |
|
return self.workflow |
|
|