Tech-Meld commited on
Commit
5796c7a
1 Parent(s): 90e98a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -46
app.py CHANGED
@@ -1,49 +1,53 @@
1
  import gradio as gr
2
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline, AutoConfig
3
  from huggingface_hub import cached_download, hf_hub_url, list_models
4
- from transformers.modeling_utils import PreTrainedModel
5
  import requests
6
  import json
7
  import os
8
  import matplotlib.pyplot as plt
9
  from io import BytesIO
10
  import base64
11
-
12
- # Choose your backend (PyTorch, TensorFlow, or Flax)
13
- import torch # If using PyTorch
 
14
 
15
  # Function to fetch open-weight LLM models
16
  def fetch_open_weight_models():
17
- models = list_models(filter="open-weight", sort="downloads", limit=10)
18
- return [model["id"] for model in models]
19
 
20
  # Function to prune a model using the "merge-kit" approach
21
  def prune_model(llm_model_name, target_size, output_dir):
22
- # Load the LLM model and tokenizer
23
- llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
24
- llm_model = AutoModelForSeq2SeqLM.from_pretrained(llm_model_name)
25
-
26
- # Get the model config
27
- config = AutoConfig.from_pretrained(llm_model_name)
28
- # Calculate the target number of parameters
29
- target_num_parameters = int(config.num_parameters * (target_size / 100))
30
-
31
- # Use merge-kit to prune the model
32
- pruned_model = merge_kit_prune(llm_model, target_num_parameters)
33
-
34
- # Save the pruned model
35
- pruned_model.save_pretrained(output_dir)
36
-
37
- # Create a visualization
38
- fig, ax = plt.subplots(figsize=(10, 5))
39
- ax.bar(["Original", "Pruned"], [config.num_parameters, pruned_model.num_parameters])
40
- ax.set_ylabel("Number of Parameters")
41
- ax.set_title("Model Size Comparison")
42
- buf = BytesIO()
43
- fig.savefig(buf, format="png")
44
- buf.seek(0)
45
- image_base64 = base64.b64encode(buf.read()).decode("utf-8")
46
- return f"Pruned model saved to {output_dir}", f"data:image/png;base64,{image_base64}"
 
 
 
 
47
 
48
  # Merge-kit Pruning Function
49
  def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int) -> PreTrainedModel:
@@ -80,12 +84,8 @@ def create_interface():
80
  with gr.Blocks() as demo:
81
  gr.Markdown("## Create a Smaller LLM")
82
 
83
- # Fetch open-weight models from Hugging Face
84
- available_models = gr.Dropdown(
85
- label="Choose a Large Language Model",
86
- choices=fetch_open_weight_models(),
87
- interactive=True,
88
- )
89
 
90
  # Input for target model size
91
  target_size = gr.Slider(
@@ -112,7 +112,7 @@ def create_interface():
112
  # Connect components
113
  prune_button.click(
114
  fn=prune_model,
115
- inputs=[available_models, target_size, save_model_path],
116
  outputs=[pruning_status, visualization],
117
  )
118
 
@@ -124,14 +124,17 @@ def create_interface():
124
  generate_button = gr.Button("Generate Text")
125
 
126
  def generate_text(text, model_path):
127
- # Load the pruned model and tokenizer
128
- tokenizer = AutoTokenizer.from_pretrained(model_path)
129
- model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
130
-
131
- # Use the pipeline for text generation
132
- generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
133
- generated_text = generator(text, max_length=50, num_beams=5, num_return_sequences=1)[0]["generated_text"]
134
- return generated_text
 
 
 
135
 
136
  generate_button.click(fn=generate_text, inputs=[text_input, save_model_path], outputs=text_output)
137
 
 
1
  import gradio as gr
2
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline, AutoConfig
3
  from huggingface_hub import cached_download, hf_hub_url, list_models
 
4
  import requests
5
  import json
6
  import os
7
  import matplotlib.pyplot as plt
8
  from io import BytesIO
9
  import base64
10
+ from transformers.models.auto import AutoModel
11
+ from transformers.modeling_utils import PreTrainedModel
12
+ import torch
13
+ from torch.nn.utils import prune
14
 
15
  # Function to fetch open-weight LLM models
16
  def fetch_open_weight_models():
17
+ models = list_models()
18
+ return models
19
 
20
  # Function to prune a model using the "merge-kit" approach
21
  def prune_model(llm_model_name, target_size, output_dir):
22
+ try:
23
+ # Load the LLM model and tokenizer
24
+ llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
25
+ llm_model = AutoModelForSeq2SeqLM.from_pretrained(llm_model_name)
26
+
27
+ # Get the model config
28
+ config = AutoConfig.from_pretrained(llm_model_name)
29
+ # Calculate the target number of parameters
30
+ target_num_parameters = int(config.num_parameters * (target_size / 100))
31
+
32
+ # Use merge-kit to prune the model
33
+ pruned_model = merge_kit_prune(llm_model, target_num_parameters)
34
+
35
+ # Save the pruned model
36
+ pruned_model.save_pretrained(output_dir)
37
+
38
+ # Create a visualization
39
+ fig, ax = plt.subplots(figsize=(10, 5))
40
+ ax.bar(["Original", "Pruned"], [config.num_parameters, pruned_model.num_parameters])
41
+ ax.set_ylabel("Number of Parameters")
42
+ ax.set_title("Model Size Comparison")
43
+ buf = BytesIO()
44
+ fig.savefig(buf, format="png")
45
+ buf.seek(0)
46
+ image_base64 = base64.b64encode(buf.read()).decode("utf-8")
47
+ return f"Pruned model saved to {output_dir}", f"data:image/png;base64,{image_base64}"
48
+
49
+ except Exception as e:
50
+ return f"Error: {e}", None
51
 
52
  # Merge-kit Pruning Function
53
  def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int) -> PreTrainedModel:
 
84
  with gr.Blocks() as demo:
85
  gr.Markdown("## Create a Smaller LLM")
86
 
87
+ # Input for model name
88
+ llm_model_name = gr.Textbox(label="Choose a Large Language Model", placeholder="Enter the model name", interactive=True)
 
 
 
 
89
 
90
  # Input for target model size
91
  target_size = gr.Slider(
 
112
  # Connect components
113
  prune_button.click(
114
  fn=prune_model,
115
+ inputs=[llm_model_name, target_size, save_model_path],
116
  outputs=[pruning_status, visualization],
117
  )
118
 
 
124
  generate_button = gr.Button("Generate Text")
125
 
126
  def generate_text(text, model_path):
127
+ try:
128
+ # Load the pruned model and tokenizer
129
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
130
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
131
+
132
+ # Use the pipeline for text generation
133
+ generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
134
+ generated_text = generator(text, max_length=50, num_beams=5, num_return_sequences=1)[0]["generated_text"]
135
+ return generated_text
136
+ except Exception as e:
137
+ return f"Error: {e}"
138
 
139
  generate_button.click(fn=generate_text, inputs=[text_input, save_model_path], outputs=text_output)
140