Tech-Meld commited on
Commit
dd6c56a
1 Parent(s): 9a6ccf8

Main app file

Browse files
Files changed (1) hide show
  1. app.py +142 -0
app.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline, AutoConfig
3
+ from huggingface_hub import cached_download, hf_hub_url, list_models
4
+ import requests
5
+ import json
6
+ import os
7
+ import matplotlib.pyplot as plt
8
+ from io import BytesIO
9
+ import base64
10
+ from transformers.models.auto import AutoModel
11
+ from transformers.modeling_utils import PreTrainedModel
12
+ import torch
13
+ from torch.nn.utils import prune
14
+
15
+ # Function to fetch open-weight LLM models
16
+ def fetch_open_weight_models():
17
+ models = list_models(filter="open-weight", sort="downloads", limit=12)
18
+ return [model["id"] for model in models]
19
+
20
+ # Function to prune a model using the "merge-kit" approach
21
+ def prune_model(llm_model_name, target_size, output_dir):
22
+ # Load the LLM model and tokenizer
23
+ llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
24
+ llm_model = AutoModelForSeq2SeqLM.from_pretrained(llm_model_name)
25
+
26
+ # Get the model config
27
+ config = AutoConfig.from_pretrained(llm_model_name)
28
+ # Calculate the target number of parameters
29
+ target_num_parameters = int(config.num_parameters * (target_size / 100))
30
+
31
+ # Use merge-kit to prune the model
32
+ pruned_model = merge_kit_prune(llm_model, target_num_parameters)
33
+
34
+ # Save the pruned model
35
+ pruned_model.save_pretrained(output_dir)
36
+
37
+ # Create a visualization
38
+ fig, ax = plt.subplots(figsize=(10, 5))
39
+ ax.bar(["Original", "Pruned"], [config.num_parameters, pruned_model.num_parameters])
40
+ ax.set_ylabel("Number of Parameters")
41
+ ax.set_title("Model Size Comparison")
42
+ buf = BytesIO()
43
+ fig.savefig(buf, format="png")
44
+ buf.seek(0)
45
+ image_base64 = base64.b64encode(buf.read()).decode("utf-8")
46
+ return f"Pruned model saved to {output_dir}", f"data:image/png;base64,{image_base64}"
47
+
48
+ # Merge-kit Pruning Function
49
+ def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int) -> PreTrainedModel:
50
+ """Prunes a model using a merge-kit approach.
51
+
52
+ Args:
53
+ model (PreTrainedModel): The model to be pruned.
54
+ target_num_parameters (int): The target number of parameters after pruning.
55
+
56
+ Returns:
57
+ PreTrainedModel: The pruned model.
58
+ """
59
+
60
+ # Define the pruning method
61
+ pruning_method = "unstructured"
62
+
63
+ # Calculate the pruning amount
64
+ amount = 1 - (target_num_parameters / model.num_parameters)
65
+
66
+ # Prune the model using the selected method
67
+ for name, module in model.named_modules():
68
+ if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
69
+ prune.random_unstructured(module, name="weight", amount=amount)
70
+
71
+ # Remove the pruned weights
72
+ for name, module in model.named_modules():
73
+ if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
74
+ prune.remove(module, name="weight")
75
+
76
+ return model
77
+
78
+ # Function to create a Gradio interface
79
+ def create_interface():
80
+ with gr.Blocks() as demo:
81
+ gr.Markdown("## Create a Smaller LLM")
82
+
83
+ # Fetch open-weight models from Hugging Face
84
+ available_models = gr.Dropdown(
85
+ label="Choose a Large Language Model",
86
+ choices=fetch_open_weight_models(),
87
+ interactive=True,
88
+ )
89
+
90
+ # Input for target model size
91
+ target_size = gr.Slider(
92
+ label="Target Model Size (%)",
93
+ minimum=1,
94
+ maximum=100,
95
+ step=1,
96
+ value=50,
97
+ interactive=True,
98
+ )
99
+
100
+ # Output for pruning status
101
+ pruning_status = gr.Textbox(label="Pruning Status")
102
+
103
+ # Output for saving the model
104
+ save_model_path = gr.Textbox(label="Save Model Path", placeholder="Path to save the pruned model", interactive=True)
105
+
106
+ # Button to start pruning
107
+ prune_button = gr.Button("Prune Model")
108
+
109
+ # Output for visualization
110
+ visualization = gr.Image(label="Model Size Comparison")
111
+
112
+ # Connect components
113
+ prune_button.click(
114
+ fn=prune_model,
115
+ inputs=[available_models, target_size, save_model_path],
116
+ outputs=[pruning_status, visualization],
117
+ )
118
+
119
+ # Example usage of the pruned model (optional)
120
+ text_input = gr.Textbox(label="Input Text")
121
+ text_output = gr.Textbox(label="Generated Text")
122
+
123
+ # Generate text button
124
+ generate_button = gr.Button("Generate Text")
125
+
126
+ def generate_text(text, model_path):
127
+ # Load the pruned model and tokenizer
128
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
129
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
130
+
131
+ # Use the pipeline for text generation
132
+ generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
133
+ generated_text = generator(text, max_length=50, num_beams=5, num_return_sequences=1)[0]["generated_text"]
134
+ return generated_text
135
+
136
+ generate_button.click(fn=generate_text, inputs=[text_input, save_model_path], outputs=text_output)
137
+
138
+ return demo
139
+
140
+ # Create and launch the Gradio interface
141
+ demo = create_interface()
142
+ demo.launch(share=True)