Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -13,6 +13,7 @@ from torch.nn.utils import prune
|
|
13 |
import subprocess
|
14 |
from tqdm import tqdm
|
15 |
import logging
|
|
|
16 |
|
17 |
# Setup logging
|
18 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
@@ -47,7 +48,7 @@ def prune_model(llm_model_name, target_size, hf_write_token, repo_name, progress
|
|
47 |
target_num_parameters = int(config.num_parameters * (target_size / 100))
|
48 |
|
49 |
# Prune the model
|
50 |
-
pruned_model = merge_kit_prune(llm_model, target_num_parameters)
|
51 |
|
52 |
log_messages.append("Model pruned successfully.")
|
53 |
logging.info("Model pruned successfully.")
|
@@ -81,7 +82,7 @@ def prune_model(llm_model_name, target_size, hf_write_token, repo_name, progress
|
|
81 |
return error_message, None, "\n".join(log_messages)
|
82 |
|
83 |
# Merge-kit Pruning Function (adjust as needed)
|
84 |
-
def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int) -> PreTrainedModel:
|
85 |
"""Prunes a model using a merge-kit approach.
|
86 |
Args:
|
87 |
model (PreTrainedModel): The model to be pruned.
|
@@ -100,11 +101,13 @@ def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int) -> PreTr
|
|
100 |
for name, module in tqdm(model.named_modules(), desc="Pruning", file=sys.stdout):
|
101 |
if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
|
102 |
prune.random_unstructured(module, name="weight", amount=amount)
|
|
|
103 |
|
104 |
# Remove the pruned weights
|
105 |
for name, module in model.named_modules():
|
106 |
if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
|
107 |
prune.remove(module, name="weight")
|
|
|
108 |
|
109 |
return model
|
110 |
|
@@ -132,8 +135,7 @@ def create_interface():
|
|
132 |
logs_button.click(fn=show_logs, outputs=logs_output)
|
133 |
|
134 |
def prune_model_with_progress(llm_model_name, target_size, hf_write_token, repo_name):
|
135 |
-
|
136 |
-
return prune_model(llm_model_name, target_size, hf_write_token, repo_name)
|
137 |
|
138 |
prune_button.click(fn=prune_model_with_progress, inputs=[llm_model_name, target_size, hf_write_token, repo_name], outputs=[pruning_status, visualization, logs_output])
|
139 |
|
|
|
13 |
import subprocess
|
14 |
from tqdm import tqdm
|
15 |
import logging
|
16 |
+
import sys
|
17 |
|
18 |
# Setup logging
|
19 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
48 |
target_num_parameters = int(config.num_parameters * (target_size / 100))
|
49 |
|
50 |
# Prune the model
|
51 |
+
pruned_model = merge_kit_prune(llm_model, target_num_parameters, progress)
|
52 |
|
53 |
log_messages.append("Model pruned successfully.")
|
54 |
logging.info("Model pruned successfully.")
|
|
|
82 |
return error_message, None, "\n".join(log_messages)
|
83 |
|
84 |
# Merge-kit Pruning Function (adjust as needed)
|
85 |
+
def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int, progress: gr.Progress) -> PreTrainedModel:
|
86 |
"""Prunes a model using a merge-kit approach.
|
87 |
Args:
|
88 |
model (PreTrainedModel): The model to be pruned.
|
|
|
101 |
for name, module in tqdm(model.named_modules(), desc="Pruning", file=sys.stdout):
|
102 |
if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
|
103 |
prune.random_unstructured(module, name="weight", amount=amount)
|
104 |
+
progress(percent_complete=50) # Example progress update
|
105 |
|
106 |
# Remove the pruned weights
|
107 |
for name, module in model.named_modules():
|
108 |
if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
|
109 |
prune.remove(module, name="weight")
|
110 |
+
progress(percent_complete=100) # Example progress update
|
111 |
|
112 |
return model
|
113 |
|
|
|
135 |
logs_button.click(fn=show_logs, outputs=logs_output)
|
136 |
|
137 |
def prune_model_with_progress(llm_model_name, target_size, hf_write_token, repo_name):
|
138 |
+
return prune_model(llm_model_name, target_size, hf_write_token, repo_name, progress_bar)
|
|
|
139 |
|
140 |
prune_button.click(fn=prune_model_with_progress, inputs=[llm_model_name, target_size, hf_write_token, repo_name], outputs=[pruning_status, visualization, logs_output])
|
141 |
|