|
import gradio as gr |
|
import yaml |
|
from loguru import logger |
|
|
|
from app_configs import UNSELECTED_VAR_NAME |
|
from components import commons |
|
from components import typed_dicts as td |
|
from components.model_pipeline.state_manager import ( |
|
ModelStepUIState, |
|
PipelineState, |
|
PipelineStateManager, |
|
PipelineUIState, |
|
TossupPipelineState, |
|
TossupPipelineStateManager, |
|
) |
|
from components.model_step.model_step import ModelStepComponent |
|
from components.utils import make_state |
|
from workflows.structs import ModelStep, TossupWorkflow, Workflow |
|
from workflows.validators import WorkflowValidator |
|
|
|
from .state_manager import get_output_panel_state |
|
|
|
DEFAULT_MAX_TEMPERATURE = 5.0 |
|
|
|
|
|
class PipelineInterface: |
|
"""UI for the pipeline.""" |
|
|
|
def __init__( |
|
self, |
|
app: gr.Blocks, |
|
workflow: Workflow, |
|
ui_state: PipelineUIState | None = None, |
|
model_options: list[str] = None, |
|
config: dict = {}, |
|
): |
|
self.app = app |
|
self.model_options = model_options |
|
self.config = config |
|
self.simple = self.config.get("simple", False) |
|
ui_state = ui_state or PipelineUIState.from_workflow(workflow) |
|
|
|
|
|
self.workflow_state = make_state(workflow.model_dump()) |
|
self.variables_state = make_state(workflow.get_available_variables()) |
|
self.output_panel_state = make_state(get_output_panel_state(workflow)) |
|
|
|
|
|
self.pipeline_change = gr.State(False) |
|
|
|
if isinstance(workflow, TossupWorkflow): |
|
pipeline_state = TossupPipelineState(workflow=workflow, ui_state=ui_state) |
|
self.sm = TossupPipelineStateManager() |
|
else: |
|
pipeline_state = PipelineState(workflow=workflow, ui_state=ui_state) |
|
self.sm = PipelineStateManager() |
|
self.pipeline_state = make_state(pipeline_state.model_dump()) |
|
|
|
def get_aux_states(pipeline_state_dict: td.PipelineStateDict): |
|
"""Get the auxiliary states for the pipeline.""" |
|
logger.debug("Pipeline changed! Getting aux states for pipeline state.") |
|
pipeline_state = self.sm.make_pipeline_state(pipeline_state_dict) |
|
return ( |
|
pipeline_state.workflow.model_dump(), |
|
pipeline_state.workflow.get_available_variables(), |
|
get_output_panel_state(pipeline_state.workflow), |
|
) |
|
|
|
|
|
self.pipeline_state.change( |
|
get_aux_states, |
|
inputs=[self.pipeline_state], |
|
outputs=[self.workflow_state, self.variables_state, self.output_panel_state], |
|
) |
|
|
|
|
|
self.input_variables = workflow.inputs |
|
self.required_output_variables = list(workflow.outputs.keys()) |
|
|
|
|
|
self.steps_container = None |
|
self.components = [] |
|
|
|
|
|
self.render() |
|
|
|
def _render_step( |
|
self, |
|
model_step: ModelStep, |
|
step_ui_state: ModelStepUIState, |
|
available_variables: list[str], |
|
position: int = 0, |
|
n_steps: int = 1, |
|
): |
|
with gr.Column(elem_classes="step-container"): |
|
|
|
step_interface = ModelStepComponent( |
|
value=model_step, |
|
ui_state=step_ui_state, |
|
model_options=self.model_options, |
|
input_variables=available_variables, |
|
pipeline_state_manager=self.sm, |
|
max_temperature=self.config.get("max_temperature", DEFAULT_MAX_TEMPERATURE), |
|
) |
|
|
|
step_interface.on_model_step_change( |
|
self.sm.update_model_step_state, |
|
inputs=[self.pipeline_state, step_interface.model_step_state, step_interface.ui_state], |
|
outputs=[self.pipeline_state], |
|
) |
|
|
|
step_interface.on_ui_change( |
|
self.sm.update_model_step_ui, |
|
inputs=[self.pipeline_state, step_interface.ui_state, gr.State(model_step.id)], |
|
outputs=[self.pipeline_state], |
|
) |
|
|
|
if self.simple: |
|
return step_interface |
|
|
|
is_multi_step = n_steps > 1 |
|
|
|
|
|
|
|
with gr.Row(elem_classes="step-controls", visible=is_multi_step): |
|
up_button = gr.Button("⬆️ Move Up", elem_classes="step-control-btn", interactive=is_multi_step) |
|
down_button = gr.Button("⬇️ Move Down", elem_classes="step-control-btn", interactive=is_multi_step) |
|
remove_button = gr.Button("🗑️ Remove", elem_classes="step-control-btn", interactive=is_multi_step) |
|
|
|
buttons = (up_button, down_button, remove_button) |
|
self._assign_step_controls(buttons, position) |
|
|
|
return (step_interface, *buttons) |
|
|
|
def _assign_step_controls(self, buttons: tuple[gr.Button, gr.Button, gr.Button], position: int): |
|
up_button, down_button, remove_button = buttons |
|
position = gr.State(position) |
|
up_button.click( |
|
self.sm.move_up, |
|
inputs=[self.pipeline_state, self.pipeline_change, position], |
|
outputs=[self.pipeline_state, self.pipeline_change], |
|
) |
|
down_button.click( |
|
self.sm.move_down, |
|
inputs=[self.pipeline_state, self.pipeline_change, position], |
|
outputs=[self.pipeline_state, self.pipeline_change], |
|
) |
|
remove_button.click( |
|
self.sm.remove_step, |
|
inputs=[self.pipeline_state, self.pipeline_change, position], |
|
outputs=[self.pipeline_state, self.pipeline_change], |
|
) |
|
|
|
def _render_add_step_button(self, position: int): |
|
if position not in {0, -1}: |
|
raise ValueError("Position must be 0 or -1") |
|
row_class = "pipeline-header" if position == 0 else "pipeline-footer" |
|
with gr.Row(elem_classes=row_class): |
|
add_step_btn = gr.Button("➕ Add Step", elem_classes="add-step-button") |
|
add_step_btn.click( |
|
self.sm.add_step, |
|
inputs=[self.pipeline_state, self.pipeline_change, gr.State(position)], |
|
outputs=[self.pipeline_state, self.pipeline_change], |
|
) |
|
return add_step_btn |
|
|
|
def _render_output_panel(self, pipeline_state: PipelineState): |
|
dropdowns = {} |
|
available_variables = pipeline_state.workflow.get_available_variables() |
|
variable_options = [UNSELECTED_VAR_NAME] + [v for v in available_variables if v not in self.input_variables] |
|
with gr.Column(elem_classes="step-accordion control-panel"): |
|
commons.get_panel_header( |
|
header="Final output variables mapping:", |
|
) |
|
with gr.Row(elem_classes="output-fields-row"): |
|
for output_field in self.required_output_variables: |
|
value = pipeline_state.workflow.outputs.get(output_field, UNSELECTED_VAR_NAME) |
|
dropdown = gr.Dropdown( |
|
label=output_field, |
|
value=value, |
|
choices=variable_options, |
|
interactive=True, |
|
elem_classes="output-field-variable", |
|
|
|
) |
|
dropdown.change( |
|
self.sm.update_output_variables, |
|
inputs=[self.pipeline_state, gr.State(output_field), dropdown], |
|
outputs=[self.pipeline_state], |
|
) |
|
dropdowns[output_field] = dropdown |
|
|
|
def update_choices(available_variables: list[str]): |
|
"""Update the choices for the dropdowns""" |
|
return [gr.update(choices=available_variables, value=None, selected=None) for _ in dropdowns.values()] |
|
|
|
self.variables_state.change( |
|
update_choices, |
|
inputs=[self.variables_state], |
|
outputs=list(dropdowns.values()), |
|
) |
|
return dropdowns |
|
|
|
def validate_workflow(self, state_dict: td.PipelineStateDict): |
|
"""Validate the workflow.""" |
|
try: |
|
state = self.sm.make_pipeline_state(state_dict) |
|
WorkflowValidator().validate(state.workflow) |
|
except ValueError as e: |
|
logger.exception(e) |
|
state_dict_str = yaml.dump(state_dict, default_flow_style=False, indent=2) |
|
logger.error(f"Could not validate workflow: \n{state_dict_str}") |
|
raise gr.Error(e) |
|
|
|
def _render_pipeline_header(self): |
|
|
|
input_variables_str = ", ".join([f"`{variable}`" for variable in self.input_variables]) |
|
output_variables_str = ", ".join([f"`{variable}`" for variable in self.required_output_variables]) |
|
if self.simple: |
|
instruction = "Create a simple single LLM call pipeline that takes in the following input variables and outputs the following output variables:" |
|
else: |
|
instruction = "Create a pipeline with the following input and output variables." |
|
gr.Markdown(f"### {instruction}") |
|
gr.Markdown(f"* Input Variables: {input_variables_str}") |
|
gr.Markdown(f"* Output Variables: {output_variables_str}") |
|
|
|
def render(self): |
|
"""Render the pipeline UI.""" |
|
|
|
self.all_components = [] |
|
|
|
self._render_pipeline_header() |
|
|
|
|
|
@gr.render( |
|
triggers=[self.app.load, self.pipeline_change.change], |
|
inputs=[self.pipeline_state], |
|
concurrency_limit=1, |
|
concurrency_id="render_steps", |
|
) |
|
def render_steps(pipeline_state: td.PipelineStateDict, evt: gr.EventData): |
|
"""Render all steps in the pipeline""" |
|
logger.info( |
|
f"Rerender triggered! \nInput Pipeline's UI State:{pipeline_state.get('ui_state')}\n Event: {evt.target} {evt._data}" |
|
) |
|
pipeline_state = self.sm.make_pipeline_state(pipeline_state) |
|
ui_state = pipeline_state.ui_state |
|
workflow = pipeline_state.workflow |
|
components = [] |
|
|
|
step_objects = [] |
|
for i, step_id in enumerate(ui_state.step_ids): |
|
step_data = workflow.steps[step_id] |
|
step_ui_state = ui_state.steps[step_id] |
|
available_variables = pipeline_state.get_available_variables(step_id) |
|
sub_components = self._render_step(step_data, step_ui_state, available_variables, i, ui_state.n_steps) |
|
step_objects.append(sub_components) |
|
|
|
components.append(step_objects) |
|
|
|
if not self.simple: |
|
self._render_add_step_button(-1) |
|
|
|
@gr.render( |
|
triggers=[self.output_panel_state.change, self.app.load], |
|
inputs=[self.pipeline_state], |
|
concurrency_limit=1, |
|
concurrency_id="render_output_fields", |
|
) |
|
def render_output_fields(pipeline_state_dict: td.PipelineStateDict): |
|
pipeline_state = self.sm.make_pipeline_state(pipeline_state_dict) |
|
logger.debug(f"Rerendering output panel: {get_output_panel_state(pipeline_state.workflow)}") |
|
self._render_output_panel(pipeline_state) |
|
|
|
export_btn = gr.Button("Export Pipeline", elem_classes="export-button") |
|
|
|
|
|
|
|
|
|
with gr.Accordion( |
|
"Pipeline Preview (click to expand and edit)", open=False, elem_classes="pipeline-preview" |
|
) as self.config_accordion: |
|
self.config_output = gr.Code( |
|
label="Workflow Configuration", |
|
show_label=False, |
|
language="yaml", |
|
elem_classes="workflow-json", |
|
interactive=True, |
|
autocomplete=True, |
|
) |
|
|
|
|
|
self.config_output.blur( |
|
fn=self.sm.update_workflow_from_code, |
|
inputs=[self.config_output, self.pipeline_change], |
|
outputs=[self.pipeline_state, self.pipeline_change], |
|
) |
|
|
|
|
|
self.add_triggers_for_pipeline_export([export_btn.click], self.pipeline_state, scroll=True) |
|
export_btn.click(self.validate_workflow, inputs=[self.pipeline_state], outputs=[]).success( |
|
fn=lambda: gr.update(visible=True, open=True), outputs=[self.config_accordion] |
|
) |
|
|
|
def add_triggers_for_pipeline_export(self, triggers: list, input_pipeline_state: gr.State, scroll: bool = False): |
|
js = None |
|
if scroll: |
|
js = "() => {document.querySelector('.pipeline-preview').scrollIntoView({behavior: 'smooth'})}" |
|
gr.on( |
|
triggers, |
|
self.validate_workflow, |
|
inputs=[input_pipeline_state], |
|
outputs=[], |
|
).success( |
|
fn=self.sm.get_formatted_config, |
|
inputs=[self.pipeline_state, gr.State("yaml")], |
|
outputs=[self.config_output], |
|
js=js, |
|
) |
|
|