File size: 18,592 Bytes
193db9d
 
 
 
 
 
e1ce295
9756440
 
193db9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1758388
193db9d
 
 
 
 
 
 
1758388
193db9d
e1ce295
193db9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1ce295
193db9d
 
e1ce295
193db9d
 
e1ce295
193db9d
 
 
 
 
 
 
 
 
e1ce295
 
 
 
 
 
 
 
193db9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1ce295
193db9d
 
 
 
e1ce295
193db9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1758388
193db9d
 
 
3b39b49
193db9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f5d1cb
193db9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1ce295
193db9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f5d1cb
193db9d
 
 
4f5d1cb
193db9d
 
 
 
 
 
 
9756440
193db9d
 
 
 
9756440
193db9d
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
import json
from typing import Any

import gradio as gr
from gradio.components import FormComponent

from app_configs import UNSELECTED_VAR_NAME
from components.model_pipeline.state_manager import ModelStepUIState, PipelineStateManager
from components.typed_dicts import PipelineStateDict
from utils import get_full_model_name
from workflows.structs import ModelStep

from .state_manager import ModelStepStateManager
from .ui_components import InputRowButtonGroup, OutputRowButtonGroup


def _make_accordion_label(model_step: ModelStep):
    name = model_step.name if model_step.name else "Untitled"
    input_field_names = [field.name for field in model_step.input_fields]
    inputs_str = ", ".join(input_field_names)
    output_field_names = [field.name for field in model_step.output_fields]
    outputs_str = ", ".join(output_field_names)
    return "{}: {} ({}) → ({})".format(model_step.id, name, inputs_str, outputs_str)


class ModelStepComponent(FormComponent):
    """
    A custom Gradio component representing a single Step in a pipeline.
    It contains:
      1. Model Provider & System Prompt
      2. Inputs – fields with name, description, and variable used
      3. Outputs – fields with name, description, and variable used

    Listens to events:
        - on_model_step_change
        - on_ui_change
    """

    def __init__(
        self,
        value: ModelStep | gr.State,
        ui_state: ModelStepUIState | gr.State | None = None,
        model_options: list[str] | None = None,
        input_variables: list[str] | None = None,
        max_input_fields=5,
        max_output_fields=5,
        max_temperature=5.0,
        pipeline_state_manager: PipelineStateManager | None = None,
        **kwargs,
    ):
        self.max_fields = {
            "input": max_input_fields,
            "output": max_output_fields,
        }
        self.max_temperature = max_temperature
        self.model_options = model_options
        self.input_variables = [UNSELECTED_VAR_NAME] + input_variables
        self.sm = ModelStepStateManager(max_input_fields, max_output_fields)
        self.pipeline_sm: PipelineStateManager = pipeline_state_manager

        self.model_step_state = gr.State(value)
        ui_state = ui_state or ModelStepUIState()
        if not isinstance(ui_state, gr.State):
            ui_state = gr.State(ui_state)
        self.ui_state: gr.State = ui_state

        self.inputs_count_state = gr.State(len(value.input_fields))
        self.outputs_count_state = gr.State(len(value.output_fields))

        # UI components that will be created in render
        self.accordion = None
        self.ui = None
        self.step_name_input = None
        self.model_selection = None
        self.system_prompt = None
        self.input_rows = []
        self.output_rows = []

        super().__init__(**kwargs)
        # self.render()
        self.setup_event_listeners()

    @property
    def model_step(self) -> ModelStep:
        return self.model_step_state.value

    @property
    def step_id(self) -> str:
        return self.model_step.id

    def get_step_config(self) -> dict:
        return self.model_step.model_dump()

    # UI state accessors
    def is_open(self) -> bool:
        return self.ui_state.value.expanded

    def get_active_tab(self) -> str:
        """Get the current active tab."""
        return self.ui_state.value.active_tab

    def _render_input_row(self, i: int) -> tuple[gr.Row, tuple, tuple]:
        """Render a single input row at index i."""
        inputs = self.model_step.input_fields
        is_visible = i < len(inputs)
        label_visible = i == 0
        disable_delete = i == 0 and len(inputs) == 1
        initial_name = inputs[i].name if is_visible else ""
        initial_desc = inputs[i].description if is_visible else ""
        initial_var = inputs[i].variable or UNSELECTED_VAR_NAME if is_visible else UNSELECTED_VAR_NAME

        with gr.Row(visible=is_visible, elem_classes="field-row form") as row:
            button_group = InputRowButtonGroup(disable_delete=disable_delete)

            inp_var = gr.Dropdown(
                choices=self.input_variables,
                label="Variable Used",
                value=initial_var,
                elem_classes="field-variable",
                scale=1,
                show_label=label_visible,
            )
            inp_name = gr.Textbox(
                label="Input Name",
                placeholder="Field name",
                value=initial_name,
                elem_classes="field-name",
                scale=1,
                show_label=label_visible,
            )
            inp_desc = gr.Textbox(
                label="Description",
                placeholder="Field description",
                value=initial_desc,
                elem_classes="field-description",
                scale=3,
                show_label=label_visible,
            )
            fields = (inp_name, inp_var, inp_desc)
            # buttons = (delete_button, add_button)
        return row, fields, button_group

    def _render_output_row(self, i: int) -> tuple[gr.Row, tuple, tuple]:
        """Render a single output row at index i."""
        outputs = self.model_step.output_fields
        is_visible = i < len(outputs)
        label_visible = i == 0
        disable_delete = i == 0 and len(outputs) == 1
        initial_name = outputs[i].name if is_visible else ""
        initial_desc = outputs[i].description if is_visible else ""
        initial_type = outputs[i].type if is_visible else "str"
        with gr.Row(visible=is_visible, elem_classes="field-row") as row:
            button_group = OutputRowButtonGroup(disable_delete=disable_delete)

            out_name = gr.Textbox(
                label="Output Field",
                placeholder="Variable identifier",
                value=initial_name,
                elem_classes="field-name",
                scale=1,
                show_label=label_visible,
            )
            out_type = gr.Dropdown(
                choices=["str", "int", "float", "bool"],
                allow_custom_value=True,
                label="Type",
                value=initial_type,
                elem_classes="field-type",
                scale=0,
                show_label=label_visible,
                interactive=True,
            )
            out_desc = gr.Textbox(
                label="Description",
                placeholder="Field description",
                value=initial_desc,
                elem_classes="field-description",
                scale=3,
                show_label=label_visible,
            )

            fields = (out_name, out_type, out_desc)
        return row, fields, button_group

    def _render_prompt_tab_content(self):
        self.system_prompt = gr.Textbox(
            label="System Prompt",
            placeholder="Enter the system prompt for this step",
            lines=5,
            value=self.model_step.system_prompt,
            elem_classes="system-prompt",
        )

    def _render_inputs_tab_content(self):
        with gr.Column(variant="panel", elem_classes="fields-panel") as self.inputs_column:
            # Render input rows using helper method
            for i in range(self.max_fields["input"]):
                row = self._render_input_row(i)
                self.input_rows.append(row)

    def _render_outputs_tab_content(self):
        with gr.Column(variant="panel", elem_classes="fields-panel") as self.outputs_column:
            # Render output rows using helper method
            for i in range(self.max_fields["output"]):
                row = self._render_output_row(i)
                self.output_rows.append(row)

    def _render_tab_content(self, tab_id: str):
        if tab_id == "model-tab":
            self._render_prompt_tab_content()
        elif tab_id == "inputs-tab":
            self._render_inputs_tab_content()
        elif tab_id == "outputs-tab":
            self._render_outputs_tab_content()

    def _render_header(self, model_options: tuple[str]):
        # Header with step name
        with gr.Row(elem_classes="step-header-row"):
            self.step_name_input = gr.Textbox(
                label="",
                value=self.model_step.name,
                elem_classes="step-name",
                show_label=False,
                placeholder="Model name...",
            )
            unselected_choice = "Select Model..."
            current_value = (
                get_full_model_name(self.model_step.model, self.model_step.provider)
                if self.model_step.model
                else unselected_choice
            )
            self.model_selection = gr.Dropdown(
                choices=[unselected_choice] + model_options,
                label="Model Provider",
                show_label=False,
                value=current_value,
                elem_classes="model-dropdown",
                scale=1,
            )
            self.temperature_slider = gr.Slider(
                value=self.model_step.temperature,
                minimum=0.0,
                maximum=self.max_temperature,
                step=0.05,
                info="Temperature",
                show_label=False,
                show_reset_button=False,
            )

    def render(self):
        """Render the component UI"""
        # Reset UI component lists
        self.input_rows = []
        self.output_rows = []
        self.tabs = {}

        # Create the accordion for this step
        accordion_label = _make_accordion_label(self.model_step)
        self.accordion = gr.Accordion(label=accordion_label, open=self.is_open(), elem_classes="step-accordion")

        # Create the UI content inside the accordion
        with self.accordion:
            self._render_header(self.model_options)

            # Configuration tabs
            selected_tab = self.get_active_tab()
            with gr.Tabs(elem_classes="step-tabs", selected=selected_tab):
                tab_ids = ("model-tab", "inputs-tab", "outputs-tab")
                tab_labels = ("Model", "Inputs", "Outputs")
                for tab_id, label in zip(tab_ids, tab_labels):
                    with gr.TabItem(label, elem_classes="tab-content", id=tab_id) as tab:
                        self._render_tab_content(tab_id)
                        self.tabs[tab_id] = tab

        return self.accordion

    def _setup_event_listeners_for_view_change(self):
        for tab_id, tab in self.tabs.items():
            tab.select(
                fn=self.sm.update_ui_state,
                inputs=[self.ui_state, gr.State("active_tab"), gr.State(tab_id)],
                outputs=[self.ui_state],
            )
        self.accordion.collapse(
            fn=self.sm.update_ui_state,
            inputs=[self.ui_state, gr.State("expanded"), gr.State(False)],
            outputs=[self.ui_state],
        )
        self.accordion.expand(
            fn=self.sm.update_ui_state,
            inputs=[self.ui_state, gr.State("expanded"), gr.State(True)],
            outputs=[self.ui_state],
        )

    def _setup_event_listeners_model_tab(self):
        # Step name change
        self.step_name_input.blur(
            fn=self._update_state_and_label,
            inputs=[self.model_step_state, self.step_name_input],
            outputs=[self.model_step_state, self.accordion],
        )

        self.temperature_slider.release(
            fn=self.sm.update_temperature,
            inputs=[self.model_step_state, self.temperature_slider],
            outputs=[self.model_step_state],
        )

        # Model and system prompt
        self.model_selection.input(
            fn=self.sm.update_model_and_provider,
            inputs=[self.model_step_state, self.model_selection],
            outputs=[self.model_step_state],
        )

        self.system_prompt.blur(
            fn=self.sm.update_system_prompt,
            inputs=[self.model_step_state, self.system_prompt],
            outputs=[self.model_step_state],
        )

    def _setup_event_listeners_inputs_tab(self):
        # Setup input row events
        for i, (row, fields, button_group) in enumerate(self.input_rows):
            inp_name, inp_var, inp_desc = fields
            row_index = gr.State(i)

            # Field change handlers
            inp_name.blur(
                fn=self.sm.update_input_field_name,
                inputs=[self.model_step_state, inp_name, row_index],
                outputs=[self.model_step_state],
            )

            inp_var.change(
                fn=self.sm.update_input_field_variable,
                inputs=[self.model_step_state, inp_var, inp_name, row_index],
                outputs=[self.model_step_state],
            )

            inp_desc.blur(
                fn=self.sm.update_input_field_description,
                inputs=[self.model_step_state, inp_desc, row_index],
                outputs=[self.model_step_state],
            )

            rows = [row for (row, _, _) in self.input_rows]
            input_fields = [field for (_, fields, _) in self.input_rows for field in fields]

            # Button handlers
            button_group.delete(
                fn=self.sm.delete_input_field,
                inputs=[self.model_step_state, row_index],
                outputs=[self.model_step_state, self.inputs_count_state] + rows + input_fields,
            )

            button_group.add(
                fn=self.sm.add_input_field,
                inputs=[self.model_step_state, row_index],
                outputs=[self.model_step_state, self.inputs_count_state] + rows + input_fields,
            )

    def _setup_event_listeners_outputs_tab(self):
        # Setup output row events
        for i, (row, fields, button_group) in enumerate(self.output_rows):
            out_name, out_type, out_desc = fields

            row_index = gr.State(i)

            # Field change handlers
            out_name.blur(
                fn=self.sm.update_output_field_name,
                inputs=[self.model_step_state, out_name, row_index],
                outputs=[self.model_step_state],
            )

            out_type.change(
                fn=self.sm.update_output_field_type,
                inputs=[self.model_step_state, out_type, row_index],
                outputs=[self.model_step_state],
            )

            out_desc.blur(
                fn=self.sm.update_output_field_description,
                inputs=[self.model_step_state, out_desc, row_index],
                outputs=[self.model_step_state],
            )

            rows = [row for (row, _, _) in self.output_rows]
            output_fields = [field for (_, fields, _) in self.output_rows for field in fields]

            # Button handlers
            button_group.delete(
                fn=self.sm.delete_output_field,
                inputs=[self.model_step_state, row_index],
                outputs=[self.model_step_state, self.outputs_count_state] + rows + output_fields,
            )

            button_group.add(
                fn=self.sm.add_output_field,
                inputs=[self.model_step_state, row_index],
                outputs=[self.model_step_state, self.outputs_count_state] + rows + output_fields,
            )

            button_group.up(
                fn=self.sm.move_output_field,
                inputs=[self.model_step_state, row_index, gr.State("up")],
                outputs=[self.model_step_state] + output_fields,
            )

            button_group.down(
                fn=self.sm.move_output_field,
                inputs=[self.model_step_state, row_index, gr.State("down")],
                outputs=[self.model_step_state] + output_fields,
            )

    # Function to set up event listeners - call this separately after all components are rendered
    def setup_event_listeners(self):
        """Set up all event listeners for this component"""
        self._setup_event_listeners_for_view_change()
        self._setup_event_listeners_model_tab()
        self._setup_event_listeners_inputs_tab()
        self._setup_event_listeners_outputs_tab()

        def state_str(x, limited: bool = False):
            d = x.model_dump()
            if limited:
                d = {k: d[k] for k in {"name", "temperature"}}
            return json.dumps(d, indent=2)

        def log_step_states(x, y, src: str):
            print(f"{src} triggered! UI:\n{state_str(x)}\n\nData:\n{state_str(y, True)}")
            print("--------------------------------")
            print(f"self.model_step_state: \n{self.get_step_config()}")
            print("--------------------------------")

        # self.model_step_state.change(
        #     log_step_states,
        #     inputs=[self.ui_state, self.model_step_state, gr.State("Model Change")],
        # )
        # self.ui_state.change(
        #     log_step_states,
        #     inputs=[self.ui_state, self.model_step_state, gr.State("UI Change")],
        # )

    def on_model_step_change(self, fn, inputs, outputs):
        """Set up an event listener for the model change event."""
        return self.model_step_state.change(fn, inputs, outputs)

    def on_ui_change(self, fn, inputs, outputs):
        """Set up an event listener for the UI change event."""
        return self.ui_state.change(fn, inputs, outputs)

    def _update_state_and_label(self, model_step: ModelStep, name: str):
        """Update both the state and the accordion label."""
        new_model_step = self.sm.update_step_name(model_step, name)
        new_label = _make_accordion_label(new_model_step)
        return new_model_step, gr.update(label=new_label)

    def refresh_variable_dropdowns(self, pipeline_state_dict: PipelineStateDict):
        # TODO: Fix this. Not sure why this is needed.
        """Refresh the variable dropdown options in all input rows."""
        variable_choices = []
        if self.pipeline_sm is not None:
            variable_choices = self.pipeline_sm.get_all_variables(pipeline_state_dict)

        for _, fields, _ in self.input_rows:
            _, inp_var, _ = fields
            inp_var.update(choices=variable_choices)

    def _update_model_and_refresh_ui(self, updated_model_step):
        """Update the model step state and refresh UI elements that depend on it."""
        self.model_step_state.value = updated_model_step
        # Update accordion label
        new_label = _make_accordion_label(updated_model_step)
        if self.accordion:
            self.accordion.update(label=new_label)
        return updated_model_step