Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -47,7 +47,7 @@ def prune_model(llm_model_name, target_size, hf_write_token, repo_name, progress
|
|
47 |
target_num_parameters = int(config.num_parameters * (target_size / 100))
|
48 |
|
49 |
# Prune the model
|
50 |
-
pruned_model = merge_kit_prune(llm_model, target_num_parameters
|
51 |
|
52 |
log_messages.append("Model pruned successfully.")
|
53 |
logging.info("Model pruned successfully.")
|
@@ -81,7 +81,7 @@ def prune_model(llm_model_name, target_size, hf_write_token, repo_name, progress
|
|
81 |
return error_message, None, "\n".join(log_messages)
|
82 |
|
83 |
# Merge-kit Pruning Function (adjust as needed)
|
84 |
-
def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int
|
85 |
"""Prunes a model using a merge-kit approach.
|
86 |
Args:
|
87 |
model (PreTrainedModel): The model to be pruned.
|
@@ -120,9 +120,9 @@ def create_interface():
|
|
120 |
pruning_status = gr.Textbox(label="Pruning Status", interactive=False)
|
121 |
prune_button = gr.Button("Prune Model")
|
122 |
visualization = gr.Image(label="Model Size Comparison", interactive=False)
|
123 |
-
progress_bar = gr.Progress()
|
124 |
logs_button = gr.Button("Show Logs")
|
125 |
logs_output = gr.Textbox(label="Logs", interactive=False)
|
|
|
126 |
|
127 |
def show_logs():
|
128 |
with open("pruning.log", "r") as log_file:
|
@@ -131,7 +131,11 @@ def create_interface():
|
|
131 |
|
132 |
logs_button.click(fn=show_logs, outputs=logs_output)
|
133 |
|
134 |
-
|
|
|
|
|
|
|
|
|
135 |
|
136 |
text_input = gr.Textbox(label="Input Text")
|
137 |
text_output = gr.Textbox(label="Generated Text")
|
|
|
47 |
target_num_parameters = int(config.num_parameters * (target_size / 100))
|
48 |
|
49 |
# Prune the model
|
50 |
+
pruned_model = merge_kit_prune(llm_model, target_num_parameters)
|
51 |
|
52 |
log_messages.append("Model pruned successfully.")
|
53 |
logging.info("Model pruned successfully.")
|
|
|
81 |
return error_message, None, "\n".join(log_messages)
|
82 |
|
83 |
# Merge-kit Pruning Function (adjust as needed)
|
84 |
+
def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int) -> PreTrainedModel:
|
85 |
"""Prunes a model using a merge-kit approach.
|
86 |
Args:
|
87 |
model (PreTrainedModel): The model to be pruned.
|
|
|
120 |
pruning_status = gr.Textbox(label="Pruning Status", interactive=False)
|
121 |
prune_button = gr.Button("Prune Model")
|
122 |
visualization = gr.Image(label="Model Size Comparison", interactive=False)
|
|
|
123 |
logs_button = gr.Button("Show Logs")
|
124 |
logs_output = gr.Textbox(label="Logs", interactive=False)
|
125 |
+
progress_bar = gr.Progress()
|
126 |
|
127 |
def show_logs():
|
128 |
with open("pruning.log", "r") as log_file:
|
|
|
131 |
|
132 |
logs_button.click(fn=show_logs, outputs=logs_output)
|
133 |
|
134 |
+
def prune_model_with_progress(llm_model_name, target_size, hf_write_token, repo_name):
|
135 |
+
with progress_bar:
|
136 |
+
return prune_model(llm_model_name, target_size, hf_write_token, repo_name)
|
137 |
+
|
138 |
+
prune_button.click(fn=prune_model_with_progress, inputs=[llm_model_name, target_size, hf_write_token, repo_name], outputs=[pruning_status, visualization, logs_output])
|
139 |
|
140 |
text_input = gr.Textbox(label="Input Text")
|
141 |
text_output = gr.Textbox(label="Generated Text")
|