Spaces:
Restarting
Restarting
from typing import Any, Literal, Union | |
import gradio as gr | |
from components.model_pipeline.state_manager import ModelStepUIState | |
from components.utils import DIRECTIONS, move_item | |
from utils import get_model_and_provider | |
from workflows.structs import FieldType, ModelStep | |
class ModelStepStateManager: | |
def __init__(self, max_input_fields: int, max_output_fields: int): | |
self.max_fields = { | |
"input": max_input_fields, | |
"output": max_output_fields, | |
} | |
# UI state update functions | |
def update_ui_state(self, ui_state: ModelStepUIState, key: str, value: Any) -> ModelStepUIState: | |
return ui_state.update(key, value) | |
# Property update functions | |
def update_step_name(self, model_step: ModelStep, value: str) -> ModelStep: | |
"""Update the step name in state and accordion label.""" | |
return model_step.update_property("name", value) | |
def update_temperature(self, model_step: ModelStep, value: float) -> ModelStep: | |
return model_step.update_property("temperature", value) | |
def update_model_and_provider(self, model_step: ModelStep, value: str) -> ModelStep: | |
"""Update the model provider in the state.""" | |
model, provider = get_model_and_provider(value) | |
return model_step.update({"model": model, "provider": provider}) | |
def update_system_prompt(self, model_step: ModelStep, value: str) -> ModelStep: | |
"""Update the system prompt in the state.""" | |
return model_step.update_property("system_prompt", value) | |
# Field update functions | |
def update_input_field_name(self, model_step: ModelStep, value: str, index: int) -> ModelStep: | |
"""Update a specific field of an input field at the given index.""" | |
return model_step.update_field("input", index, "name", value) | |
def update_input_field_variable(self, model_step: ModelStep, value: str, index: int) -> ModelStep: | |
"""Update a specific field of an input field at the given index.""" | |
return model_step.update_field("input", index, "variable", value) | |
def update_input_field_description(self, model_step: ModelStep, value: str, index: int) -> ModelStep: | |
"""Update a specific field of an input field at the given index.""" | |
return model_step.update_field("input", index, "description", value) | |
def update_output_field_name(self, model_step: ModelStep, value: str, index: int) -> ModelStep: | |
"""Update a specific field of an output field at the given index.""" | |
return model_step.update_field("output", index, "name", value) | |
def update_output_field_type(self, model_step: ModelStep, value: str, index: int) -> ModelStep: | |
"""Update a specific field of an output field at the given index.""" | |
return model_step.update_field("output", index, "type", value) | |
def update_output_field_variable(self, model_step: ModelStep, value: str, index: int) -> ModelStep: | |
"""Update a specific field of an output field at the given index.""" | |
return model_step.update_field("output", index, "variable", value) | |
def update_output_field_description(self, model_step: ModelStep, value: str, index: int) -> ModelStep: | |
"""Update a specific field of an output field at the given index.""" | |
return model_step.update_field("output", index, "description", value) | |
def make_input_field_updates(self, model_step: ModelStep) -> list[gr.State | dict[str, Any]]: | |
fields = model_step.input_fields | |
updates = [] | |
for i in range(self.max_fields["input"]): | |
if i < len(fields): | |
updates.extend( | |
[ | |
gr.update(value=fields[i].name), | |
gr.update(value=fields[i].variable), | |
gr.update(value=fields[i].description), | |
] | |
) | |
else: | |
updates.extend([gr.skip(), gr.skip(), gr.skip()]) | |
return updates | |
def make_output_field_updates(self, model_step: ModelStep) -> list[gr.State | dict[str, Any]]: | |
fields = model_step.output_fields | |
updates = [] | |
for i in range(self.max_fields["output"]): | |
if i < len(fields): | |
updates.extend( | |
[ | |
gr.update(value=fields[i].name), | |
gr.update(value=fields[i].type), | |
gr.update(value=fields[i].description), | |
] | |
) | |
else: | |
updates.extend([gr.skip(), gr.skip(), gr.skip()]) | |
return updates | |
def _add_field( | |
self, model_step: ModelStep, field_type: FieldType, index: int = -1, input_var: str | None = None | |
) -> tuple[Union[ModelStep, int, dict[str, Any]], ...]: | |
new_step = model_step.add_field(field_type, index, input_var) | |
fields = new_step.fields(field_type) | |
row_updates = [gr.update(visible=i < len(fields)) for i in range(self.max_fields[field_type])] | |
return new_step, len(fields), *row_updates | |
def _delete_field( | |
self, model_step: ModelStep, field_type: FieldType, index: int | |
) -> tuple[Union[ModelStep, int, dict[str, Any]], ...]: | |
new_step = model_step.delete_field(field_type, index) | |
fields = new_step.fields(field_type) | |
row_updates = [gr.update(visible=i < len(fields)) for i in range(self.max_fields[field_type])] | |
return new_step, len(fields), *row_updates | |
# Field add/delete functions | |
def add_input_field(self, model_step: ModelStep, index: int = -1): | |
updates = self._add_field(model_step, "input", index, input_var="question_text") | |
return *updates, *self.make_input_field_updates(model_step) | |
def add_output_field(self, model_step: ModelStep, index: int = -1): | |
updates = self._add_field(model_step, "output", index) | |
return *updates, *self.make_output_field_updates(model_step) | |
def delete_input_field(self, model_step: ModelStep, index: int): | |
updates = self._delete_field(model_step, "input", index) | |
return *updates, *self.make_input_field_updates(model_step) | |
def delete_output_field(self, model_step: ModelStep, index: int): | |
updates = self._delete_field(model_step, "output", index) | |
return *updates, *self.make_output_field_updates(model_step) | |
def move_output_field( | |
self, model_step: ModelStep, index: int, direction: DIRECTIONS | |
) -> list[gr.State | dict[str, Any]]: | |
""" | |
Move an output field in the list either up or down. | |
Args: | |
index: Index of the output field to move | |
direction: Direction to move the field ('up' or 'down') | |
Returns: | |
list: A list containing [updated_state, field_value_updates...] | |
""" | |
new_step = model_step.model_copy() | |
move_item(new_step.output_fields, index, direction) | |
# Update all output fields to reflect the new order | |
updates = self.make_output_field_updates(new_step) | |
return new_step, *updates | |