from typing import Any, Literal, Union import gradio as gr from components.model_pipeline.state_manager import ModelStepUIState from components.utils import DIRECTIONS, move_item from utils import get_model_and_provider from workflows.structs import FieldType, ModelStep class ModelStepStateManager: def __init__(self, max_input_fields: int, max_output_fields: int): self.max_fields = { "input": max_input_fields, "output": max_output_fields, } # UI state update functions def update_ui_state(self, ui_state: ModelStepUIState, key: str, value: Any) -> ModelStepUIState: return ui_state.update(key, value) # Property update functions def update_step_name(self, model_step: ModelStep, value: str) -> ModelStep: """Update the step name in state and accordion label.""" return model_step.update_property("name", value) def update_temperature(self, model_step: ModelStep, value: float) -> ModelStep: return model_step.update_property("temperature", value) def update_model_and_provider(self, model_step: ModelStep, value: str) -> ModelStep: """Update the model provider in the state.""" model, provider = get_model_and_provider(value) return model_step.update({"model": model, "provider": provider}) def update_system_prompt(self, model_step: ModelStep, value: str) -> ModelStep: """Update the system prompt in the state.""" return model_step.update_property("system_prompt", value) # Field update functions def update_input_field_name(self, model_step: ModelStep, value: str, index: int) -> ModelStep: """Update a specific field of an input field at the given index.""" return model_step.update_field("input", index, "name", value) def update_input_field_variable(self, model_step: ModelStep, value: str, index: int) -> ModelStep: """Update a specific field of an input field at the given index.""" return model_step.update_field("input", index, "variable", value) def update_input_field_description(self, model_step: ModelStep, value: str, index: int) -> ModelStep: """Update a specific field of an input field at the given index.""" return model_step.update_field("input", index, "description", value) def update_output_field_name(self, model_step: ModelStep, value: str, index: int) -> ModelStep: """Update a specific field of an output field at the given index.""" return model_step.update_field("output", index, "name", value) def update_output_field_type(self, model_step: ModelStep, value: str, index: int) -> ModelStep: """Update a specific field of an output field at the given index.""" return model_step.update_field("output", index, "type", value) def update_output_field_variable(self, model_step: ModelStep, value: str, index: int) -> ModelStep: """Update a specific field of an output field at the given index.""" return model_step.update_field("output", index, "variable", value) def update_output_field_description(self, model_step: ModelStep, value: str, index: int) -> ModelStep: """Update a specific field of an output field at the given index.""" return model_step.update_field("output", index, "description", value) def make_input_field_updates(self, model_step: ModelStep) -> list[gr.State | dict[str, Any]]: fields = model_step.input_fields updates = [] for i in range(self.max_fields["input"]): if i < len(fields): updates.extend( [ gr.update(value=fields[i].name), gr.update(value=fields[i].variable), gr.update(value=fields[i].description), ] ) else: updates.extend([gr.skip(), gr.skip(), gr.skip()]) return updates def make_output_field_updates(self, model_step: ModelStep) -> list[gr.State | dict[str, Any]]: fields = model_step.output_fields updates = [] for i in range(self.max_fields["output"]): if i < len(fields): updates.extend( [ gr.update(value=fields[i].name), gr.update(value=fields[i].type), gr.update(value=fields[i].description), ] ) else: updates.extend([gr.skip(), gr.skip(), gr.skip()]) return updates def _add_field( self, model_step: ModelStep, field_type: FieldType, index: int = -1, input_var: str | None = None ) -> tuple[Union[ModelStep, int, dict[str, Any]], ...]: new_step = model_step.add_field(field_type, index, input_var) fields = new_step.fields(field_type) row_updates = [gr.update(visible=i < len(fields)) for i in range(self.max_fields[field_type])] return new_step, len(fields), *row_updates def _delete_field( self, model_step: ModelStep, field_type: FieldType, index: int ) -> tuple[Union[ModelStep, int, dict[str, Any]], ...]: new_step = model_step.delete_field(field_type, index) fields = new_step.fields(field_type) row_updates = [gr.update(visible=i < len(fields)) for i in range(self.max_fields[field_type])] return new_step, len(fields), *row_updates # Field add/delete functions def add_input_field(self, model_step: ModelStep, index: int = -1): updates = self._add_field(model_step, "input", index, input_var="question_text") return *updates, *self.make_input_field_updates(model_step) def add_output_field(self, model_step: ModelStep, index: int = -1): updates = self._add_field(model_step, "output", index) return *updates, *self.make_output_field_updates(model_step) def delete_input_field(self, model_step: ModelStep, index: int): updates = self._delete_field(model_step, "input", index) return *updates, *self.make_input_field_updates(model_step) def delete_output_field(self, model_step: ModelStep, index: int): updates = self._delete_field(model_step, "output", index) return *updates, *self.make_output_field_updates(model_step) def move_output_field( self, model_step: ModelStep, index: int, direction: DIRECTIONS ) -> list[gr.State | dict[str, Any]]: """ Move an output field in the list either up or down. Args: index: Index of the output field to move direction: Direction to move the field ('up' or 'down') Returns: list: A list containing [updated_state, field_value_updates...] """ new_step = model_step.model_copy() move_item(new_step.output_fields, index, direction) # Update all output fields to reflect the new order updates = self.make_output_field_updates(new_step) return new_step, *updates