Tech-Meld commited on
Commit
4484172
1 Parent(s): 10fbcfe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -66
app.py CHANGED
@@ -2,9 +2,6 @@ import gradio as gr
2
  from transformers import AutoModel, AutoTokenizer, pipeline, AutoConfig, AutoModelForCausalLM
3
  from huggingface_hub import create_repo, HfApi, list_models
4
  from transformers.modeling_utils import PreTrainedModel
5
- import requests
6
- import json
7
- import os
8
  import matplotlib.pyplot as plt
9
  from io import BytesIO
10
  import base64
@@ -26,11 +23,53 @@ except ImportError:
26
 
27
  # Function to fetch open-weight LLM models
28
  def fetch_open_weight_models():
29
- models = list_models()
30
- return models
 
 
 
 
31
 
32
- # Function to prune a model using the "merge-kit" approach
33
- def prune_model(llm_model_name, target_size, hf_write_token, repo_name, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  log_messages = []
35
  try:
36
  # Load the LLM model and tokenizer
@@ -39,105 +78,80 @@ def prune_model(llm_model_name, target_size, hf_write_token, repo_name, progress
39
  llm_model_name,
40
  torch_dtype=torch.float16,
41
  )
42
-
43
  log_messages.append('Model and tokenizer loaded successfully.')
44
  logging.info('Model and tokenizer loaded successfully.')
45
-
46
- # Get the model config
47
- config = AutoConfig.from_pretrained(llm_model_name)
48
- target_num_parameters = int(config.num_parameters * (target_size / 100))
49
-
50
  # Prune the model
51
  pruned_model = merge_kit_prune(llm_model, target_num_parameters, progress)
52
-
53
  log_messages.append('Model pruned successfully.')
54
  logging.info('Model pruned successfully.')
55
-
56
  # Save the pruned model
57
  api = HfApi()
58
  repo_id = f"{hf_write_token}/{repo_name}"
59
  create_repo(repo_id, token=hf_write_token, private=False, exist_ok=True)
60
  pruned_model.push_to_hub(repo_id, use_auth_token=hf_write_token)
61
  llm_tokenizer.push_to_hub(repo_id, use_auth_token=hf_write_token)
62
-
63
  log_messages.append(f"Pruned model saved to Hugging Face Hub in repository {repo_id}")
64
  logging.info(f"Pruned model saved to Hugging Face Hub in repository {repo_id}")
65
-
66
  # Create a visualization
67
  fig, ax = plt.subplots(figsize=(10, 5))
68
- ax.bar(['Original', 'Pruned'], [config.num_parameters, sum(p.numel() for p in pruned_model.parameters())])
69
  ax.set_ylabel('Number of Parameters')
70
  ax.set_title('Model Size Comparison')
71
  buf = BytesIO()
72
  fig.savefig(buf, format='png')
73
  buf.seek(0)
74
  image_base64 = base64.b64encode(buf.read()).decode('utf-8')
75
-
76
  return f"Pruned model saved to Hugging Face Hub in repository {repo_id}", f"data:image/png;base64,{image_base64}", '\n'.join(log_messages)
77
-
78
  except Exception as e:
79
- error_message = f"Error: {e}"
80
  log_messages.append(error_message)
81
  logging.error(error_message)
82
  return error_message, None, '\n'.join(log_messages)
83
 
84
- # Merge-kit Pruning Function (adjust as needed)
85
- def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int, progress: gr.Progress) -> PreTrainedModel:
86
- """Prunes a model using a merge-kit approach.
87
- Args:
88
- model (PreTrainedModel): The model to be pruned.
89
- target_num_parameters (int): The target number of parameters after pruning.
90
- Returns:
91
- PreTrainedModel: The pruned model.
92
- """
93
-
94
- total_params = sum(p.numel() for p in model.parameters())
95
- amount = 1 - (target_num_parameters / total_params)
96
-
97
- for name, module in tqdm(model.named_modules(), desc='Pruning', file=sys.stdout):
98
- if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
99
- prune.random_unstructured(module, name='weight', amount=amount)
100
- progress(percent_complete=50)
101
-
102
- for name, module in model.named_modules():
103
- if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
104
- prune.remove(module, name='weight')
105
- progress(percent_complete=100)
106
-
107
- return model
108
-
109
  # Function to create a Gradio interface
110
  def create_interface():
111
  with gr.Blocks() as demo:
112
  gr.Markdown("## Create a Smaller LLM")
113
-
114
- llm_model_name = gr.Textbox(label="Choose a Large Language Model", placeholder="Enter the model name", interactive=True)
 
 
 
 
115
  target_size = gr.Slider(label="Target Model Size (%)", minimum=1, maximum=100, step=1, value=50, interactive=True)
116
  hf_write_token = gr.Textbox(label="Hugging Face Write Token", placeholder="Enter your HF write token", interactive=True, type="password")
117
  repo_name = gr.Textbox(label="Repository Name", placeholder="Enter the name of the repository", interactive=True)
 
 
 
118
  pruning_status = gr.Textbox(label="Pruning Status", interactive=False)
119
  prune_button = gr.Button("Prune Model")
120
  visualization = gr.Image(label="Model Size Comparison", interactive=False)
121
- logs_button = gr.Button("Show Logs")
122
- logs_output = gr.Textbox(label="Logs", interactive=False)
123
  progress_bar = gr.Progress()
124
-
125
- def show_logs():
126
- with open('pruning.log', 'r') as log_file:
127
- logs = log_file.read()
128
- return logs
129
-
130
- logs_button.click(fn=show_logs, outputs=logs_output)
131
-
132
- def prune_model_with_progress(llm_model_name, target_size, hf_write_token, repo_name):
133
- return prune_model(llm_model_name, target_size, hf_write_token, repo_name, progress_bar)
134
-
135
- prune_button.click(fn=prune_model_with_progress, inputs=[llm_model_name, target_size, hf_write_token, repo_name], outputs=[pruning_status, visualization, logs_output])
136
-
137
  text_input = gr.Textbox(label="Input Text")
138
  text_output = gr.Textbox(label="Generated Text")
139
  generate_button = gr.Button("Generate Text")
140
-
141
  def generate_text(text, repo_name, hf_write_token):
142
  try:
143
  tokenizer = AutoTokenizer.from_pretrained(repo_name, use_auth_token=hf_write_token)
@@ -146,8 +160,9 @@ def create_interface():
146
  generated_text = generator(text, max_length=50, num_beams=5, num_return_sequences=1)[0]['generated_text']
147
  return generated_text
148
  except Exception as e:
149
- return f"Error: {e}"
150
-
 
151
  generate_button.click(fn=generate_text, inputs=[text_input, repo_name, hf_write_token], outputs=text_output)
152
 
153
  return demo
 
2
  from transformers import AutoModel, AutoTokenizer, pipeline, AutoConfig, AutoModelForCausalLM
3
  from huggingface_hub import create_repo, HfApi, list_models
4
  from transformers.modeling_utils import PreTrainedModel
 
 
 
5
  import matplotlib.pyplot as plt
6
  from io import BytesIO
7
  import base64
 
23
 
24
  # Function to fetch open-weight LLM models
25
  def fetch_open_weight_models():
26
+ try:
27
+ models = list_models()
28
+ return models
29
+ except Exception as e:
30
+ logging.error(f"Error fetching models: {e}")
31
+ return []
32
 
33
+ # Custom function to retrieve just names from models list
34
+ def get_model_names():
35
+ models = fetch_open_weight_models()
36
+ model_names = [model.modelId for model in models if model.modelId is not None]
37
+ return model_names
38
+
39
+ # Full merge-kit Pruning Function
40
+ def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int, progress: gr.Progress) -> PreTrainedModel:
41
+ """Prunes a model using a merge-kit approach.
42
+ Args:
43
+ model (PreTrainedModel): The model to be pruned.
44
+ target_num_parameters (int): The target number of parameters after pruning.
45
+ progress (gr.Progress): The progress object for visual feedback.
46
+ Returns:
47
+ PreTrainedModel: The pruned model.
48
+ """
49
+
50
+ total_params = sum(p.numel() for p in model.parameters())
51
+ amount = 1 - (target_num_parameters / total_params)
52
+
53
+ try:
54
+ # Prune the model
55
+ for i, (name, module) in enumerate(tqdm(model.named_modules(), desc="Pruning", file=sys.stdout)):
56
+ if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
57
+ prune.random_unstructured(module, name="weight", amount=amount)
58
+ progress(percent_complete=50 * (i + 1) / len(list(model.named_modules()))) # Progress update
59
+
60
+ # Remove the pruned weights
61
+ for i, (name, module) in enumerate(model.named_modules()):
62
+ if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
63
+ prune.remove(module, name="weight")
64
+ progress(percent_complete=50 + 50 * (i + 1) / len(list(model.named_modules()))) # Progress update
65
+
66
+ return model
67
+ except Exception as e:
68
+ logging.error(f"Error during pruning: {e}")
69
+ raise e
70
+
71
+ # Function to prune a model
72
+ def prune_model(llm_model_name, target_size, hf_write_token, repo_name, base_model_name=None, progress=gr.Progress(track_tqdm=True)):
73
  log_messages = []
74
  try:
75
  # Load the LLM model and tokenizer
 
78
  llm_model_name,
79
  torch_dtype=torch.float16,
80
  )
81
+
82
  log_messages.append('Model and tokenizer loaded successfully.')
83
  logging.info('Model and tokenizer loaded successfully.')
84
+
85
+ total_params = sum(p.numel() for p in llm_model.parameters())
86
+ target_num_parameters = int(total_params * (target_size / 100))
87
+
 
88
  # Prune the model
89
  pruned_model = merge_kit_prune(llm_model, target_num_parameters, progress)
90
+
91
  log_messages.append('Model pruned successfully.')
92
  logging.info('Model pruned successfully.')
93
+
94
  # Save the pruned model
95
  api = HfApi()
96
  repo_id = f"{hf_write_token}/{repo_name}"
97
  create_repo(repo_id, token=hf_write_token, private=False, exist_ok=True)
98
  pruned_model.push_to_hub(repo_id, use_auth_token=hf_write_token)
99
  llm_tokenizer.push_to_hub(repo_id, use_auth_token=hf_write_token)
100
+
101
  log_messages.append(f"Pruned model saved to Hugging Face Hub in repository {repo_id}")
102
  logging.info(f"Pruned model saved to Hugging Face Hub in repository {repo_id}")
103
+
104
  # Create a visualization
105
  fig, ax = plt.subplots(figsize=(10, 5))
106
+ ax.bar(['Original', 'Pruned'], [total_params, sum(p.numel() for p in pruned_model.parameters())])
107
  ax.set_ylabel('Number of Parameters')
108
  ax.set_title('Model Size Comparison')
109
  buf = BytesIO()
110
  fig.savefig(buf, format='png')
111
  buf.seek(0)
112
  image_base64 = base64.b64encode(buf.read()).decode('utf-8')
113
+
114
  return f"Pruned model saved to Hugging Face Hub in repository {repo_id}", f"data:image/png;base64,{image_base64}", '\n'.join(log_messages)
115
+
116
  except Exception as e:
117
+ error_message = f"Detailed error: {repr(e)}"
118
  log_messages.append(error_message)
119
  logging.error(error_message)
120
  return error_message, None, '\n'.join(log_messages)
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  # Function to create a Gradio interface
123
  def create_interface():
124
  with gr.Blocks() as demo:
125
  gr.Markdown("## Create a Smaller LLM")
126
+
127
+ # Fetch available model names
128
+ model_names = get_model_names()
129
+
130
+ # Input components
131
+ llm_model_name = gr.Dropdown(label="Choose a Large Language Model", choices=model_names, interactive=True)
132
  target_size = gr.Slider(label="Target Model Size (%)", minimum=1, maximum=100, step=1, value=50, interactive=True)
133
  hf_write_token = gr.Textbox(label="Hugging Face Write Token", placeholder="Enter your HF write token", interactive=True, type="password")
134
  repo_name = gr.Textbox(label="Repository Name", placeholder="Enter the name of the repository", interactive=True)
135
+ pruned_func_choice = gr.Radio(label="Pruning Function", choices=["merge-kit"], value="merge-kit", interactive=True)
136
+ base_model_name = gr.Dropdown(label="Base Model Name (if required)", choices=model_names, interactive=True, visible=False)
137
+
138
  pruning_status = gr.Textbox(label="Pruning Status", interactive=False)
139
  prune_button = gr.Button("Prune Model")
140
  visualization = gr.Image(label="Model Size Comparison", interactive=False)
 
 
141
  progress_bar = gr.Progress()
142
+
143
+ def prune_model_with_progress(llm_model_name, target_size, hf_write_token, repo_name, pruned_func_choice, base_model_name):
144
+ if pruned_func_choice == "merge-kit":
145
+ return prune_model(llm_model_name, target_size, hf_write_token, repo_name, base_model_name, progress_bar)
146
+ else:
147
+ return f"Pruning function '{pruned_func_choice}' not implemented.", None, None
148
+
149
+ prune_button.click(fn=prune_model_with_progress, inputs=[llm_model_name, target_size, hf_write_token, repo_name, pruned_func_choice, base_model_name], outputs=[pruning_status, visualization])
150
+
 
 
 
 
151
  text_input = gr.Textbox(label="Input Text")
152
  text_output = gr.Textbox(label="Generated Text")
153
  generate_button = gr.Button("Generate Text")
154
+
155
  def generate_text(text, repo_name, hf_write_token):
156
  try:
157
  tokenizer = AutoTokenizer.from_pretrained(repo_name, use_auth_token=hf_write_token)
 
160
  generated_text = generator(text, max_length=50, num_beams=5, num_return_sequences=1)[0]['generated_text']
161
  return generated_text
162
  except Exception as e:
163
+ logging.error(f"Error during text generation: {e}")
164
+ return f"Error: {repr(e)}"
165
+
166
  generate_button.click(fn=generate_text, inputs=[text_input, repo_name, hf_write_token], outputs=text_output)
167
 
168
  return demo