Tech-Meld commited on
Commit
3cfae8b
·
verified ·
1 Parent(s): 8bb39cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -41
app.py CHANGED
@@ -10,32 +10,37 @@ from io import BytesIO
10
  import base64
11
  import torch
12
  from torch.nn.utils import prune
 
13
 
14
  # Function to fetch open-weight LLM models
15
  def fetch_open_weight_models():
16
  models = list_models()
17
  return models
18
 
 
 
 
 
 
 
19
  # Function to prune a model using the "merge-kit" approach
20
  def prune_model(llm_model_name, target_size, hf_write_token, repo_name):
21
  try:
22
  # Load the LLM model and tokenizer
23
  llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
24
- # Handle cases where the model is split into multiple safetensors
25
  llm_model = AutoModelForCausalLM.from_pretrained(
26
  llm_model_name,
27
- torch_dtype=torch.float16, # Adjust dtype as needed
28
  )
29
 
30
  # Get the model config
31
  config = AutoConfig.from_pretrained(llm_model_name)
32
- # Calculate the target number of parameters
33
  target_num_parameters = int(config.num_parameters * (target_size / 100))
34
 
35
- # Use merge-kit to prune the model
36
  pruned_model = merge_kit_prune(llm_model, target_num_parameters)
37
 
38
- # Save the pruned model to Hugging Face repository
39
  api = HfApi()
40
  repo_id = f"{hf_write_token}/{repo_name}"
41
  create_repo(repo_id, token=hf_write_token, private=False, exist_ok=True)
@@ -51,10 +56,11 @@ def prune_model(llm_model_name, target_size, hf_write_token, repo_name):
51
  fig.savefig(buf, format="png")
52
  buf.seek(0)
53
  image_base64 = base64.b64encode(buf.read()).decode("utf-8")
54
- return f"Pruned model saved to Hugging Face Hub in repository {repo_id}", f"data:image/png;base64,{image_base64}"
 
55
 
56
  except Exception as e:
57
- return f"Error: {e}", None
58
 
59
  # Merge-kit Pruning Function (adjust as needed)
60
  def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int) -> PreTrainedModel:
@@ -88,55 +94,24 @@ def create_interface():
88
  with gr.Blocks() as demo:
89
  gr.Markdown("## Create a Smaller LLM")
90
 
91
- # Input for model name
92
  llm_model_name = gr.Textbox(label="Choose a Large Language Model", placeholder="Enter the model name", interactive=True)
93
-
94
- # Input for target model size
95
- target_size = gr.Slider(
96
- label="Target Model Size (%)",
97
- minimum=1,
98
- maximum=100,
99
- step=1,
100
- value=50,
101
- interactive=True,
102
- )
103
-
104
- # Input for Hugging Face write token
105
  hf_write_token = gr.Textbox(label="Hugging Face Write Token", placeholder="Enter your HF write token", interactive=True, type="password")
106
-
107
- # Input for repository name
108
  repo_name = gr.Textbox(label="Repository Name", placeholder="Enter the name of the repository", interactive=True)
109
-
110
- # Output for pruning status
111
  pruning_status = gr.Textbox(label="Pruning Status", interactive=False)
112
-
113
- # Button to start pruning
114
  prune_button = gr.Button("Prune Model")
115
-
116
- # Output for visualization
117
  visualization = gr.Image(label="Model Size Comparison", interactive=False)
118
 
119
- # Connect components
120
- prune_button.click(
121
- fn=prune_model,
122
- inputs=[llm_model_name, target_size, hf_write_token, repo_name],
123
- outputs=[pruning_status, visualization],
124
- )
125
 
126
- # Example usage of the pruned model (optional)
127
  text_input = gr.Textbox(label="Input Text")
128
  text_output = gr.Textbox(label="Generated Text")
129
-
130
- # Generate text button
131
  generate_button = gr.Button("Generate Text")
132
 
133
  def generate_text(text, repo_name):
134
  try:
135
- # Load the pruned model and tokenizer
136
  tokenizer = AutoTokenizer.from_pretrained(repo_name, use_auth_token=hf_write_token)
137
  model = AutoModelForCausalLM.from_pretrained(repo_name, use_auth_token=hf_write_token)
138
-
139
- # Use the pipeline for text generation
140
  generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
141
  generated_text = generator(text, max_length=50, num_beams=5, num_return_sequences=1)[0]["generated_text"]
142
  return generated_text
@@ -149,4 +124,4 @@ def create_interface():
149
 
150
  # Create and launch the Gradio interface
151
  demo = create_interface()
152
- demo.launch(share=True)
 
10
  import base64
11
  import torch
12
  from torch.nn.utils import prune
13
+ import subprocess
14
 
15
  # Function to fetch open-weight LLM models
16
  def fetch_open_weight_models():
17
  models = list_models()
18
  return models
19
 
20
+ # Ensure sentencepiece is installed
21
+ try:
22
+ import sentencepiece
23
+ except ImportError:
24
+ subprocess.check_call(["pip", "install", "sentencepiece"])
25
+
26
  # Function to prune a model using the "merge-kit" approach
27
  def prune_model(llm_model_name, target_size, hf_write_token, repo_name):
28
  try:
29
  # Load the LLM model and tokenizer
30
  llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
 
31
  llm_model = AutoModelForCausalLM.from_pretrained(
32
  llm_model_name,
33
+ torch_dtype=torch.float16,
34
  )
35
 
36
  # Get the model config
37
  config = AutoConfig.from_pretrained(llm_model_name)
 
38
  target_num_parameters = int(config.num_parameters * (target_size / 100))
39
 
40
+ # Prune the model
41
  pruned_model = merge_kit_prune(llm_model, target_num_parameters)
42
 
43
+ # Save the pruned model
44
  api = HfApi()
45
  repo_id = f"{hf_write_token}/{repo_name}"
46
  create_repo(repo_id, token=hf_write_token, private=False, exist_ok=True)
 
56
  fig.savefig(buf, format="png")
57
  buf.seek(0)
58
  image_base64 = base64.b64encode(buf.read()).decode("utf-8")
59
+
60
+ return f"Pruned model saved to Hugging Face Hub in repository {repo_id}", f"data:image/png;base64,{image_base64}", None
61
 
62
  except Exception as e:
63
+ return f"Error: {e}", None, None
64
 
65
  # Merge-kit Pruning Function (adjust as needed)
66
  def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int) -> PreTrainedModel:
 
94
  with gr.Blocks() as demo:
95
  gr.Markdown("## Create a Smaller LLM")
96
 
 
97
  llm_model_name = gr.Textbox(label="Choose a Large Language Model", placeholder="Enter the model name", interactive=True)
98
+ target_size = gr.Slider(label="Target Model Size (%)", minimum=1, maximum=100, step=1, value=50, interactive=True)
 
 
 
 
 
 
 
 
 
 
 
99
  hf_write_token = gr.Textbox(label="Hugging Face Write Token", placeholder="Enter your HF write token", interactive=True, type="password")
 
 
100
  repo_name = gr.Textbox(label="Repository Name", placeholder="Enter the name of the repository", interactive=True)
 
 
101
  pruning_status = gr.Textbox(label="Pruning Status", interactive=False)
 
 
102
  prune_button = gr.Button("Prune Model")
 
 
103
  visualization = gr.Image(label="Model Size Comparison", interactive=False)
104
 
105
+ prune_button.click(fn=prune_model, inputs=[llm_model_name, target_size, hf_write_token, repo_name], outputs=[pruning_status, visualization])
 
 
 
 
 
106
 
 
107
  text_input = gr.Textbox(label="Input Text")
108
  text_output = gr.Textbox(label="Generated Text")
 
 
109
  generate_button = gr.Button("Generate Text")
110
 
111
  def generate_text(text, repo_name):
112
  try:
 
113
  tokenizer = AutoTokenizer.from_pretrained(repo_name, use_auth_token=hf_write_token)
114
  model = AutoModelForCausalLM.from_pretrained(repo_name, use_auth_token=hf_write_token)
 
 
115
  generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
116
  generated_text = generator(text, max_length=50, num_beams=5, num_return_sequences=1)[0]["generated_text"]
117
  return generated_text
 
124
 
125
  # Create and launch the Gradio interface
126
  demo = create_interface()
127
+ demo.launch(share=True)