Tech-Meld commited on
Commit
9ab273b
1 Parent(s): 4484172

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -28
app.py CHANGED
@@ -8,7 +8,6 @@ import base64
8
  import torch
9
  from torch.nn.utils import prune
10
  import subprocess
11
- from tqdm import tqdm
12
  import logging
13
  import sys
14
 
@@ -46,7 +45,6 @@ def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int, progress
46
  Returns:
47
  PreTrainedModel: The pruned model.
48
  """
49
-
50
  total_params = sum(p.numel() for p in model.parameters())
51
  amount = 1 - (target_num_parameters / total_params)
52
 
@@ -93,13 +91,12 @@ def prune_model(llm_model_name, target_size, hf_write_token, repo_name, base_mod
93
 
94
  # Save the pruned model
95
  api = HfApi()
96
- repo_id = f"{hf_write_token}/{repo_name}"
97
- create_repo(repo_id, token=hf_write_token, private=False, exist_ok=True)
98
- pruned_model.push_to_hub(repo_id, use_auth_token=hf_write_token)
99
- llm_tokenizer.push_to_hub(repo_id, use_auth_token=hf_write_token)
100
 
101
- log_messages.append(f"Pruned model saved to Hugging Face Hub in repository {repo_id}")
102
- logging.info(f"Pruned model saved to Hugging Face Hub in repository {repo_id}")
103
 
104
  # Create a visualization
105
  fig, ax = plt.subplots(figsize=(10, 5))
@@ -111,7 +108,7 @@ def prune_model(llm_model_name, target_size, hf_write_token, repo_name, base_mod
111
  buf.seek(0)
112
  image_base64 = base64.b64encode(buf.read()).decode('utf-8')
113
 
114
- return f"Pruned model saved to Hugging Face Hub in repository {repo_id}", f"data:image/png;base64,{image_base64}", '\n'.join(log_messages)
115
 
116
  except Exception as e:
117
  error_message = f"Detailed error: {repr(e)}"
@@ -119,6 +116,18 @@ def prune_model(llm_model_name, target_size, hf_write_token, repo_name, base_mod
119
  logging.error(error_message)
120
  return error_message, None, '\n'.join(log_messages)
121
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  # Function to create a Gradio interface
123
  def create_interface():
124
  with gr.Blocks() as demo:
@@ -126,43 +135,33 @@ def create_interface():
126
 
127
  # Fetch available model names
128
  model_names = get_model_names()
129
-
130
  # Input components
131
  llm_model_name = gr.Dropdown(label="Choose a Large Language Model", choices=model_names, interactive=True)
 
132
  target_size = gr.Slider(label="Target Model Size (%)", minimum=1, maximum=100, step=1, value=50, interactive=True)
133
  hf_write_token = gr.Textbox(label="Hugging Face Write Token", placeholder="Enter your HF write token", interactive=True, type="password")
134
  repo_name = gr.Textbox(label="Repository Name", placeholder="Enter the name of the repository", interactive=True)
135
  pruned_func_choice = gr.Radio(label="Pruning Function", choices=["merge-kit"], value="merge-kit", interactive=True)
136
- base_model_name = gr.Dropdown(label="Base Model Name (if required)", choices=model_names, interactive=True, visible=False)
137
 
138
  pruning_status = gr.Textbox(label="Pruning Status", interactive=False)
139
  prune_button = gr.Button("Prune Model")
140
  visualization = gr.Image(label="Model Size Comparison", interactive=False)
141
  progress_bar = gr.Progress()
142
-
143
- def prune_model_with_progress(llm_model_name, target_size, hf_write_token, repo_name, pruned_func_choice, base_model_name):
 
144
  if pruned_func_choice == "merge-kit":
145
  return prune_model(llm_model_name, target_size, hf_write_token, repo_name, base_model_name, progress_bar)
146
  else:
147
  return f"Pruning function '{pruned_func_choice}' not implemented.", None, None
148
-
149
- prune_button.click(fn=prune_model_with_progress, inputs=[llm_model_name, target_size, hf_write_token, repo_name, pruned_func_choice, base_model_name], outputs=[pruning_status, visualization])
150
-
151
  text_input = gr.Textbox(label="Input Text")
152
  text_output = gr.Textbox(label="Generated Text")
153
  generate_button = gr.Button("Generate Text")
154
-
155
- def generate_text(text, repo_name, hf_write_token):
156
- try:
157
- tokenizer = AutoTokenizer.from_pretrained(repo_name, use_auth_token=hf_write_token)
158
- model = AutoModelForCausalLM.from_pretrained(repo_name, use_auth_token=hf_write_token)
159
- generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
160
- generated_text = generator(text, max_length=50, num_beams=5, num_return_sequences=1)[0]['generated_text']
161
- return generated_text
162
- except Exception as e:
163
- logging.error(f"Error during text generation: {e}")
164
- return f"Error: {repr(e)}"
165
-
166
  generate_button.click(fn=generate_text, inputs=[text_input, repo_name, hf_write_token], outputs=text_output)
167
 
168
  return demo
 
8
  import torch
9
  from torch.nn.utils import prune
10
  import subprocess
 
11
  import logging
12
  import sys
13
 
 
45
  Returns:
46
  PreTrainedModel: The pruned model.
47
  """
 
48
  total_params = sum(p.numel() for p in model.parameters())
49
  amount = 1 - (target_num_parameters / total_params)
50
 
 
91
 
92
  # Save the pruned model
93
  api = HfApi()
94
+ create_repo(repo_name, token=hf_write_token, private=False, exist_ok=True)
95
+ pruned_model.push_to_hub(repo_name, use_auth_token=hf_write_token)
96
+ llm_tokenizer.push_to_hub(repo_name, use_auth_token=hf_write_token)
 
97
 
98
+ log_messages.append(f"Pruned model saved to Hugging Face Hub in repository {repo_name}")
99
+ logging.info(f"Pruned model saved to Hugging Face Hub in repository {repo_name}")
100
 
101
  # Create a visualization
102
  fig, ax = plt.subplots(figsize=(10, 5))
 
108
  buf.seek(0)
109
  image_base64 = base64.b64encode(buf.read()).decode('utf-8')
110
 
111
+ return f"Pruned model saved to Hugging Face Hub in repository {repo_name}", f"data:image/png;base64,{image_base64}", '\n'.join(log_messages)
112
 
113
  except Exception as e:
114
  error_message = f"Detailed error: {repr(e)}"
 
116
  logging.error(error_message)
117
  return error_message, None, '\n'.join(log_messages)
118
 
119
+ # Define function to generate text
120
+ def generate_text(text, repo_name, hf_write_token):
121
+ try:
122
+ tokenizer = AutoTokenizer.from_pretrained(repo_name, use_auth_token=hf_write_token)
123
+ model = AutoModelForCausalLM.from_pretrained(repo_name, use_auth_token=hf_write_token)
124
+ generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
125
+ generated_text = generator(text, max_length=50, num_beams=5, num_return_sequences=1)[0]['generated_text']
126
+ return generated_text
127
+ except Exception as e:
128
+ logging.error(f"Error during text generation: {e}")
129
+ return f"Error: {repr(e)}"
130
+
131
  # Function to create a Gradio interface
132
  def create_interface():
133
  with gr.Blocks() as demo:
 
135
 
136
  # Fetch available model names
137
  model_names = get_model_names()
138
+
139
  # Input components
140
  llm_model_name = gr.Dropdown(label="Choose a Large Language Model", choices=model_names, interactive=True)
141
+ base_model_name = gr.Dropdown(label="Base Model Name (if required)", choices=model_names, interactive=True, visible=False)
142
  target_size = gr.Slider(label="Target Model Size (%)", minimum=1, maximum=100, step=1, value=50, interactive=True)
143
  hf_write_token = gr.Textbox(label="Hugging Face Write Token", placeholder="Enter your HF write token", interactive=True, type="password")
144
  repo_name = gr.Textbox(label="Repository Name", placeholder="Enter the name of the repository", interactive=True)
145
  pruned_func_choice = gr.Radio(label="Pruning Function", choices=["merge-kit"], value="merge-kit", interactive=True)
 
146
 
147
  pruning_status = gr.Textbox(label="Pruning Status", interactive=False)
148
  prune_button = gr.Button("Prune Model")
149
  visualization = gr.Image(label="Model Size Comparison", interactive=False)
150
  progress_bar = gr.Progress()
151
+
152
+ # Define function for pruning model with progress
153
+ def prune_model_with_progress(llm_model_name, base_model_name, target_size, hf_write_token, repo_name, pruned_func_choice):
154
  if pruned_func_choice == "merge-kit":
155
  return prune_model(llm_model_name, target_size, hf_write_token, repo_name, base_model_name, progress_bar)
156
  else:
157
  return f"Pruning function '{pruned_func_choice}' not implemented.", None, None
158
+
159
+ prune_button.click(fn=prune_model_with_progress, inputs=[llm_model_name, base_model_name, target_size, hf_write_token, repo_name, pruned_func_choice], outputs=[pruning_status, visualization])
160
+
161
  text_input = gr.Textbox(label="Input Text")
162
  text_output = gr.Textbox(label="Generated Text")
163
  generate_button = gr.Button("Generate Text")
164
+
 
 
 
 
 
 
 
 
 
 
 
165
  generate_button.click(fn=generate_text, inputs=[text_input, repo_name, hf_write_token], outputs=text_output)
166
 
167
  return demo