Maharshi Gor commited on
Commit
e00ec4e
·
1 Parent(s): 3b39b49

BugFix: Step Creation and removal.

Browse files

* disallowed removing when there is only step
* reliable new step id creation

src/components/model_pipeline/model_pipeline.py CHANGED
@@ -1,7 +1,6 @@
1
- import json
2
-
3
  import gradio as gr
4
  import yaml
 
5
 
6
  from components.model_pipeline.state_manager import (
7
  ModelStepUIState,
@@ -89,6 +88,7 @@ class PipelineInterface:
89
  step_ui_state: ModelStepUIState,
90
  available_variables: list[str],
91
  position: int = 0,
 
92
  ):
93
  with gr.Column(elem_classes="step-container"):
94
  # Create the step component
@@ -115,11 +115,14 @@ class PipelineInterface:
115
  if self.simple:
116
  return step_interface
117
 
 
 
 
118
  # Add step controls below
119
- with gr.Row(elem_classes="step-controls"):
120
- up_button = gr.Button("⬆️ Move Up", elem_classes="step-control-btn")
121
- down_button = gr.Button("⬇️ Move Down", elem_classes="step-control-btn")
122
- remove_button = gr.Button("🗑️ Remove", elem_classes="step-control-btn")
123
 
124
  buttons = (up_button, down_button, remove_button)
125
  self._assign_step_controls(buttons, position)
@@ -232,10 +235,10 @@ class PipelineInterface:
232
 
233
  # Function to render all steps
234
  @gr.render(inputs=[self.pipeline_state, self.ui_state])
235
- def render_steps(state, ui_state):
236
  """Render all steps in the pipeline"""
 
237
  workflow = state.workflow
238
- print(f"\nRerender triggered! Current UI State:{ui_state}")
239
  components = []
240
 
241
  step_objects = [] # Reset step objects list
@@ -243,7 +246,7 @@ class PipelineInterface:
243
  step_data = workflow.steps[step_id]
244
  step_ui_state = ui_state.steps[step_id]
245
  available_variables = self.sm.get_all_variables(state, step_id)
246
- sub_components = self._render_step(step_data, step_ui_state, available_variables, i)
247
  step_objects.append(sub_components)
248
 
249
  components.append(step_objects)
 
 
 
1
  import gradio as gr
2
  import yaml
3
+ from loguru import logger
4
 
5
  from components.model_pipeline.state_manager import (
6
  ModelStepUIState,
 
88
  step_ui_state: ModelStepUIState,
89
  available_variables: list[str],
90
  position: int = 0,
91
+ n_steps: int = 1,
92
  ):
93
  with gr.Column(elem_classes="step-container"):
94
  # Create the step component
 
115
  if self.simple:
116
  return step_interface
117
 
118
+ is_multi_step = n_steps > 1
119
+ logger.debug(f"Rendering step {position} of {n_steps}")
120
+
121
  # Add step controls below
122
+ with gr.Row(elem_classes="step-controls", visible=is_multi_step):
123
+ up_button = gr.Button("⬆️ Move Up", elem_classes="step-control-btn", interactive=is_multi_step)
124
+ down_button = gr.Button("⬇️ Move Down", elem_classes="step-control-btn", interactive=is_multi_step)
125
+ remove_button = gr.Button("🗑️ Remove", elem_classes="step-control-btn", interactive=is_multi_step)
126
 
127
  buttons = (up_button, down_button, remove_button)
128
  self._assign_step_controls(buttons, position)
 
235
 
236
  # Function to render all steps
237
  @gr.render(inputs=[self.pipeline_state, self.ui_state])
238
+ def render_steps(state: PipelineState, ui_state: PipelineUIState):
239
  """Render all steps in the pipeline"""
240
+ logger.info(f"\nRerender triggered! Current UI State:{ui_state.model_dump()}")
241
  workflow = state.workflow
 
242
  components = []
243
 
244
  step_objects = [] # Reset step objects list
 
246
  step_data = workflow.steps[step_id]
247
  step_ui_state = ui_state.steps[step_id]
248
  available_variables = self.sm.get_all_variables(state, step_id)
249
+ sub_components = self._render_step(step_data, step_ui_state, available_variables, i, ui_state.n_steps)
250
  step_objects.append(sub_components)
251
 
252
  components.append(step_objects)
src/components/model_pipeline/state_manager.py CHANGED
@@ -3,6 +3,7 @@ from typing import Any, Literal
3
 
4
  import gradio as gr
5
  import yaml
 
6
  from pydantic import BaseModel, Field
7
 
8
  from components import utils
@@ -10,17 +11,25 @@ from workflows.factory import create_new_llm_step
10
  from workflows.structs import ModelStep, Workflow
11
 
12
 
13
- def make_step_id(step_id: int):
14
  """Make a step id from a step name."""
15
- if step_id < 26:
16
- return chr(ord("A") + step_id)
17
  else:
18
  # For more than 26 steps, use AA, AB, AC, etc.
19
- first_char = chr(ord("A") + (step_id // 26) - 1)
20
- second_char = chr(ord("A") + (step_id % 26))
21
  return f"{first_char}{second_char}"
22
 
23
 
 
 
 
 
 
 
 
 
24
  class ModelStepUIState(BaseModel):
25
  """Represents the UI state for a model step component."""
26
 
@@ -48,6 +57,11 @@ class PipelineUIState(BaseModel):
48
  """Get the position of a step in the pipeline."""
49
  return next((i for i, step in enumerate(self.step_ids) if step == step_id), None)
50
 
 
 
 
 
 
51
  @classmethod
52
  def from_workflow(cls, workflow: Workflow):
53
  """Create a pipeline UI state from a workflow."""
@@ -63,7 +77,7 @@ class PipelineState(BaseModel):
63
  workflow: Workflow
64
  ui_state: PipelineUIState
65
 
66
- def insert_step(self, position: int, step: ModelStep):
67
  if step.id in self.workflow.steps:
68
  raise ValueError(f"Step {step.id} already exists in pipeline")
69
 
@@ -81,14 +95,15 @@ class PipelineState(BaseModel):
81
  self.ui_state.step_ids.insert(position, step.id)
82
  return self
83
 
84
- def remove_step(self, position: int):
85
  step_id = self.ui_state.step_ids.pop(position)
86
  self.workflow.steps.pop(step_id)
87
  self.ui_state = self.ui_state.model_copy()
88
  self.ui_state.steps.pop(step_id)
89
  self.update_output_variables_mapping()
 
90
 
91
- def update_output_variables_mapping(self):
92
  available_variables = set(self.available_variables)
93
  for output_field in self.workflow.outputs:
94
  if self.workflow.outputs[output_field] not in available_variables:
@@ -96,13 +111,21 @@ class PipelineState(BaseModel):
96
  return self
97
 
98
  @property
99
- def available_variables(self):
100
  return self.workflow.get_available_variables()
101
 
102
  @property
103
- def n_steps(self):
104
  return len(self.workflow.steps)
105
 
 
 
 
 
 
 
 
 
106
 
107
  class PipelineStateManager:
108
  """Manages a pipeline of multiple steps."""
@@ -120,7 +143,7 @@ class PipelineStateManager:
120
 
121
  def add_step(self, state: PipelineState, position: int = -1, name=""):
122
  """Create a new step and return its state."""
123
- step_id = make_step_id(state.n_steps)
124
  step_name = name or f"Step {state.n_steps + 1}"
125
  new_step = create_new_llm_step(step_id=step_id, name=step_name)
126
  state = state.insert_step(position, new_step)
 
3
 
4
  import gradio as gr
5
  import yaml
6
+ from loguru import logger
7
  from pydantic import BaseModel, Field
8
 
9
  from components import utils
 
11
  from workflows.structs import ModelStep, Workflow
12
 
13
 
14
+ def make_step_id(step_number: int):
15
  """Make a step id from a step name."""
16
+ if step_number < 26:
17
+ return chr(ord("A") + step_number)
18
  else:
19
  # For more than 26 steps, use AA, AB, AC, etc.
20
+ first_char = chr(ord("A") + (step_number // 26) - 1)
21
+ second_char = chr(ord("A") + (step_number % 26))
22
  return f"{first_char}{second_char}"
23
 
24
 
25
+ def make_step_number(step_id: str):
26
+ """Make a step number from a step id."""
27
+ if len(step_id) == 1:
28
+ return ord(step_id) - ord("A")
29
+ else:
30
+ return (ord(step_id[0]) - ord("A")) * 26 + (ord(step_id[1]) - ord("A")) + 1
31
+
32
+
33
  class ModelStepUIState(BaseModel):
34
  """Represents the UI state for a model step component."""
35
 
 
57
  """Get the position of a step in the pipeline."""
58
  return next((i for i, step in enumerate(self.step_ids) if step == step_id), None)
59
 
60
+ @property
61
+ def n_steps(self) -> int:
62
+ """Get the number of steps in the pipeline."""
63
+ return len(self.step_ids)
64
+
65
  @classmethod
66
  def from_workflow(cls, workflow: Workflow):
67
  """Create a pipeline UI state from a workflow."""
 
77
  workflow: Workflow
78
  ui_state: PipelineUIState
79
 
80
+ def insert_step(self, position: int, step: ModelStep) -> "PipelineState":
81
  if step.id in self.workflow.steps:
82
  raise ValueError(f"Step {step.id} already exists in pipeline")
83
 
 
95
  self.ui_state.step_ids.insert(position, step.id)
96
  return self
97
 
98
+ def remove_step(self, position: int) -> "PipelineState":
99
  step_id = self.ui_state.step_ids.pop(position)
100
  self.workflow.steps.pop(step_id)
101
  self.ui_state = self.ui_state.model_copy()
102
  self.ui_state.steps.pop(step_id)
103
  self.update_output_variables_mapping()
104
+ return self
105
 
106
+ def update_output_variables_mapping(self) -> "PipelineState":
107
  available_variables = set(self.available_variables)
108
  for output_field in self.workflow.outputs:
109
  if self.workflow.outputs[output_field] not in available_variables:
 
111
  return self
112
 
113
  @property
114
+ def available_variables(self) -> list[str]:
115
  return self.workflow.get_available_variables()
116
 
117
  @property
118
+ def n_steps(self) -> int:
119
  return len(self.workflow.steps)
120
 
121
+ def get_new_step_id(self) -> str:
122
+ """Get a step ID for a new step."""
123
+ if not self.workflow.steps:
124
+ return "A"
125
+ else:
126
+ last_step_number = max(map(make_step_number, self.workflow.steps.keys()))
127
+ return make_step_id(last_step_number + 1)
128
+
129
 
130
  class PipelineStateManager:
131
  """Manages a pipeline of multiple steps."""
 
143
 
144
  def add_step(self, state: PipelineState, position: int = -1, name=""):
145
  """Create a new step and return its state."""
146
+ step_id = state.get_new_step_id()
147
  step_name = name or f"Step {state.n_steps + 1}"
148
  new_step = create_new_llm_step(step_id=step_id, name=step_name)
149
  state = state.insert_step(position, new_step)
src/workflows/structs.py CHANGED
@@ -193,7 +193,8 @@ class Workflow(BaseModel):
193
 
194
  def model_dump(self, *args, **kwargs):
195
  data = super().model_dump(*args, **kwargs)
196
- data["steps"] = list(data["steps"].values())
 
197
  return data
198
 
199
  @model_validator(mode="before")
 
193
 
194
  def model_dump(self, *args, **kwargs):
195
  data = super().model_dump(*args, **kwargs)
196
+ if "steps" in data:
197
+ data["steps"] = list(data["steps"].values())
198
  return data
199
 
200
  @model_validator(mode="before")