import gradio as gr import yaml from loguru import logger from app_configs import UNSELECTED_VAR_NAME from components import commons from components import typed_dicts as td from components.model_pipeline.state_manager import ( ModelStepUIState, PipelineState, PipelineStateManager, PipelineUIState, TossupPipelineState, TossupPipelineStateManager, ) from components.model_step.model_step import ModelStepComponent from components.utils import make_state from workflows.structs import ModelStep, TossupWorkflow, Workflow from workflows.validators import WorkflowValidator from .state_manager import get_output_panel_state DEFAULT_MAX_TEMPERATURE = 5.0 class PipelineInterface: """UI for the pipeline.""" def __init__( self, app: gr.Blocks, workflow: Workflow, ui_state: PipelineUIState | None = None, model_options: list[str] = None, config: dict = {}, ): self.app = app self.model_options = model_options self.config = config self.simple = self.config.get("simple", False) ui_state = ui_state or PipelineUIState.from_workflow(workflow) # Gradio States self.workflow_state = make_state(workflow.model_dump()) self.variables_state = make_state(workflow.get_available_variables()) self.output_panel_state = make_state(get_output_panel_state(workflow)) # Maintains the toggle state change for pipeline changes through user input. self.pipeline_change = gr.State(False) if isinstance(workflow, TossupWorkflow): pipeline_state = TossupPipelineState(workflow=workflow, ui_state=ui_state) self.sm = TossupPipelineStateManager() else: pipeline_state = PipelineState(workflow=workflow, ui_state=ui_state) self.sm = PipelineStateManager() self.pipeline_state = make_state(pipeline_state.model_dump()) def get_aux_states(pipeline_state_dict: td.PipelineStateDict): """Get the auxiliary states for the pipeline.""" logger.debug("Pipeline changed! Getting aux states for pipeline state.") pipeline_state = self.sm.make_pipeline_state(pipeline_state_dict) return ( pipeline_state.workflow.model_dump(), pipeline_state.workflow.get_available_variables(), get_output_panel_state(pipeline_state.workflow), ) # Triggers for pipeline state changes self.pipeline_state.change( get_aux_states, inputs=[self.pipeline_state], outputs=[self.workflow_state, self.variables_state, self.output_panel_state], ) # IO Variables self.input_variables = workflow.inputs self.required_output_variables = list(workflow.outputs.keys()) # UI elements self.steps_container = None self.components = [] # Render the pipeline UI self.render() def _render_step( self, model_step: ModelStep, step_ui_state: ModelStepUIState, available_variables: list[str], position: int = 0, n_steps: int = 1, ): with gr.Column(elem_classes="step-container"): # Create the step component step_interface = ModelStepComponent( value=model_step, ui_state=step_ui_state, model_options=self.model_options, input_variables=available_variables, pipeline_state_manager=self.sm, max_temperature=self.config.get("max_temperature", DEFAULT_MAX_TEMPERATURE), ) step_interface.on_model_step_change( self.sm.update_model_step_state, inputs=[self.pipeline_state, step_interface.model_step_state, step_interface.ui_state], outputs=[self.pipeline_state], ) step_interface.on_ui_change( self.sm.update_model_step_ui, inputs=[self.pipeline_state, step_interface.ui_state, gr.State(model_step.id)], outputs=[self.pipeline_state], ) if self.simple: return step_interface is_multi_step = n_steps > 1 # logger.debug(f"Rendering step {position} of {n_steps}") # Add step controls below with gr.Row(elem_classes="step-controls", visible=is_multi_step): up_button = gr.Button("⬆️ Move Up", elem_classes="step-control-btn", interactive=is_multi_step) down_button = gr.Button("⬇️ Move Down", elem_classes="step-control-btn", interactive=is_multi_step) remove_button = gr.Button("🗑️ Remove", elem_classes="step-control-btn", interactive=is_multi_step) buttons = (up_button, down_button, remove_button) self._assign_step_controls(buttons, position) return (step_interface, *buttons) def _assign_step_controls(self, buttons: tuple[gr.Button, gr.Button, gr.Button], position: int): up_button, down_button, remove_button = buttons position = gr.State(position) up_button.click( self.sm.move_up, inputs=[self.pipeline_state, self.pipeline_change, position], outputs=[self.pipeline_state, self.pipeline_change], ) down_button.click( self.sm.move_down, inputs=[self.pipeline_state, self.pipeline_change, position], outputs=[self.pipeline_state, self.pipeline_change], ) remove_button.click( self.sm.remove_step, inputs=[self.pipeline_state, self.pipeline_change, position], outputs=[self.pipeline_state, self.pipeline_change], ) def _render_add_step_button(self, position: int): if position not in {0, -1}: raise ValueError("Position must be 0 or -1") row_class = "pipeline-header" if position == 0 else "pipeline-footer" with gr.Row(elem_classes=row_class): add_step_btn = gr.Button("➕ Add Step", elem_classes="add-step-button") add_step_btn.click( self.sm.add_step, inputs=[self.pipeline_state, self.pipeline_change, gr.State(position)], outputs=[self.pipeline_state, self.pipeline_change], ) return add_step_btn def _render_output_panel(self, pipeline_state: PipelineState): dropdowns = {} available_variables = pipeline_state.workflow.get_available_variables() variable_options = [UNSELECTED_VAR_NAME] + [v for v in available_variables if v not in self.input_variables] with gr.Column(elem_classes="step-accordion control-panel"): commons.get_panel_header( header="Final output variables mapping:", ) with gr.Row(elem_classes="output-fields-row"): for output_field in self.required_output_variables: value = pipeline_state.workflow.outputs.get(output_field, UNSELECTED_VAR_NAME) dropdown = gr.Dropdown( label=output_field, value=value, choices=variable_options, interactive=True, elem_classes="output-field-variable", # show_label=False, ) dropdown.change( self.sm.update_output_variables, inputs=[self.pipeline_state, gr.State(output_field), dropdown], outputs=[self.pipeline_state], ) dropdowns[output_field] = dropdown def update_choices(available_variables: list[str]): """Update the choices for the dropdowns""" return [gr.update(choices=available_variables, value=None, selected=None) for _ in dropdowns.values()] self.variables_state.change( update_choices, inputs=[self.variables_state], outputs=list(dropdowns.values()), ) return dropdowns def validate_workflow(self, state_dict: td.PipelineStateDict): """Validate the workflow.""" try: state = self.sm.make_pipeline_state(state_dict) WorkflowValidator().validate(state.workflow) except ValueError as e: logger.exception(e) state_dict_str = yaml.dump(state_dict, default_flow_style=False, indent=2) logger.error(f"Could not validate workflow: \n{state_dict_str}") raise gr.Error(e) def _render_pipeline_header(self): # Add Step button at top input_variables_str = ", ".join([f"`{variable}`" for variable in self.input_variables]) output_variables_str = ", ".join([f"`{variable}`" for variable in self.required_output_variables]) if self.simple: instruction = "Create a simple single LLM call pipeline that takes in the following input variables and outputs the following output variables:" else: instruction = "Create a pipeline with the following input and output variables." gr.Markdown(f"### {instruction}") gr.Markdown(f"* Input Variables: {input_variables_str}") gr.Markdown(f"* Output Variables: {output_variables_str}") def render(self): """Render the pipeline UI.""" # Create a placeholder for all the step components self.all_components = [] self._render_pipeline_header() # Function to render all steps @gr.render( triggers=[self.app.load, self.pipeline_change.change], inputs=[self.pipeline_state], concurrency_limit=1, concurrency_id="render_steps", ) def render_steps(pipeline_state: td.PipelineStateDict, evt: gr.EventData): """Render all steps in the pipeline""" logger.info( f"Rerender triggered! \nInput Pipeline's UI State:{pipeline_state.get('ui_state')}\n Event: {evt.target} {evt._data}" ) pipeline_state = self.sm.make_pipeline_state(pipeline_state) ui_state = pipeline_state.ui_state workflow = pipeline_state.workflow components = [] step_objects = [] # Reset step objects list for i, step_id in enumerate(ui_state.step_ids): step_data = workflow.steps[step_id] step_ui_state = ui_state.steps[step_id] available_variables = pipeline_state.get_available_variables(step_id) sub_components = self._render_step(step_data, step_ui_state, available_variables, i, ui_state.n_steps) step_objects.append(sub_components) components.append(step_objects) if not self.simple: self._render_add_step_button(-1) @gr.render( triggers=[self.output_panel_state.change, self.app.load], inputs=[self.pipeline_state], concurrency_limit=1, concurrency_id="render_output_fields", ) def render_output_fields(pipeline_state_dict: td.PipelineStateDict): pipeline_state = self.sm.make_pipeline_state(pipeline_state_dict) logger.debug(f"Rerendering output panel: {get_output_panel_state(pipeline_state.workflow)}") self._render_output_panel(pipeline_state) export_btn = gr.Button("Export Pipeline", elem_classes="export-button") # components.append(export_btn) # Add a code box to display the workflow JSON # with gr.Column(elem_classes="workflow-json-container"): with gr.Accordion( "Pipeline Preview (click to expand and edit)", open=False, elem_classes="pipeline-preview" ) as self.config_accordion: self.config_output = gr.Code( label="Workflow Configuration", show_label=False, language="yaml", elem_classes="workflow-json", interactive=True, autocomplete=True, ) # components.append(config_accordion) self.config_output.blur( fn=self.sm.update_workflow_from_code, inputs=[self.config_output, self.pipeline_change], outputs=[self.pipeline_state, self.pipeline_change], ) # Connect the export button to show the workflow JSON self.add_triggers_for_pipeline_export([export_btn.click], self.pipeline_state, scroll=True) export_btn.click(self.validate_workflow, inputs=[self.pipeline_state], outputs=[]).success( fn=lambda: gr.update(visible=True, open=True), outputs=[self.config_accordion] ) def add_triggers_for_pipeline_export(self, triggers: list, input_pipeline_state: gr.State, scroll: bool = False): js = None if scroll: js = "() => {document.querySelector('.pipeline-preview').scrollIntoView({behavior: 'smooth'})}" gr.on( triggers, self.validate_workflow, inputs=[input_pipeline_state], outputs=[], ).success( fn=self.sm.get_formatted_config, inputs=[self.pipeline_state, gr.State("yaml")], outputs=[self.config_output], js=js, )