Maharshi Gor
Restructure output panel rendering. This change fixes logprob issue and
1758388
raw
history blame
13.5 kB
import gradio as gr
import yaml
from loguru import logger
from app_configs import UNSELECTED_VAR_NAME
from components import commons
from components import typed_dicts as td
from components.model_pipeline.state_manager import (
ModelStepUIState,
PipelineState,
PipelineStateManager,
PipelineUIState,
TossupPipelineState,
TossupPipelineStateManager,
)
from components.model_step.model_step import ModelStepComponent
from components.utils import make_state
from workflows.structs import ModelStep, TossupWorkflow, Workflow
from workflows.validators import WorkflowValidator
from .state_manager import get_output_panel_state
DEFAULT_MAX_TEMPERATURE = 5.0
class PipelineInterface:
"""UI for the pipeline."""
def __init__(
self,
app: gr.Blocks,
workflow: Workflow,
ui_state: PipelineUIState | None = None,
model_options: list[str] = None,
config: dict = {},
):
self.app = app
self.model_options = model_options
self.config = config
self.simple = self.config.get("simple", False)
ui_state = ui_state or PipelineUIState.from_workflow(workflow)
# Gradio States
self.workflow_state = make_state(workflow.model_dump())
self.variables_state = make_state(workflow.get_available_variables())
self.output_panel_state = make_state(get_output_panel_state(workflow))
# Maintains the toggle state change for pipeline changes through user input.
self.pipeline_change = gr.State(False)
if isinstance(workflow, TossupWorkflow):
pipeline_state = TossupPipelineState(workflow=workflow, ui_state=ui_state)
self.sm = TossupPipelineStateManager()
else:
pipeline_state = PipelineState(workflow=workflow, ui_state=ui_state)
self.sm = PipelineStateManager()
self.pipeline_state = make_state(pipeline_state.model_dump())
def get_aux_states(pipeline_state_dict: td.PipelineStateDict):
"""Get the auxiliary states for the pipeline."""
logger.debug("Pipeline changed! Getting aux states for pipeline state.")
pipeline_state = self.sm.make_pipeline_state(pipeline_state_dict)
return (
pipeline_state.workflow.model_dump(),
pipeline_state.workflow.get_available_variables(),
get_output_panel_state(pipeline_state.workflow),
)
# Triggers for pipeline state changes
self.pipeline_state.change(
get_aux_states,
inputs=[self.pipeline_state],
outputs=[self.workflow_state, self.variables_state, self.output_panel_state],
)
# IO Variables
self.input_variables = workflow.inputs
self.required_output_variables = list(workflow.outputs.keys())
# UI elements
self.steps_container = None
self.components = []
# Render the pipeline UI
self.render()
def _render_step(
self,
model_step: ModelStep,
step_ui_state: ModelStepUIState,
available_variables: list[str],
position: int = 0,
n_steps: int = 1,
):
with gr.Column(elem_classes="step-container"):
# Create the step component
step_interface = ModelStepComponent(
value=model_step,
ui_state=step_ui_state,
model_options=self.model_options,
input_variables=available_variables,
pipeline_state_manager=self.sm,
max_temperature=self.config.get("max_temperature", DEFAULT_MAX_TEMPERATURE),
)
step_interface.on_model_step_change(
self.sm.update_model_step_state,
inputs=[self.pipeline_state, step_interface.model_step_state, step_interface.ui_state],
outputs=[self.pipeline_state],
)
step_interface.on_ui_change(
self.sm.update_model_step_ui,
inputs=[self.pipeline_state, step_interface.ui_state, gr.State(model_step.id)],
outputs=[self.pipeline_state],
)
if self.simple:
return step_interface
is_multi_step = n_steps > 1
# logger.debug(f"Rendering step {position} of {n_steps}")
# Add step controls below
with gr.Row(elem_classes="step-controls", visible=is_multi_step):
up_button = gr.Button("⬆️ Move Up", elem_classes="step-control-btn", interactive=is_multi_step)
down_button = gr.Button("⬇️ Move Down", elem_classes="step-control-btn", interactive=is_multi_step)
remove_button = gr.Button("🗑️ Remove", elem_classes="step-control-btn", interactive=is_multi_step)
buttons = (up_button, down_button, remove_button)
self._assign_step_controls(buttons, position)
return (step_interface, *buttons)
def _assign_step_controls(self, buttons: tuple[gr.Button, gr.Button, gr.Button], position: int):
up_button, down_button, remove_button = buttons
position = gr.State(position)
up_button.click(
self.sm.move_up,
inputs=[self.pipeline_state, self.pipeline_change, position],
outputs=[self.pipeline_state, self.pipeline_change],
)
down_button.click(
self.sm.move_down,
inputs=[self.pipeline_state, self.pipeline_change, position],
outputs=[self.pipeline_state, self.pipeline_change],
)
remove_button.click(
self.sm.remove_step,
inputs=[self.pipeline_state, self.pipeline_change, position],
outputs=[self.pipeline_state, self.pipeline_change],
)
def _render_add_step_button(self, position: int):
if position not in {0, -1}:
raise ValueError("Position must be 0 or -1")
row_class = "pipeline-header" if position == 0 else "pipeline-footer"
with gr.Row(elem_classes=row_class):
add_step_btn = gr.Button("➕ Add Step", elem_classes="add-step-button")
add_step_btn.click(
self.sm.add_step,
inputs=[self.pipeline_state, self.pipeline_change, gr.State(position)],
outputs=[self.pipeline_state, self.pipeline_change],
)
return add_step_btn
def _render_output_panel(self, pipeline_state: PipelineState):
dropdowns = {}
available_variables = pipeline_state.workflow.get_available_variables()
variable_options = [UNSELECTED_VAR_NAME] + [v for v in available_variables if v not in self.input_variables]
with gr.Column(elem_classes="step-accordion control-panel"):
commons.get_panel_header(
header="Final output variables mapping:",
)
with gr.Row(elem_classes="output-fields-row"):
for output_field in self.required_output_variables:
value = pipeline_state.workflow.outputs.get(output_field, UNSELECTED_VAR_NAME)
dropdown = gr.Dropdown(
label=output_field,
value=value,
choices=variable_options,
interactive=True,
elem_classes="output-field-variable",
# show_label=False,
)
dropdown.change(
self.sm.update_output_variables,
inputs=[self.pipeline_state, gr.State(output_field), dropdown],
outputs=[self.pipeline_state],
)
dropdowns[output_field] = dropdown
def update_choices(available_variables: list[str]):
"""Update the choices for the dropdowns"""
return [gr.update(choices=available_variables, value=None, selected=None) for _ in dropdowns.values()]
self.variables_state.change(
update_choices,
inputs=[self.variables_state],
outputs=list(dropdowns.values()),
)
return dropdowns
def validate_workflow(self, state_dict: td.PipelineStateDict):
"""Validate the workflow."""
try:
state = self.sm.make_pipeline_state(state_dict)
WorkflowValidator().validate(state.workflow)
except ValueError as e:
logger.exception(e)
state_dict_str = yaml.dump(state_dict, default_flow_style=False, indent=2)
logger.error(f"Could not validate workflow: \n{state_dict_str}")
raise gr.Error(e)
def _render_pipeline_header(self):
# Add Step button at top
input_variables_str = ", ".join([f"`{variable}`" for variable in self.input_variables])
output_variables_str = ", ".join([f"`{variable}`" for variable in self.required_output_variables])
if self.simple:
instruction = "Create a simple single LLM call pipeline that takes in the following input variables and outputs the following output variables:"
else:
instruction = "Create a pipeline with the following input and output variables."
gr.Markdown(f"### {instruction}")
gr.Markdown(f"* Input Variables: {input_variables_str}")
gr.Markdown(f"* Output Variables: {output_variables_str}")
def render(self):
"""Render the pipeline UI."""
# Create a placeholder for all the step components
self.all_components = []
self._render_pipeline_header()
# Function to render all steps
@gr.render(
triggers=[self.app.load, self.pipeline_change.change],
inputs=[self.pipeline_state],
concurrency_limit=1,
concurrency_id="render_steps",
)
def render_steps(pipeline_state: td.PipelineStateDict, evt: gr.EventData):
"""Render all steps in the pipeline"""
logger.info(
f"Rerender triggered! \nInput Pipeline's UI State:{pipeline_state.get('ui_state')}\n Event: {evt.target} {evt._data}"
)
pipeline_state = self.sm.make_pipeline_state(pipeline_state)
ui_state = pipeline_state.ui_state
workflow = pipeline_state.workflow
components = []
step_objects = [] # Reset step objects list
for i, step_id in enumerate(ui_state.step_ids):
step_data = workflow.steps[step_id]
step_ui_state = ui_state.steps[step_id]
available_variables = pipeline_state.get_available_variables(step_id)
sub_components = self._render_step(step_data, step_ui_state, available_variables, i, ui_state.n_steps)
step_objects.append(sub_components)
components.append(step_objects)
if not self.simple:
self._render_add_step_button(-1)
@gr.render(
triggers=[self.output_panel_state.change, self.app.load],
inputs=[self.pipeline_state],
concurrency_limit=1,
concurrency_id="render_output_fields",
)
def render_output_fields(pipeline_state_dict: td.PipelineStateDict):
pipeline_state = self.sm.make_pipeline_state(pipeline_state_dict)
logger.debug(f"Rerendering output panel: {get_output_panel_state(pipeline_state.workflow)}")
self._render_output_panel(pipeline_state)
export_btn = gr.Button("Export Pipeline", elem_classes="export-button")
# components.append(export_btn)
# Add a code box to display the workflow JSON
# with gr.Column(elem_classes="workflow-json-container"):
with gr.Accordion(
"Pipeline Preview (click to expand and edit)", open=False, elem_classes="pipeline-preview"
) as self.config_accordion:
self.config_output = gr.Code(
label="Workflow Configuration",
show_label=False,
language="yaml",
elem_classes="workflow-json",
interactive=True,
autocomplete=True,
)
# components.append(config_accordion)
self.config_output.blur(
fn=self.sm.update_workflow_from_code,
inputs=[self.config_output, self.pipeline_change],
outputs=[self.pipeline_state, self.pipeline_change],
)
# Connect the export button to show the workflow JSON
self.add_triggers_for_pipeline_export([export_btn.click], self.pipeline_state, scroll=True)
export_btn.click(self.validate_workflow, inputs=[self.pipeline_state], outputs=[]).success(
fn=lambda: gr.update(visible=True, open=True), outputs=[self.config_accordion]
)
def add_triggers_for_pipeline_export(self, triggers: list, input_pipeline_state: gr.State, scroll: bool = False):
js = None
if scroll:
js = "() => {document.querySelector('.pipeline-preview').scrollIntoView({behavior: 'smooth'})}"
gr.on(
triggers,
self.validate_workflow,
inputs=[input_pipeline_state],
outputs=[],
).success(
fn=self.sm.get_formatted_config,
inputs=[self.pipeline_state, gr.State("yaml")],
outputs=[self.config_output],
js=js,
)