Tech-Meld commited on
Commit
8a2d207
1 Parent(s): ae62561

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -13,6 +13,7 @@ from torch.nn.utils import prune
13
  import subprocess
14
  from tqdm import tqdm
15
  import logging
 
16
 
17
  # Setup logging
18
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -47,7 +48,7 @@ def prune_model(llm_model_name, target_size, hf_write_token, repo_name, progress
47
  target_num_parameters = int(config.num_parameters * (target_size / 100))
48
 
49
  # Prune the model
50
- pruned_model = merge_kit_prune(llm_model, target_num_parameters)
51
 
52
  log_messages.append("Model pruned successfully.")
53
  logging.info("Model pruned successfully.")
@@ -81,7 +82,7 @@ def prune_model(llm_model_name, target_size, hf_write_token, repo_name, progress
81
  return error_message, None, "\n".join(log_messages)
82
 
83
  # Merge-kit Pruning Function (adjust as needed)
84
- def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int) -> PreTrainedModel:
85
  """Prunes a model using a merge-kit approach.
86
  Args:
87
  model (PreTrainedModel): The model to be pruned.
@@ -100,11 +101,13 @@ def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int) -> PreTr
100
  for name, module in tqdm(model.named_modules(), desc="Pruning", file=sys.stdout):
101
  if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
102
  prune.random_unstructured(module, name="weight", amount=amount)
 
103
 
104
  # Remove the pruned weights
105
  for name, module in model.named_modules():
106
  if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
107
  prune.remove(module, name="weight")
 
108
 
109
  return model
110
 
@@ -132,8 +135,7 @@ def create_interface():
132
  logs_button.click(fn=show_logs, outputs=logs_output)
133
 
134
  def prune_model_with_progress(llm_model_name, target_size, hf_write_token, repo_name):
135
- with progress_bar:
136
- return prune_model(llm_model_name, target_size, hf_write_token, repo_name)
137
 
138
  prune_button.click(fn=prune_model_with_progress, inputs=[llm_model_name, target_size, hf_write_token, repo_name], outputs=[pruning_status, visualization, logs_output])
139
 
 
13
  import subprocess
14
  from tqdm import tqdm
15
  import logging
16
+ import sys
17
 
18
  # Setup logging
19
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
48
  target_num_parameters = int(config.num_parameters * (target_size / 100))
49
 
50
  # Prune the model
51
+ pruned_model = merge_kit_prune(llm_model, target_num_parameters, progress)
52
 
53
  log_messages.append("Model pruned successfully.")
54
  logging.info("Model pruned successfully.")
 
82
  return error_message, None, "\n".join(log_messages)
83
 
84
  # Merge-kit Pruning Function (adjust as needed)
85
+ def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int, progress: gr.Progress) -> PreTrainedModel:
86
  """Prunes a model using a merge-kit approach.
87
  Args:
88
  model (PreTrainedModel): The model to be pruned.
 
101
  for name, module in tqdm(model.named_modules(), desc="Pruning", file=sys.stdout):
102
  if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
103
  prune.random_unstructured(module, name="weight", amount=amount)
104
+ progress(percent_complete=50) # Example progress update
105
 
106
  # Remove the pruned weights
107
  for name, module in model.named_modules():
108
  if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
109
  prune.remove(module, name="weight")
110
+ progress(percent_complete=100) # Example progress update
111
 
112
  return model
113
 
 
135
  logs_button.click(fn=show_logs, outputs=logs_output)
136
 
137
  def prune_model_with_progress(llm_model_name, target_size, hf_write_token, repo_name):
138
+ return prune_model(llm_model_name, target_size, hf_write_token, repo_name, progress_bar)
 
139
 
140
  prune_button.click(fn=prune_model_with_progress, inputs=[llm_model_name, target_size, hf_write_token, repo_name], outputs=[pruning_status, visualization, logs_output])
141