Tech-Meld commited on
Commit
ae62561
1 Parent(s): 0445e3f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -47,7 +47,7 @@ def prune_model(llm_model_name, target_size, hf_write_token, repo_name, progress
47
  target_num_parameters = int(config.num_parameters * (target_size / 100))
48
 
49
  # Prune the model
50
- pruned_model = merge_kit_prune(llm_model, target_num_parameters, progress)
51
 
52
  log_messages.append("Model pruned successfully.")
53
  logging.info("Model pruned successfully.")
@@ -81,7 +81,7 @@ def prune_model(llm_model_name, target_size, hf_write_token, repo_name, progress
81
  return error_message, None, "\n".join(log_messages)
82
 
83
  # Merge-kit Pruning Function (adjust as needed)
84
- def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int, progress) -> PreTrainedModel:
85
  """Prunes a model using a merge-kit approach.
86
  Args:
87
  model (PreTrainedModel): The model to be pruned.
@@ -120,9 +120,9 @@ def create_interface():
120
  pruning_status = gr.Textbox(label="Pruning Status", interactive=False)
121
  prune_button = gr.Button("Prune Model")
122
  visualization = gr.Image(label="Model Size Comparison", interactive=False)
123
- progress_bar = gr.Progress()
124
  logs_button = gr.Button("Show Logs")
125
  logs_output = gr.Textbox(label="Logs", interactive=False)
 
126
 
127
  def show_logs():
128
  with open("pruning.log", "r") as log_file:
@@ -131,7 +131,11 @@ def create_interface():
131
 
132
  logs_button.click(fn=show_logs, outputs=logs_output)
133
 
134
- prune_button.click(fn=prune_model, inputs=[llm_model_name, target_size, hf_write_token, repo_name, progress_bar], outputs=[pruning_status, visualization, logs_output])
 
 
 
 
135
 
136
  text_input = gr.Textbox(label="Input Text")
137
  text_output = gr.Textbox(label="Generated Text")
 
47
  target_num_parameters = int(config.num_parameters * (target_size / 100))
48
 
49
  # Prune the model
50
+ pruned_model = merge_kit_prune(llm_model, target_num_parameters)
51
 
52
  log_messages.append("Model pruned successfully.")
53
  logging.info("Model pruned successfully.")
 
81
  return error_message, None, "\n".join(log_messages)
82
 
83
  # Merge-kit Pruning Function (adjust as needed)
84
+ def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int) -> PreTrainedModel:
85
  """Prunes a model using a merge-kit approach.
86
  Args:
87
  model (PreTrainedModel): The model to be pruned.
 
120
  pruning_status = gr.Textbox(label="Pruning Status", interactive=False)
121
  prune_button = gr.Button("Prune Model")
122
  visualization = gr.Image(label="Model Size Comparison", interactive=False)
 
123
  logs_button = gr.Button("Show Logs")
124
  logs_output = gr.Textbox(label="Logs", interactive=False)
125
+ progress_bar = gr.Progress()
126
 
127
  def show_logs():
128
  with open("pruning.log", "r") as log_file:
 
131
 
132
  logs_button.click(fn=show_logs, outputs=logs_output)
133
 
134
+ def prune_model_with_progress(llm_model_name, target_size, hf_write_token, repo_name):
135
+ with progress_bar:
136
+ return prune_model(llm_model_name, target_size, hf_write_token, repo_name)
137
+
138
+ prune_button.click(fn=prune_model_with_progress, inputs=[llm_model_name, target_size, hf_write_token, repo_name], outputs=[pruning_status, visualization, logs_output])
139
 
140
  text_input = gr.Textbox(label="Input Text")
141
  text_output = gr.Textbox(label="Generated Text")