Maharshi Gor
Squash merge dictify-states into main
9756440
raw
history blame
18.5 kB
import json
from typing import Any
import gradio as gr
from gradio.components import FormComponent
from app_configs import UNSELECTED_VAR_NAME
from components.model_pipeline.state_manager import ModelStepUIState, PipelineStateManager
from components.typed_dicts import PipelineStateDict
from utils import get_full_model_name
from workflows.structs import ModelStep
from .state_manager import ModelStepStateManager
from .ui_components import InputRowButtonGroup, OutputRowButtonGroup
def _make_accordion_label(model_step: ModelStep):
name = model_step.name if model_step.name else "Untitled"
input_field_names = [field.name for field in model_step.input_fields]
inputs_str = ", ".join(input_field_names)
output_field_names = [field.name for field in model_step.output_fields]
outputs_str = ", ".join(output_field_names)
return "{}: {} ({}) → ({})".format(model_step.id, name, inputs_str, outputs_str)
class ModelStepComponent(FormComponent):
"""
A custom Gradio component representing a single Step in a pipeline.
It contains:
1. Model Provider & System Prompt
2. Inputs – fields with name, description, and variable used
3. Outputs – fields with name, description, and variable used
Listens to events:
- on_model_step_change
- on_ui_change
"""
def __init__(
self,
value: ModelStep | gr.State,
ui_state: ModelStepUIState | gr.State | None = None,
model_options: list[str] | None = None,
input_variables: list[str] | None = None,
max_input_fields=5,
max_output_fields=5,
pipeline_state_manager: PipelineStateManager | None = None,
**kwargs,
):
self.max_fields = {
"input": max_input_fields,
"output": max_output_fields,
}
self.model_options = model_options
self.input_variables = [UNSELECTED_VAR_NAME] + input_variables
self.sm = ModelStepStateManager(max_input_fields, max_output_fields)
self.pipeline_sm: PipelineStateManager = pipeline_state_manager
self.model_step_state = gr.State(value)
ui_state = ui_state or ModelStepUIState()
if not isinstance(ui_state, gr.State):
ui_state = gr.State(ui_state)
self.ui_state: gr.State = ui_state
self.inputs_count_state = gr.State(len(value.input_fields))
self.outputs_count_state = gr.State(len(value.output_fields))
# UI components that will be created in render
self.accordion = None
self.ui = None
self.step_name_input = None
self.model_selection = None
self.system_prompt = None
self.input_rows = []
self.output_rows = []
super().__init__(**kwargs)
# self.render()
self.setup_event_listeners()
@property
def model_step(self) -> ModelStep:
return self.model_step_state.value
@property
def step_id(self) -> str:
return self.model_step.id
def get_step_config(self) -> dict:
return self.model_step.model_dump()
# UI state accessors
def is_open(self) -> bool:
return self.ui_state.value.expanded
def get_active_tab(self) -> str:
"""Get the current active tab."""
return self.ui_state.value.active_tab
def _render_input_row(self, i: int) -> tuple[gr.Row, tuple, tuple]:
"""Render a single input row at index i."""
inputs = self.model_step.input_fields
is_visible = i < len(inputs)
label_visible = i == 0
disable_delete = i == 0 and len(inputs) == 1
initial_name = inputs[i].name if is_visible else ""
initial_desc = inputs[i].description if is_visible else ""
initial_var = inputs[i].variable or UNSELECTED_VAR_NAME if is_visible else UNSELECTED_VAR_NAME
with gr.Row(visible=is_visible, elem_classes="field-row form") as row:
button_group = InputRowButtonGroup(disable_delete=disable_delete)
inp_var = gr.Dropdown(
choices=self.input_variables,
label="Variable Used",
value=initial_var,
elem_classes="field-variable",
scale=1,
show_label=label_visible,
)
inp_name = gr.Textbox(
label="Input Name",
placeholder="Field name",
value=initial_name,
elem_classes="field-name",
scale=1,
show_label=label_visible,
)
inp_desc = gr.Textbox(
label="Description",
placeholder="Field description",
value=initial_desc,
elem_classes="field-description",
scale=3,
show_label=label_visible,
)
fields = (inp_name, inp_var, inp_desc)
# buttons = (delete_button, add_button)
return row, fields, button_group
def _render_output_row(self, i: int) -> tuple[gr.Row, tuple, tuple]:
"""Render a single output row at index i."""
outputs = self.model_step.output_fields
is_visible = i < len(outputs)
label_visible = i == 0
disable_delete = i == 0 and len(outputs) == 1
initial_name = outputs[i].name if is_visible else ""
initial_desc = outputs[i].description if is_visible else ""
initial_type = outputs[i].type if is_visible else "str"
with gr.Row(visible=is_visible, elem_classes="field-row") as row:
button_group = OutputRowButtonGroup(disable_delete=disable_delete)
out_name = gr.Textbox(
label="Output Field",
placeholder="Variable identifier",
value=initial_name,
elem_classes="field-name",
scale=1,
show_label=label_visible,
)
out_type = gr.Dropdown(
choices=["str", "int", "float", "bool"],
allow_custom_value=True,
label="Type",
value=initial_type,
elem_classes="field-type",
scale=0,
show_label=label_visible,
interactive=True,
)
out_desc = gr.Textbox(
label="Description",
placeholder="Field description",
value=initial_desc,
elem_classes="field-description",
scale=3,
show_label=label_visible,
)
fields = (out_name, out_type, out_desc)
return row, fields, button_group
def _render_prompt_tab_content(self):
self.system_prompt = gr.Textbox(
label="System Prompt",
placeholder="Enter the system prompt for this step",
lines=5,
value=self.model_step.system_prompt,
elem_classes="system-prompt",
)
def _render_inputs_tab_content(self):
with gr.Column(variant="panel", elem_classes="fields-panel") as self.inputs_column:
# Render input rows using helper method
for i in range(self.max_fields["input"]):
row = self._render_input_row(i)
self.input_rows.append(row)
def _render_outputs_tab_content(self):
with gr.Column(variant="panel", elem_classes="fields-panel") as self.outputs_column:
# Render output rows using helper method
for i in range(self.max_fields["output"]):
row = self._render_output_row(i)
self.output_rows.append(row)
def _render_tab_content(self, tab_id: str):
if tab_id == "model-tab":
self._render_prompt_tab_content()
elif tab_id == "inputs-tab":
self._render_inputs_tab_content()
elif tab_id == "outputs-tab":
self._render_outputs_tab_content()
def _render_header(self, model_options: tuple[str]):
# Header with step name
with gr.Row(elem_classes="step-header-row"):
self.step_name_input = gr.Textbox(
label="",
value=self.model_step.name,
elem_classes="step-name",
show_label=False,
placeholder="Model name...",
)
unselected_choice = "Select Model..."
current_value = (
get_full_model_name(self.model_step.model, self.model_step.provider)
if self.model_step.model
else unselected_choice
)
self.model_selection = gr.Dropdown(
choices=[unselected_choice] + model_options,
label="Model Provider",
show_label=False,
value=current_value,
elem_classes="model-dropdown",
scale=1,
)
self.temperature_slider = gr.Slider(
value=self.model_step.temperature,
minimum=0.0,
maximum=5,
step=0.05,
info="Temperature",
show_label=False,
show_reset_button=False,
)
def render(self):
"""Render the component UI"""
# Reset UI component lists
self.input_rows = []
self.output_rows = []
self.tabs = {}
# Create the accordion for this step
accordion_label = _make_accordion_label(self.model_step)
self.accordion = gr.Accordion(label=accordion_label, open=self.is_open(), elem_classes="step-accordion")
# Create the UI content inside the accordion
with self.accordion:
self._render_header(self.model_options)
# Configuration tabs
selected_tab = self.get_active_tab()
with gr.Tabs(elem_classes="step-tabs", selected=selected_tab):
tab_ids = ("model-tab", "inputs-tab", "outputs-tab")
tab_labels = ("Model", "Inputs", "Outputs")
for tab_id, label in zip(tab_ids, tab_labels):
with gr.TabItem(label, elem_classes="tab-content", id=tab_id) as tab:
self._render_tab_content(tab_id)
self.tabs[tab_id] = tab
return self.accordion
def _setup_event_listeners_for_view_change(self):
for tab_id, tab in self.tabs.items():
tab.select(
fn=self.sm.update_ui_state,
inputs=[self.ui_state, gr.State("active_tab"), gr.State(tab_id)],
outputs=[self.ui_state],
)
self.accordion.collapse(
fn=self.sm.update_ui_state,
inputs=[self.ui_state, gr.State("expanded"), gr.State(False)],
outputs=[self.ui_state],
)
self.accordion.expand(
fn=self.sm.update_ui_state,
inputs=[self.ui_state, gr.State("expanded"), gr.State(True)],
outputs=[self.ui_state],
)
def _setup_event_listeners_model_tab(self):
# Step name change
self.step_name_input.blur(
fn=self._update_state_and_label,
inputs=[self.model_step_state, self.step_name_input],
outputs=[self.model_step_state, self.accordion],
)
self.temperature_slider.release(
fn=self.sm.update_temperature,
inputs=[self.model_step_state, self.temperature_slider],
outputs=[self.model_step_state],
)
# Model and system prompt
self.model_selection.input(
fn=self.sm.update_model_and_provider,
inputs=[self.model_step_state, self.model_selection],
outputs=[self.model_step_state],
)
self.system_prompt.blur(
fn=self.sm.update_system_prompt,
inputs=[self.model_step_state, self.system_prompt],
outputs=[self.model_step_state],
)
def _setup_event_listeners_inputs_tab(self):
# Setup input row events
for i, (row, fields, button_group) in enumerate(self.input_rows):
inp_name, inp_var, inp_desc = fields
row_index = gr.State(i)
# Field change handlers
inp_name.blur(
fn=self.sm.update_input_field_name,
inputs=[self.model_step_state, inp_name, row_index],
outputs=[self.model_step_state],
)
inp_var.change(
fn=self.sm.update_input_field_variable,
inputs=[self.model_step_state, inp_var, inp_name, row_index],
outputs=[self.model_step_state],
)
inp_desc.blur(
fn=self.sm.update_input_field_description,
inputs=[self.model_step_state, inp_desc, row_index],
outputs=[self.model_step_state],
)
rows = [row for (row, _, _) in self.input_rows]
input_fields = [field for (_, fields, _) in self.input_rows for field in fields]
# Button handlers
button_group.delete(
fn=self.sm.delete_input_field,
inputs=[self.model_step_state, row_index],
outputs=[self.model_step_state, self.inputs_count_state] + rows + input_fields,
)
button_group.add(
fn=self.sm.add_input_field,
inputs=[self.model_step_state, row_index],
outputs=[self.model_step_state, self.inputs_count_state] + rows + input_fields,
)
def _setup_event_listeners_outputs_tab(self):
# Setup output row events
for i, (row, fields, button_group) in enumerate(self.output_rows):
out_name, out_type, out_desc = fields
row_index = gr.State(i)
# Field change handlers
out_name.blur(
fn=self.sm.update_output_field_name,
inputs=[self.model_step_state, out_name, row_index],
outputs=[self.model_step_state],
)
out_type.change(
fn=self.sm.update_output_field_type,
inputs=[self.model_step_state, out_type, row_index],
outputs=[self.model_step_state],
)
out_desc.blur(
fn=self.sm.update_output_field_description,
inputs=[self.model_step_state, out_desc, row_index],
outputs=[self.model_step_state],
)
rows = [row for (row, _, _) in self.output_rows]
output_fields = [field for (_, fields, _) in self.output_rows for field in fields]
# Button handlers
button_group.delete(
fn=self.sm.delete_output_field,
inputs=[self.model_step_state, row_index],
outputs=[self.model_step_state, self.outputs_count_state] + rows + output_fields,
)
button_group.add(
fn=self.sm.add_output_field,
inputs=[self.model_step_state, row_index],
outputs=[self.model_step_state, self.outputs_count_state] + rows + output_fields,
)
button_group.up(
fn=self.sm.move_output_field,
inputs=[self.model_step_state, row_index, gr.State("up")],
outputs=[self.model_step_state] + output_fields,
)
button_group.down(
fn=self.sm.move_output_field,
inputs=[self.model_step_state, row_index, gr.State("down")],
outputs=[self.model_step_state] + output_fields,
)
# Function to set up event listeners - call this separately after all components are rendered
def setup_event_listeners(self):
"""Set up all event listeners for this component"""
self._setup_event_listeners_for_view_change()
self._setup_event_listeners_model_tab()
self._setup_event_listeners_inputs_tab()
self._setup_event_listeners_outputs_tab()
def state_str(x, limited: bool = False):
d = x.model_dump()
if limited:
d = {k: d[k] for k in {"name", "temperature"}}
return json.dumps(d, indent=2)
def log_step_states(x, y, src: str):
print(f"{src} triggered! UI:\n{state_str(x)}\n\nData:\n{state_str(y, True)}")
print("--------------------------------")
print(f"self.model_step_state: \n{self.get_step_config()}")
print("--------------------------------")
# self.model_step_state.change(
# log_step_states,
# inputs=[self.ui_state, self.model_step_state, gr.State("Model Change")],
# )
# self.ui_state.change(
# log_step_states,
# inputs=[self.ui_state, self.model_step_state, gr.State("UI Change")],
# )
def on_model_step_change(self, fn, inputs, outputs):
"""Set up an event listener for the model change event."""
return self.model_step_state.change(fn, inputs, outputs)
def on_ui_change(self, fn, inputs, outputs):
"""Set up an event listener for the UI change event."""
return self.ui_state.change(fn, inputs, outputs)
def _update_state_and_label(self, model_step: ModelStep, name: str):
"""Update both the state and the accordion label."""
new_model_step = self.sm.update_step_name(model_step, name)
new_label = _make_accordion_label(new_model_step)
return new_model_step, gr.update(label=new_label)
def refresh_variable_dropdowns(self, pipeline_state_dict: PipelineStateDict):
# TODO: Fix this. Not sure why this is needed.
"""Refresh the variable dropdown options in all input rows."""
variable_choices = []
if self.pipeline_sm is not None:
variable_choices = self.pipeline_sm.get_all_variables(pipeline_state_dict)
for _, fields, _ in self.input_rows:
_, inp_var, _ = fields
inp_var.update(choices=variable_choices)
def _update_model_and_refresh_ui(self, updated_model_step):
"""Update the model step state and refresh UI elements that depend on it."""
self.model_step_state.value = updated_model_step
# Update accordion label
new_label = _make_accordion_label(updated_model_step)
if self.accordion:
self.accordion.update(label=new_label)
return updated_model_step