import json from typing import Any import gradio as gr from gradio.components import FormComponent from app_configs import UNSELECTED_VAR_NAME from components.model_pipeline.state_manager import ModelStepUIState, PipelineStateManager from components.typed_dicts import PipelineStateDict from utils import get_full_model_name from workflows.structs import ModelStep from .state_manager import ModelStepStateManager from .ui_components import InputRowButtonGroup, OutputRowButtonGroup def _make_accordion_label(model_step: ModelStep): name = model_step.name if model_step.name else "Untitled" input_field_names = [field.name for field in model_step.input_fields] inputs_str = ", ".join(input_field_names) output_field_names = [field.name for field in model_step.output_fields] outputs_str = ", ".join(output_field_names) return "{}: {} ({}) → ({})".format(model_step.id, name, inputs_str, outputs_str) class ModelStepComponent(FormComponent): """ A custom Gradio component representing a single Step in a pipeline. It contains: 1. Model Provider & System Prompt 2. Inputs – fields with name, description, and variable used 3. Outputs – fields with name, description, and variable used Listens to events: - on_model_step_change - on_ui_change """ def __init__( self, value: ModelStep | gr.State, ui_state: ModelStepUIState | gr.State | None = None, model_options: list[str] | None = None, input_variables: list[str] | None = None, max_input_fields=5, max_output_fields=5, max_temperature=5.0, pipeline_state_manager: PipelineStateManager | None = None, **kwargs, ): self.max_fields = { "input": max_input_fields, "output": max_output_fields, } self.max_temperature = max_temperature self.model_options = model_options self.input_variables = [UNSELECTED_VAR_NAME] + input_variables self.sm = ModelStepStateManager(max_input_fields, max_output_fields) self.pipeline_sm: PipelineStateManager = pipeline_state_manager self.model_step_state = gr.State(value) ui_state = ui_state or ModelStepUIState() if not isinstance(ui_state, gr.State): ui_state = gr.State(ui_state) self.ui_state: gr.State = ui_state self.inputs_count_state = gr.State(len(value.input_fields)) self.outputs_count_state = gr.State(len(value.output_fields)) # UI components that will be created in render self.accordion = None self.ui = None self.step_name_input = None self.model_selection = None self.system_prompt = None self.input_rows = [] self.output_rows = [] super().__init__(**kwargs) # self.render() self.setup_event_listeners() @property def model_step(self) -> ModelStep: return self.model_step_state.value @property def step_id(self) -> str: return self.model_step.id def get_step_config(self) -> dict: return self.model_step.model_dump() # UI state accessors def is_open(self) -> bool: return self.ui_state.value.expanded def get_active_tab(self) -> str: """Get the current active tab.""" return self.ui_state.value.active_tab def _render_input_row(self, i: int) -> tuple[gr.Row, tuple, tuple]: """Render a single input row at index i.""" inputs = self.model_step.input_fields is_visible = i < len(inputs) label_visible = i == 0 disable_delete = i == 0 and len(inputs) == 1 initial_name = inputs[i].name if is_visible else "" initial_desc = inputs[i].description if is_visible else "" initial_var = inputs[i].variable or UNSELECTED_VAR_NAME if is_visible else UNSELECTED_VAR_NAME with gr.Row(visible=is_visible, elem_classes="field-row form") as row: button_group = InputRowButtonGroup(disable_delete=disable_delete) inp_var = gr.Dropdown( choices=self.input_variables, label="Variable Used", value=initial_var, elem_classes="field-variable", scale=1, show_label=label_visible, ) inp_name = gr.Textbox( label="Input Name", placeholder="Field name", value=initial_name, elem_classes="field-name", scale=1, show_label=label_visible, ) inp_desc = gr.Textbox( label="Description", placeholder="Field description", value=initial_desc, elem_classes="field-description", scale=3, show_label=label_visible, ) fields = (inp_name, inp_var, inp_desc) # buttons = (delete_button, add_button) return row, fields, button_group def _render_output_row(self, i: int) -> tuple[gr.Row, tuple, tuple]: """Render a single output row at index i.""" outputs = self.model_step.output_fields is_visible = i < len(outputs) label_visible = i == 0 disable_delete = i == 0 and len(outputs) == 1 initial_name = outputs[i].name if is_visible else "" initial_desc = outputs[i].description if is_visible else "" initial_type = outputs[i].type if is_visible else "str" with gr.Row(visible=is_visible, elem_classes="field-row") as row: button_group = OutputRowButtonGroup(disable_delete=disable_delete) out_name = gr.Textbox( label="Output Field", placeholder="Variable identifier", value=initial_name, elem_classes="field-name", scale=1, show_label=label_visible, ) out_type = gr.Dropdown( choices=["str", "int", "float", "bool"], allow_custom_value=True, label="Type", value=initial_type, elem_classes="field-type", scale=0, show_label=label_visible, interactive=True, ) out_desc = gr.Textbox( label="Description", placeholder="Field description", value=initial_desc, elem_classes="field-description", scale=3, show_label=label_visible, ) fields = (out_name, out_type, out_desc) return row, fields, button_group def _render_prompt_tab_content(self): self.system_prompt = gr.Textbox( label="System Prompt", placeholder="Enter the system prompt for this step", lines=5, value=self.model_step.system_prompt, elem_classes="system-prompt", ) def _render_inputs_tab_content(self): with gr.Column(variant="panel", elem_classes="fields-panel") as self.inputs_column: # Render input rows using helper method for i in range(self.max_fields["input"]): row = self._render_input_row(i) self.input_rows.append(row) def _render_outputs_tab_content(self): with gr.Column(variant="panel", elem_classes="fields-panel") as self.outputs_column: # Render output rows using helper method for i in range(self.max_fields["output"]): row = self._render_output_row(i) self.output_rows.append(row) def _render_tab_content(self, tab_id: str): if tab_id == "model-tab": self._render_prompt_tab_content() elif tab_id == "inputs-tab": self._render_inputs_tab_content() elif tab_id == "outputs-tab": self._render_outputs_tab_content() def _render_header(self, model_options: tuple[str]): # Header with step name with gr.Row(elem_classes="step-header-row"): self.step_name_input = gr.Textbox( label="", value=self.model_step.name, elem_classes="step-name", show_label=False, placeholder="Model name...", ) unselected_choice = "Select Model..." current_value = ( get_full_model_name(self.model_step.model, self.model_step.provider) if self.model_step.model else unselected_choice ) self.model_selection = gr.Dropdown( choices=[unselected_choice] + model_options, label="Model Provider", show_label=False, value=current_value, elem_classes="model-dropdown", scale=1, ) self.temperature_slider = gr.Slider( value=self.model_step.temperature, minimum=0.0, maximum=self.max_temperature, step=0.05, info="Temperature", show_label=False, show_reset_button=False, ) def render(self): """Render the component UI""" # Reset UI component lists self.input_rows = [] self.output_rows = [] self.tabs = {} # Create the accordion for this step accordion_label = _make_accordion_label(self.model_step) self.accordion = gr.Accordion(label=accordion_label, open=self.is_open(), elem_classes="step-accordion") # Create the UI content inside the accordion with self.accordion: self._render_header(self.model_options) # Configuration tabs selected_tab = self.get_active_tab() with gr.Tabs(elem_classes="step-tabs", selected=selected_tab): tab_ids = ("model-tab", "inputs-tab", "outputs-tab") tab_labels = ("Model", "Inputs", "Outputs") for tab_id, label in zip(tab_ids, tab_labels): with gr.TabItem(label, elem_classes="tab-content", id=tab_id) as tab: self._render_tab_content(tab_id) self.tabs[tab_id] = tab return self.accordion def _setup_event_listeners_for_view_change(self): for tab_id, tab in self.tabs.items(): tab.select( fn=self.sm.update_ui_state, inputs=[self.ui_state, gr.State("active_tab"), gr.State(tab_id)], outputs=[self.ui_state], ) self.accordion.collapse( fn=self.sm.update_ui_state, inputs=[self.ui_state, gr.State("expanded"), gr.State(False)], outputs=[self.ui_state], ) self.accordion.expand( fn=self.sm.update_ui_state, inputs=[self.ui_state, gr.State("expanded"), gr.State(True)], outputs=[self.ui_state], ) def _setup_event_listeners_model_tab(self): # Step name change self.step_name_input.blur( fn=self._update_state_and_label, inputs=[self.model_step_state, self.step_name_input], outputs=[self.model_step_state, self.accordion], ) self.temperature_slider.release( fn=self.sm.update_temperature, inputs=[self.model_step_state, self.temperature_slider], outputs=[self.model_step_state], ) # Model and system prompt self.model_selection.input( fn=self.sm.update_model_and_provider, inputs=[self.model_step_state, self.model_selection], outputs=[self.model_step_state], ) self.system_prompt.blur( fn=self.sm.update_system_prompt, inputs=[self.model_step_state, self.system_prompt], outputs=[self.model_step_state], ) def _setup_event_listeners_inputs_tab(self): # Setup input row events for i, (row, fields, button_group) in enumerate(self.input_rows): inp_name, inp_var, inp_desc = fields row_index = gr.State(i) # Field change handlers inp_name.blur( fn=self.sm.update_input_field_name, inputs=[self.model_step_state, inp_name, row_index], outputs=[self.model_step_state], ) inp_var.change( fn=self.sm.update_input_field_variable, inputs=[self.model_step_state, inp_var, inp_name, row_index], outputs=[self.model_step_state], ) inp_desc.blur( fn=self.sm.update_input_field_description, inputs=[self.model_step_state, inp_desc, row_index], outputs=[self.model_step_state], ) rows = [row for (row, _, _) in self.input_rows] input_fields = [field for (_, fields, _) in self.input_rows for field in fields] # Button handlers button_group.delete( fn=self.sm.delete_input_field, inputs=[self.model_step_state, row_index], outputs=[self.model_step_state, self.inputs_count_state] + rows + input_fields, ) button_group.add( fn=self.sm.add_input_field, inputs=[self.model_step_state, row_index], outputs=[self.model_step_state, self.inputs_count_state] + rows + input_fields, ) def _setup_event_listeners_outputs_tab(self): # Setup output row events for i, (row, fields, button_group) in enumerate(self.output_rows): out_name, out_type, out_desc = fields row_index = gr.State(i) # Field change handlers out_name.blur( fn=self.sm.update_output_field_name, inputs=[self.model_step_state, out_name, row_index], outputs=[self.model_step_state], ) out_type.change( fn=self.sm.update_output_field_type, inputs=[self.model_step_state, out_type, row_index], outputs=[self.model_step_state], ) out_desc.blur( fn=self.sm.update_output_field_description, inputs=[self.model_step_state, out_desc, row_index], outputs=[self.model_step_state], ) rows = [row for (row, _, _) in self.output_rows] output_fields = [field for (_, fields, _) in self.output_rows for field in fields] # Button handlers button_group.delete( fn=self.sm.delete_output_field, inputs=[self.model_step_state, row_index], outputs=[self.model_step_state, self.outputs_count_state] + rows + output_fields, ) button_group.add( fn=self.sm.add_output_field, inputs=[self.model_step_state, row_index], outputs=[self.model_step_state, self.outputs_count_state] + rows + output_fields, ) button_group.up( fn=self.sm.move_output_field, inputs=[self.model_step_state, row_index, gr.State("up")], outputs=[self.model_step_state] + output_fields, ) button_group.down( fn=self.sm.move_output_field, inputs=[self.model_step_state, row_index, gr.State("down")], outputs=[self.model_step_state] + output_fields, ) # Function to set up event listeners - call this separately after all components are rendered def setup_event_listeners(self): """Set up all event listeners for this component""" self._setup_event_listeners_for_view_change() self._setup_event_listeners_model_tab() self._setup_event_listeners_inputs_tab() self._setup_event_listeners_outputs_tab() def state_str(x, limited: bool = False): d = x.model_dump() if limited: d = {k: d[k] for k in {"name", "temperature"}} return json.dumps(d, indent=2) def log_step_states(x, y, src: str): print(f"{src} triggered! UI:\n{state_str(x)}\n\nData:\n{state_str(y, True)}") print("--------------------------------") print(f"self.model_step_state: \n{self.get_step_config()}") print("--------------------------------") # self.model_step_state.change( # log_step_states, # inputs=[self.ui_state, self.model_step_state, gr.State("Model Change")], # ) # self.ui_state.change( # log_step_states, # inputs=[self.ui_state, self.model_step_state, gr.State("UI Change")], # ) def on_model_step_change(self, fn, inputs, outputs): """Set up an event listener for the model change event.""" return self.model_step_state.change(fn, inputs, outputs) def on_ui_change(self, fn, inputs, outputs): """Set up an event listener for the UI change event.""" return self.ui_state.change(fn, inputs, outputs) def _update_state_and_label(self, model_step: ModelStep, name: str): """Update both the state and the accordion label.""" new_model_step = self.sm.update_step_name(model_step, name) new_label = _make_accordion_label(new_model_step) return new_model_step, gr.update(label=new_label) def refresh_variable_dropdowns(self, pipeline_state_dict: PipelineStateDict): # TODO: Fix this. Not sure why this is needed. """Refresh the variable dropdown options in all input rows.""" variable_choices = [] if self.pipeline_sm is not None: variable_choices = self.pipeline_sm.get_all_variables(pipeline_state_dict) for _, fields, _ in self.input_rows: _, inp_var, _ = fields inp_var.update(choices=variable_choices) def _update_model_and_refresh_ui(self, updated_model_step): """Update the model step state and refresh UI elements that depend on it.""" self.model_step_state.value = updated_model_step # Update accordion label new_label = _make_accordion_label(updated_model_step) if self.accordion: self.accordion.update(label=new_label) return updated_model_step