import json from typing import Literal import yaml from app_configs import UNSELECTED_VAR_NAME from components import typed_dicts as td from components import utils from components.structs import ModelStepUIState, PipelineState, PipelineUIState, TossupPipelineState from workflows.factory import create_new_llm_step from workflows.structs import Buzzer, ModelStep, TossupWorkflow, Workflow class PipelineStateManager: """Manages a pipeline of multiple steps.""" def make_pipeline_state(self, state_dict: td.PipelineStateDict) -> PipelineState: """Make a state from a state dictionary.""" return PipelineState(**state_dict) def get_formatted_config(self, state_dict: td.PipelineStateDict, format: Literal["json", "yaml"] = "yaml") -> str: """Get the full pipeline configuration.""" state = self.make_pipeline_state(state_dict) config = state.workflow.model_dump(exclude_defaults=True) if isinstance(state.workflow, TossupWorkflow): buzzer_config = state.workflow.buzzer.model_dump(exclude_defaults=False) config["buzzer"] = buzzer_config if format == "yaml": return yaml.dump(config, default_flow_style=False, sort_keys=False, indent=4) else: return json.dumps(config, indent=4, sort_keys=False) def add_step( self, state_dict: td.PipelineStateDict, pipeline_change: bool, position: int = -1, name="" ) -> td.PipelineStateDict: """Create a new step and return its state.""" state = self.make_pipeline_state(state_dict) step_id = state.get_new_step_id() step_name = name or f"Step {state.n_steps + 1}" new_step = create_new_llm_step(step_id=step_id, name=step_name) state = state.insert_step(position, new_step) return state.model_dump(), not pipeline_change def remove_step( self, state_dict: td.PipelineStateDict, pipeline_change: bool, position: int ) -> td.PipelineStateDict: """Remove a step from the pipeline.""" state = self.make_pipeline_state(state_dict) if 0 <= position < state.n_steps: state = state.remove_step(position) else: raise ValueError(f"Invalid step position: {position}") return state.model_dump(), not pipeline_change def _move_step( self, state_dict: td.PipelineStateDict, position: int, direction: Literal["up", "down"] ) -> tuple[td.PipelineStateDict, bool]: state = self.make_pipeline_state(state_dict) old_order = list(state.ui_state.step_ids) utils.move_item(state.ui_state.step_ids, position, direction) return state.model_dump(), old_order != list(state.ui_state.step_ids) def move_up(self, state_dict: td.PipelineStateDict, pipeline_change: bool, position: int) -> td.PipelineStateDict: """Move a step up in the pipeline.""" new_state_dict, change = self._move_step(state_dict, position, "up") if change: pipeline_change = not pipeline_change return new_state_dict, pipeline_change def move_down( self, state_dict: td.PipelineStateDict, pipeline_change: bool, position: int ) -> td.PipelineStateDict: """Move a step down in the pipeline.""" new_state_dict, change = self._move_step(state_dict, position, "down") if change: pipeline_change = not pipeline_change return new_state_dict, pipeline_change def update_model_step_state( self, state_dict: td.PipelineStateDict, model_step: ModelStep, ui_state: ModelStepUIState ) -> td.PipelineStateDict: """Update a particular model step in the pipeline.""" state = self.make_pipeline_state(state_dict) state = state.update_step(model_step, ui_state) return state.model_dump() def update_output_variables( self, state_dict: td.PipelineStateDict, target: str, produced_variable: str ) -> td.PipelineStateDict: if produced_variable == UNSELECTED_VAR_NAME: produced_variable = None """Update the output variables for a step.""" state = self.make_pipeline_state(state_dict) state.workflow.outputs[target] = produced_variable return state.model_dump() def update_model_step_ui( self, state_dict: td.PipelineStateDict, step_ui: ModelStepUIState, step_id: str ) -> td.PipelineStateDict: """Update a step in the pipeline.""" state = self.make_pipeline_state(state_dict) state.ui_state.steps[step_id] = step_ui.model_copy() return state.model_dump() def get_all_variables(self, state_dict: td.PipelineStateDict, model_step_id: str | None = None) -> list[str]: """Get all variables from all steps.""" return self.make_pipeline_state(state_dict) def parse_yaml_workflow(self, yaml_str: str) -> Workflow: """Parse a YAML workflow.""" workflow = yaml.safe_load(yaml_str) return Workflow(**workflow) def update_workflow_from_code(self, yaml_str: str) -> td.PipelineStateDict: """Update a workflow from a YAML string.""" workflow = self.parse_yaml_workflow(yaml_str) return PipelineState.from_workflow(workflow).model_dump() class TossupPipelineStateManager(PipelineStateManager): """Manages a tossup pipeline state.""" def make_pipeline_state(self, state_dict: td.PipelineStateDict) -> TossupPipelineState: """Make a state from a state dictionary.""" return TossupPipelineState(**state_dict) def parse_yaml_workflow(self, yaml_str: str) -> TossupWorkflow: """Parse a YAML workflow.""" workflow = yaml.safe_load(yaml_str) return TossupWorkflow(**workflow) def update_workflow_from_code(self, yaml_str: str, change_state: bool) -> tuple[td.PipelineStateDict, bool]: """Update a workflow from a YAML string.""" workflow = self.parse_yaml_workflow(yaml_str) return TossupPipelineState.from_workflow(workflow).model_dump(), not change_state def update_buzzer( self, state_dict: td.TossupPipelineStateDict, confidence_threshold: float, method: str, tokens_prob: float | None, ) -> td.TossupPipelineStateDict: """Update the buzzer.""" state = self.make_pipeline_state(state_dict) prob_threshold = float(tokens_prob) if tokens_prob and tokens_prob > 0 else None state.workflow.buzzer = Buzzer( method=method, confidence_threshold=confidence_threshold, prob_threshold=prob_threshold ) return state.model_dump()