File size: 7,519 Bytes
193db9d e1ce295 193db9d e1ce295 193db9d e1ce295 193db9d e1ce295 193db9d e1ce295 193db9d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
from typing import Any, Literal, Union
import gradio as gr
from loguru import logger
from app_configs import UNSELECTED_VAR_NAME
from components.model_pipeline.state_manager import ModelStepUIState
from components.utils import DIRECTIONS, move_item
from utils import get_model_and_provider
from workflows.structs import FieldType, ModelStep
class ModelStepStateManager:
def __init__(self, max_input_fields: int, max_output_fields: int):
self.max_fields = {
"input": max_input_fields,
"output": max_output_fields,
}
# UI state update functions
def update_ui_state(self, ui_state: ModelStepUIState, key: str, value: Any) -> ModelStepUIState:
return ui_state.update(key, value)
# Property update functions
def update_step_name(self, model_step: ModelStep, value: str) -> ModelStep:
"""Update the step name in state and accordion label."""
return model_step.update_property("name", value)
def update_temperature(self, model_step: ModelStep, value: float) -> ModelStep:
return model_step.update_property("temperature", value)
def update_model_and_provider(self, model_step: ModelStep, value: str) -> ModelStep:
"""Update the model provider in the state."""
model, provider = get_model_and_provider(value)
return model_step.update({"model": model, "provider": provider})
def update_system_prompt(self, model_step: ModelStep, value: str) -> ModelStep:
"""Update the system prompt in the state."""
return model_step.update_property("system_prompt", value)
# Field update functions
def update_input_field_name(self, model_step: ModelStep, value: str, index: int) -> ModelStep:
"""Update a specific field of an input field at the given index."""
return model_step.update_field("input", index, "name", value)
def update_input_field_variable(self, model_step: ModelStep, value: str, name: str, index: int) -> ModelStep:
"""Update a specific field of an input field at the given index."""
if value == UNSELECTED_VAR_NAME:
return model_step.update_field("input", index, "variable", "")
if name == "":
suggested_name = value.split(".", 1)[-1]
logger.info(f"Updating input field variable to {value}. Suggested name: {suggested_name}")
model_step = model_step.update_field("input", index, "name", suggested_name)
return model_step.update_field("input", index, "variable", value)
def update_input_field_description(self, model_step: ModelStep, value: str, index: int) -> ModelStep:
"""Update a specific field of an input field at the given index."""
return model_step.update_field("input", index, "description", value)
def update_output_field_name(self, model_step: ModelStep, value: str, index: int) -> ModelStep:
"""Update a specific field of an output field at the given index."""
return model_step.update_field("output", index, "name", value)
def update_output_field_type(self, model_step: ModelStep, value: str, index: int) -> ModelStep:
"""Update a specific field of an output field at the given index."""
return model_step.update_field("output", index, "type", value)
def update_output_field_variable(self, model_step: ModelStep, value: str, index: int) -> ModelStep:
"""Update a specific field of an output field at the given index."""
return model_step.update_field("output", index, "variable", value)
def update_output_field_description(self, model_step: ModelStep, value: str, index: int) -> ModelStep:
"""Update a specific field of an output field at the given index."""
return model_step.update_field("output", index, "description", value)
def make_input_field_updates(self, model_step: ModelStep) -> list[gr.State | dict[str, Any]]:
fields = model_step.input_fields
updates = []
for i in range(self.max_fields["input"]):
if i < len(fields):
updates.extend(
[
gr.update(value=fields[i].name),
gr.update(value=fields[i].variable),
gr.update(value=fields[i].description),
]
)
else:
updates.extend([gr.skip(), gr.skip(), gr.skip()])
return updates
def make_output_field_updates(self, model_step: ModelStep) -> list[gr.State | dict[str, Any]]:
fields = model_step.output_fields
updates = []
for i in range(self.max_fields["output"]):
if i < len(fields):
updates.extend(
[
gr.update(value=fields[i].name),
gr.update(value=fields[i].type),
gr.update(value=fields[i].description),
]
)
else:
updates.extend([gr.skip(), gr.skip(), gr.skip()])
return updates
def _add_field(
self, model_step: ModelStep, field_type: FieldType, index: int = -1, input_var: str | None = None
) -> tuple[Union[ModelStep, int, dict[str, Any]], ...]:
new_step = model_step.add_field(field_type, index, input_var)
fields = new_step.fields(field_type)
row_updates = [gr.update(visible=i < len(fields)) for i in range(self.max_fields[field_type])]
return new_step, len(fields), *row_updates
def _delete_field(
self, model_step: ModelStep, field_type: FieldType, index: int
) -> tuple[Union[ModelStep, int, dict[str, Any]], ...]:
new_step = model_step.delete_field(field_type, index)
fields = new_step.fields(field_type)
row_updates = [gr.update(visible=i < len(fields)) for i in range(self.max_fields[field_type])]
return new_step, len(fields), *row_updates
# Field add/delete functions
def add_input_field(self, model_step: ModelStep, index: int = -1):
updates = self._add_field(model_step, "input", index, input_var=UNSELECTED_VAR_NAME)
return *updates, *self.make_input_field_updates(model_step)
def add_output_field(self, model_step: ModelStep, index: int = -1):
updates = self._add_field(model_step, "output", index)
return *updates, *self.make_output_field_updates(model_step)
def delete_input_field(self, model_step: ModelStep, index: int):
updates = self._delete_field(model_step, "input", index)
return *updates, *self.make_input_field_updates(model_step)
def delete_output_field(self, model_step: ModelStep, index: int):
updates = self._delete_field(model_step, "output", index)
return *updates, *self.make_output_field_updates(model_step)
def move_output_field(
self, model_step: ModelStep, index: int, direction: DIRECTIONS
) -> list[gr.State | dict[str, Any]]:
"""
Move an output field in the list either up or down.
Args:
index: Index of the output field to move
direction: Direction to move the field ('up' or 'down')
Returns:
list: A list containing [updated_state, field_value_updates...]
"""
new_step = model_step.model_copy()
move_item(new_step.output_fields, index, direction)
# Update all output fields to reflect the new order
updates = self.make_output_field_updates(new_step)
return new_step, *updates
|