|
import json |
|
|
|
import gradio as gr |
|
import yaml |
|
|
|
from components.model_pipeline.state_manager import ( |
|
ModelStepUIState, |
|
PipelineState, |
|
PipelineStateManager, |
|
PipelineUIState, |
|
) |
|
from components.model_step.model_step import ModelStepComponent |
|
from components.utils import make_state |
|
from workflows.structs import ModelStep, Workflow |
|
from workflows.validators import WorkflowValidator |
|
|
|
|
|
def validate_simple_workflow(workflow: Workflow, required_output_variables: list[str]) -> Workflow: |
|
"""Validate the workflow.""" |
|
step = next(iter(workflow.steps.values())) |
|
if not step.output_fields: |
|
raise ValueError("No output fields found in the workflow") |
|
output_field_names = {output.name for output in step.output_fields} |
|
if not set(required_output_variables) <= output_field_names: |
|
missing_vars = required_output_variables - output_field_names |
|
raise ValueError(f"Missing required output variables: {missing_vars}") |
|
return workflow |
|
|
|
|
|
def validate_complex_workflow(workflow: Workflow, required_output_variables: list[str]) -> Workflow: |
|
"""Validate the workflow.""" |
|
print("Validating complex workflow.") |
|
return workflow |
|
step = next(iter(workflow.steps.values())) |
|
if not step.output_fields: |
|
raise ValueError("No output fields found in the workflow") |
|
output_field_names = {output.name for output in step.output_fields} |
|
if not output_field_names <= set(required_output_variables): |
|
missing_vars = output_field_names - set(required_output_variables) |
|
raise ValueError(f"Missing required output variables: {missing_vars}") |
|
return workflow |
|
|
|
|
|
def parse_yaml_workflow(yaml_str: str) -> Workflow: |
|
"""Parse a YAML workflow.""" |
|
workflow = yaml.safe_load(yaml_str) |
|
return Workflow(**workflow) |
|
|
|
|
|
def update_workflow_from_code(yaml_str: str, ui_state: PipelineUIState) -> PipelineState: |
|
"""Update a workflow from a YAML string.""" |
|
workflow = parse_yaml_workflow(yaml_str) |
|
ui_state = PipelineUIState.from_workflow(workflow) |
|
return PipelineState(workflow=workflow, ui_state=ui_state) |
|
|
|
|
|
class PipelineInterface: |
|
"""UI for the pipeline.""" |
|
|
|
def __init__( |
|
self, |
|
workflow: Workflow, |
|
ui_state: PipelineUIState | None = None, |
|
model_options: list[str] = None, |
|
simple: bool = False, |
|
): |
|
self.model_options = model_options |
|
self.simple = simple |
|
if not ui_state: |
|
ui_state = PipelineUIState.from_workflow(workflow) |
|
self.ui_state = make_state(ui_state) |
|
self.pipeline_state = make_state(PipelineState(workflow=workflow, ui_state=ui_state)) |
|
self.variables_state = make_state(workflow.get_available_variables()) |
|
|
|
self.sm = PipelineStateManager() |
|
self.input_variables = workflow.inputs |
|
self.required_output_variables = list(workflow.outputs.keys()) |
|
|
|
|
|
self.steps_container = None |
|
self.components = [] |
|
|
|
|
|
self.render() |
|
|
|
def _render_step( |
|
self, |
|
model_step: ModelStep, |
|
step_ui_state: ModelStepUIState, |
|
available_variables: list[str], |
|
position: int = 0, |
|
): |
|
with gr.Column(elem_classes="step-container"): |
|
|
|
step_interface = ModelStepComponent( |
|
value=model_step, |
|
ui_state=step_ui_state, |
|
model_options=self.model_options, |
|
input_variables=available_variables, |
|
pipeline_state_manager=self.sm, |
|
) |
|
|
|
step_interface.on_model_step_change( |
|
self.sm.update_model_step_state, |
|
inputs=[self.pipeline_state, step_interface.model_step_state, step_interface.ui_state], |
|
outputs=[self.pipeline_state, self.ui_state, self.variables_state], |
|
) |
|
|
|
step_interface.on_ui_change( |
|
self.sm.update_model_step_ui, |
|
inputs=[self.pipeline_state, step_interface.ui_state, gr.State(model_step.id)], |
|
outputs=[self.pipeline_state, self.ui_state], |
|
) |
|
|
|
if self.simple: |
|
return step_interface |
|
|
|
|
|
with gr.Row(elem_classes="step-controls"): |
|
up_button = gr.Button("⬆️ Move Up", elem_classes="step-control-btn") |
|
down_button = gr.Button("⬇️ Move Down", elem_classes="step-control-btn") |
|
remove_button = gr.Button("🗑️ Remove", elem_classes="step-control-btn") |
|
|
|
buttons = (up_button, down_button, remove_button) |
|
self._assign_step_controls(buttons, position) |
|
|
|
return (step_interface, *buttons) |
|
|
|
def _assign_step_controls(self, buttons: tuple[gr.Button, gr.Button, gr.Button], position: int): |
|
up_button, down_button, remove_button = buttons |
|
position = gr.State(position) |
|
up_button.click(self.sm.move_up, inputs=[self.ui_state, position], outputs=self.ui_state) |
|
down_button.click(self.sm.move_down, inputs=[self.ui_state, position], outputs=self.ui_state) |
|
remove_button.click( |
|
self.sm.remove_step, |
|
inputs=[self.pipeline_state, position], |
|
outputs=[self.pipeline_state, self.ui_state, self.variables_state], |
|
) |
|
|
|
def _render_add_step_button(self, position: int): |
|
if position not in {0, -1}: |
|
raise ValueError("Position must be 0 or -1") |
|
row_class = "pipeline-header" if position == 0 else "pipeline-footer" |
|
with gr.Row(elem_classes=row_class): |
|
add_step_btn = gr.Button("➕ Add Step", elem_classes="add-step-button") |
|
add_step_btn.click( |
|
self.sm.add_step, |
|
inputs=[self.pipeline_state, gr.State(position)], |
|
outputs=[self.pipeline_state, self.ui_state, self.variables_state], |
|
) |
|
return add_step_btn |
|
|
|
def _render_output_fields(self, available_variables: list[str], pipeline_state: PipelineState): |
|
dropdowns = {} |
|
UNSET_VALUE = "Choose variable..." |
|
variable_options = [UNSET_VALUE] + [v for v in available_variables if v not in self.input_variables] |
|
with gr.Column(elem_classes="step-accordion"): |
|
with gr.Row(elem_classes="output-fields-header"): |
|
gr.Markdown("#### Final output variables mapping:") |
|
with gr.Row(elem_classes="output-fields-row"): |
|
for output_field in self.required_output_variables: |
|
value = pipeline_state.workflow.outputs[output_field] |
|
if not value: |
|
value = UNSET_VALUE |
|
dropdown = gr.Dropdown( |
|
label=output_field, |
|
value=value, |
|
choices=variable_options, |
|
interactive=True, |
|
elem_classes="output-field-variable", |
|
|
|
) |
|
dropdown.change( |
|
self.sm.update_output_variables, |
|
inputs=[self.pipeline_state, gr.State(output_field), dropdown], |
|
outputs=[self.pipeline_state], |
|
) |
|
dropdowns[output_field] = dropdown |
|
|
|
def update_choices(available_variables): |
|
"""Update the choices for the dropdowns""" |
|
return [ |
|
gr.update(choices=available_variables, value=None, selected=None) for dropdown in dropdowns.values() |
|
] |
|
|
|
self.variables_state.change( |
|
update_choices, |
|
inputs=[self.variables_state], |
|
outputs=list(dropdowns.values()), |
|
) |
|
return dropdowns |
|
|
|
def validate_workflow(self, state: PipelineState) -> PipelineState: |
|
"""Validate the workflow.""" |
|
try: |
|
if self.simple: |
|
workflow = validate_simple_workflow(state.workflow, self.required_output_variables) |
|
else: |
|
workflow = validate_complex_workflow(state.workflow, self.required_output_variables) |
|
state.workflow = workflow |
|
return state |
|
except ValueError as e: |
|
raise gr.Error(e) |
|
|
|
def _render_pipeline_header(self): |
|
|
|
input_variables_str = ", ".join([f"`{variable}`" for variable in self.input_variables]) |
|
output_variables_str = ", ".join([f"`{variable}`" for variable in self.required_output_variables]) |
|
if self.simple: |
|
instruction = "Create a simple single LLM call pipeline that takes in the following input variables and outputs the following output variables:" |
|
else: |
|
instruction = "Create a pipeline that takes in the following input variables and outputs the following output variables:" |
|
gr.Markdown(f"### {instruction}") |
|
gr.Markdown(f"Input Variables: {input_variables_str}") |
|
gr.Markdown(f"Output Variables: {output_variables_str}") |
|
|
|
|
|
|
|
|
|
def render(self): |
|
"""Render the pipeline UI.""" |
|
|
|
self.all_components = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._render_pipeline_header() |
|
|
|
|
|
@gr.render(inputs=[self.pipeline_state, self.ui_state]) |
|
def render_steps(state, ui_state): |
|
"""Render all steps in the pipeline""" |
|
workflow = state.workflow |
|
print(f"\nRerender triggered! Current UI State:{ui_state}") |
|
components = [] |
|
|
|
step_objects = [] |
|
for i, step_id in enumerate(ui_state.step_ids): |
|
step_data = workflow.steps[step_id] |
|
step_ui_state = ui_state.steps[step_id] |
|
available_variables = self.sm.get_all_variables(state, step_id) |
|
sub_components = self._render_step(step_data, step_ui_state, available_variables, i) |
|
step_objects.append(sub_components) |
|
|
|
components.append(step_objects) |
|
|
|
|
|
if not self.simple: |
|
self._render_add_step_button(-1) |
|
|
|
@gr.render(inputs=[self.variables_state, self.pipeline_state]) |
|
def render_output_fields(available_variables, pipeline_state): |
|
return self._render_output_fields(available_variables, pipeline_state) |
|
|
|
export_btn = gr.Button("Export Pipeline", elem_classes="export-button") |
|
|
|
|
|
|
|
|
|
with gr.Accordion("Pipeline Preview", open=False, elem_classes="pipeline-preview") as config_accordion: |
|
config_output = gr.Code( |
|
label="Workflow Configuration", |
|
language="yaml", |
|
elem_classes="workflow-json", |
|
interactive=True, |
|
autocomplete=True, |
|
) |
|
|
|
|
|
config_output.blur( |
|
fn=update_workflow_from_code, |
|
inputs=[config_output, self.ui_state], |
|
outputs=[self.pipeline_state], |
|
) |
|
|
|
|
|
export_btn.click(self.validate_workflow, inputs=[self.pipeline_state], outputs=[self.pipeline_state]).success( |
|
fn=lambda: gr.update(visible=True, open=True), outputs=[config_accordion] |
|
) |
|
export_btn.click( |
|
fn=self.sm.get_formatted_config, |
|
inputs=[self.pipeline_state, gr.State("yaml")], |
|
outputs=[config_output], |
|
js="() => {document.querySelector('.pipeline-preview').scrollIntoView({behavior: 'smooth'})}", |
|
) |
|
|
|
|
|
|