import json from typing import Any, Literal import gradio as gr import yaml from loguru import logger from pydantic import BaseModel, Field from components import utils from workflows.factory import create_new_llm_step from workflows.structs import ModelStep, Workflow def make_step_id(step_number: int): """Make a step id from a step name.""" if step_number < 26: return chr(ord("A") + step_number) else: # For more than 26 steps, use AA, AB, AC, etc. first_char = chr(ord("A") + (step_number // 26) - 1) second_char = chr(ord("A") + (step_number % 26)) return f"{first_char}{second_char}" def make_step_number(step_id: str): """Make a step number from a step id.""" if len(step_id) == 1: return ord(step_id) - ord("A") else: return (ord(step_id[0]) - ord("A")) * 26 + (ord(step_id[1]) - ord("A")) + 1 class ModelStepUIState(BaseModel): """Represents the UI state for a model step component.""" expanded: bool = True active_tab: Literal["model-tab", "inputs-tab", "outputs-tab"] = "model-tab" def update(self, key: str, value: Any) -> "ModelStepUIState": """Update the UI state.""" new_state = self.model_copy(update={key: value}) return new_state class PipelineUIState(BaseModel): """Represents the UI state for a pipeline component.""" step_ids: list[str] = Field(default_factory=list) steps: dict[str, ModelStepUIState] = Field(default_factory=dict) def model_post_init(self, __context: utils.Any) -> None: if not self.steps and self.step_ids: self.steps = {step_id: ModelStepUIState() for step_id in self.step_ids} return super().model_post_init(__context) def get_step_position(self, step_id: str): """Get the position of a step in the pipeline.""" return next((i for i, step in enumerate(self.step_ids) if step == step_id), None) @property def n_steps(self) -> int: """Get the number of steps in the pipeline.""" return len(self.step_ids) @classmethod def from_workflow(cls, workflow: Workflow): """Create a pipeline UI state from a workflow.""" return PipelineUIState( step_ids=list(workflow.steps.keys()), steps={step_id: ModelStepUIState() for step_id in workflow.steps.keys()}, ) class PipelineState(BaseModel): """Represents the state for a pipeline component.""" workflow: Workflow ui_state: PipelineUIState def insert_step(self, position: int, step: ModelStep) -> "PipelineState": if step.id in self.workflow.steps: raise ValueError(f"Step {step.id} already exists in pipeline") # Validate position if position != -1 and (position < 0 or position > self.n_steps): raise ValueError(f"Invalid position: {position}. Must be between 0 and {self.n_steps} or -1") self.workflow.steps[step.id] = step self.ui_state = self.ui_state.model_copy() self.ui_state.steps[step.id] = ModelStepUIState() if position == -1: self.ui_state.step_ids.append(step.id) else: self.ui_state.step_ids.insert(position, step.id) return self def remove_step(self, position: int) -> "PipelineState": step_id = self.ui_state.step_ids.pop(position) self.workflow.steps.pop(step_id) self.ui_state = self.ui_state.model_copy() self.ui_state.steps.pop(step_id) self.update_output_variables_mapping() return self def update_output_variables_mapping(self) -> "PipelineState": available_variables = set(self.available_variables) for output_field in self.workflow.outputs: if self.workflow.outputs[output_field] not in available_variables: self.workflow.outputs[output_field] = None return self @property def available_variables(self) -> list[str]: return self.workflow.get_available_variables() @property def n_steps(self) -> int: return len(self.workflow.steps) def get_new_step_id(self) -> str: """Get a step ID for a new step.""" if not self.workflow.steps: return "A" else: last_step_number = max(map(make_step_number, self.workflow.steps.keys())) return make_step_id(last_step_number + 1) class PipelineStateManager: """Manages a pipeline of multiple steps.""" def get_formatted_config(self, state: PipelineState, format: Literal["json", "yaml"] = "yaml"): """Get the full pipeline configuration.""" config = state.workflow.model_dump(exclude_defaults=True) if format == "yaml": return yaml.dump(config, default_flow_style=False, sort_keys=False, indent=4) else: return json.dumps(config, indent=4, sort_keys=False) def count_state(self): return gr.State(len(self.steps)) def add_step(self, state: PipelineState, position: int = -1, name=""): """Create a new step and return its state.""" step_id = state.get_new_step_id() step_name = name or f"Step {state.n_steps + 1}" new_step = create_new_llm_step(step_id=step_id, name=step_name) state = state.insert_step(position, new_step) return state, state.ui_state, state.available_variables def remove_step(self, state: PipelineState, position: int): """Remove a step from the pipeline.""" if 0 <= position < state.n_steps: state = state.remove_step(position) else: raise ValueError(f"Invalid step position: {position}") return state, state.ui_state, state.available_variables def move_up(self, ui_state: PipelineUIState, position: int): """Move a step up in the pipeline.""" utils.move_item(ui_state.step_ids, position, "up") return ui_state.model_copy() def move_down(self, ui_state: PipelineUIState, position: int): """Move a step down in the pipeline.""" utils.move_item(ui_state.step_ids, position, "down") return ui_state.model_copy() def update_model_step_state(self, state: PipelineState, model_step: ModelStep, ui_state: ModelStepUIState): """Update a step in the pipeline.""" state.workflow.steps[model_step.id] = model_step.model_copy() state.ui_state.steps[model_step.id] = ui_state.model_copy() state.ui_state = state.ui_state.model_copy() state.update_output_variables_mapping() return state, state.ui_state, state.available_variables def update_output_variables(self, state: PipelineState, target: str, produced_variable: str): if produced_variable == "Choose variable...": produced_variable = None """Update the output variables for a step.""" state.workflow.outputs.update({target: produced_variable}) return state def update_model_step_ui(self, state: PipelineState, step_ui: ModelStepUIState, step_id: str): """Update a step in the pipeline.""" state.ui_state.steps[step_id] = step_ui.model_copy() return state, state.ui_state def get_all_variables(self, state: PipelineState, model_step_id: str | None = None) -> list[str]: """Get all variables from all steps.""" available_variables = state.available_variables if model_step_id is None: return available_variables else: prefix = f"{model_step_id}." return [var for var in available_variables if not var.startswith(prefix)] def get_pipeline_config(self): """Get the full pipeline configuration.""" return self.workflow