File size: 7,519 Bytes
193db9d
 
 
e1ce295
193db9d
e1ce295
193db9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1ce295
193db9d
e1ce295
 
 
 
 
 
193db9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1ce295
193db9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from typing import Any, Literal, Union

import gradio as gr
from loguru import logger

from app_configs import UNSELECTED_VAR_NAME
from components.model_pipeline.state_manager import ModelStepUIState
from components.utils import DIRECTIONS, move_item
from utils import get_model_and_provider
from workflows.structs import FieldType, ModelStep


class ModelStepStateManager:
    def __init__(self, max_input_fields: int, max_output_fields: int):
        self.max_fields = {
            "input": max_input_fields,
            "output": max_output_fields,
        }

    # UI state update functions
    def update_ui_state(self, ui_state: ModelStepUIState, key: str, value: Any) -> ModelStepUIState:
        return ui_state.update(key, value)

    # Property update functions
    def update_step_name(self, model_step: ModelStep, value: str) -> ModelStep:
        """Update the step name in state and accordion label."""
        return model_step.update_property("name", value)

    def update_temperature(self, model_step: ModelStep, value: float) -> ModelStep:
        return model_step.update_property("temperature", value)

    def update_model_and_provider(self, model_step: ModelStep, value: str) -> ModelStep:
        """Update the model provider in the state."""
        model, provider = get_model_and_provider(value)
        return model_step.update({"model": model, "provider": provider})

    def update_system_prompt(self, model_step: ModelStep, value: str) -> ModelStep:
        """Update the system prompt in the state."""
        return model_step.update_property("system_prompt", value)

    # Field update functions
    def update_input_field_name(self, model_step: ModelStep, value: str, index: int) -> ModelStep:
        """Update a specific field of an input field at the given index."""
        return model_step.update_field("input", index, "name", value)

    def update_input_field_variable(self, model_step: ModelStep, value: str, name: str, index: int) -> ModelStep:
        """Update a specific field of an input field at the given index."""
        if value == UNSELECTED_VAR_NAME:
            return model_step.update_field("input", index, "variable", "")
        if name == "":
            suggested_name = value.split(".", 1)[-1]
            logger.info(f"Updating input field variable to {value}. Suggested name: {suggested_name}")
            model_step = model_step.update_field("input", index, "name", suggested_name)
        return model_step.update_field("input", index, "variable", value)

    def update_input_field_description(self, model_step: ModelStep, value: str, index: int) -> ModelStep:
        """Update a specific field of an input field at the given index."""
        return model_step.update_field("input", index, "description", value)

    def update_output_field_name(self, model_step: ModelStep, value: str, index: int) -> ModelStep:
        """Update a specific field of an output field at the given index."""
        return model_step.update_field("output", index, "name", value)

    def update_output_field_type(self, model_step: ModelStep, value: str, index: int) -> ModelStep:
        """Update a specific field of an output field at the given index."""
        return model_step.update_field("output", index, "type", value)

    def update_output_field_variable(self, model_step: ModelStep, value: str, index: int) -> ModelStep:
        """Update a specific field of an output field at the given index."""
        return model_step.update_field("output", index, "variable", value)

    def update_output_field_description(self, model_step: ModelStep, value: str, index: int) -> ModelStep:
        """Update a specific field of an output field at the given index."""
        return model_step.update_field("output", index, "description", value)

    def make_input_field_updates(self, model_step: ModelStep) -> list[gr.State | dict[str, Any]]:
        fields = model_step.input_fields
        updates = []
        for i in range(self.max_fields["input"]):
            if i < len(fields):
                updates.extend(
                    [
                        gr.update(value=fields[i].name),
                        gr.update(value=fields[i].variable),
                        gr.update(value=fields[i].description),
                    ]
                )
            else:
                updates.extend([gr.skip(), gr.skip(), gr.skip()])
        return updates

    def make_output_field_updates(self, model_step: ModelStep) -> list[gr.State | dict[str, Any]]:
        fields = model_step.output_fields
        updates = []
        for i in range(self.max_fields["output"]):
            if i < len(fields):
                updates.extend(
                    [
                        gr.update(value=fields[i].name),
                        gr.update(value=fields[i].type),
                        gr.update(value=fields[i].description),
                    ]
                )
            else:
                updates.extend([gr.skip(), gr.skip(), gr.skip()])
        return updates

    def _add_field(
        self, model_step: ModelStep, field_type: FieldType, index: int = -1, input_var: str | None = None
    ) -> tuple[Union[ModelStep, int, dict[str, Any]], ...]:
        new_step = model_step.add_field(field_type, index, input_var)
        fields = new_step.fields(field_type)
        row_updates = [gr.update(visible=i < len(fields)) for i in range(self.max_fields[field_type])]
        return new_step, len(fields), *row_updates

    def _delete_field(
        self, model_step: ModelStep, field_type: FieldType, index: int
    ) -> tuple[Union[ModelStep, int, dict[str, Any]], ...]:
        new_step = model_step.delete_field(field_type, index)
        fields = new_step.fields(field_type)
        row_updates = [gr.update(visible=i < len(fields)) for i in range(self.max_fields[field_type])]
        return new_step, len(fields), *row_updates

    # Field add/delete functions
    def add_input_field(self, model_step: ModelStep, index: int = -1):
        updates = self._add_field(model_step, "input", index, input_var=UNSELECTED_VAR_NAME)
        return *updates, *self.make_input_field_updates(model_step)

    def add_output_field(self, model_step: ModelStep, index: int = -1):
        updates = self._add_field(model_step, "output", index)
        return *updates, *self.make_output_field_updates(model_step)

    def delete_input_field(self, model_step: ModelStep, index: int):
        updates = self._delete_field(model_step, "input", index)
        return *updates, *self.make_input_field_updates(model_step)

    def delete_output_field(self, model_step: ModelStep, index: int):
        updates = self._delete_field(model_step, "output", index)
        return *updates, *self.make_output_field_updates(model_step)

    def move_output_field(
        self, model_step: ModelStep, index: int, direction: DIRECTIONS
    ) -> list[gr.State | dict[str, Any]]:
        """
        Move an output field in the list either up or down.

        Args:
            index: Index of the output field to move
            direction: Direction to move the field ('up' or 'down')

        Returns:
            list: A list containing [updated_state, field_value_updates...]
        """
        new_step = model_step.model_copy()
        move_item(new_step.output_fields, index, direction)

        # Update all output fields to reflect the new order
        updates = self.make_output_field_updates(new_step)

        return new_step, *updates