Tech-Meld commited on
Commit
10fbcfe
1 Parent(s): 8fdbc8b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -31
app.py CHANGED
@@ -22,7 +22,7 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(
22
  try:
23
  import sentencepiece
24
  except ImportError:
25
- subprocess.check_call(["pip", "install", "sentencepiece"])
26
 
27
  # Function to fetch open-weight LLM models
28
  def fetch_open_weight_models():
@@ -40,8 +40,8 @@ def prune_model(llm_model_name, target_size, hf_write_token, repo_name, progress
40
  torch_dtype=torch.float16,
41
  )
42
 
43
- log_messages.append("Model and tokenizer loaded successfully.")
44
- logging.info("Model and tokenizer loaded successfully.")
45
 
46
  # Get the model config
47
  config = AutoConfig.from_pretrained(llm_model_name)
@@ -50,8 +50,8 @@ def prune_model(llm_model_name, target_size, hf_write_token, repo_name, progress
50
  # Prune the model
51
  pruned_model = merge_kit_prune(llm_model, target_num_parameters, progress)
52
 
53
- log_messages.append("Model pruned successfully.")
54
- logging.info("Model pruned successfully.")
55
 
56
  # Save the pruned model
57
  api = HfApi()
@@ -65,21 +65,21 @@ def prune_model(llm_model_name, target_size, hf_write_token, repo_name, progress
65
 
66
  # Create a visualization
67
  fig, ax = plt.subplots(figsize=(10, 5))
68
- ax.bar(["Original", "Pruned"], [config.num_parameters, sum(p.numel() for p in pruned_model.parameters())])
69
- ax.set_ylabel("Number of Parameters")
70
- ax.set_title("Model Size Comparison")
71
  buf = BytesIO()
72
- fig.savefig(buf, format="png")
73
  buf.seek(0)
74
- image_base64 = base64.b64encode(buf.read()).decode("utf-8")
75
 
76
- return f"Pruned model saved to Hugging Face Hub in repository {repo_id}", f"data:image/png;base64,{image_base64}", "\n".join(log_messages)
77
 
78
  except Exception as e:
79
  error_message = f"Error: {e}"
80
  log_messages.append(error_message)
81
  logging.error(error_message)
82
- return error_message, None, "\n".join(log_messages)
83
 
84
  # Merge-kit Pruning Function (adjust as needed)
85
  def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int, progress: gr.Progress) -> PreTrainedModel:
@@ -90,25 +90,20 @@ def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int, progress
90
  Returns:
91
  PreTrainedModel: The pruned model.
92
  """
93
- # Define the pruning method
94
- pruning_method = "unstructured"
95
-
96
- # Calculate the pruning amount
97
  total_params = sum(p.numel() for p in model.parameters())
98
- amount = 1 - (target_num_parameters / total_params)
99
 
100
- # Prune the model using the selected method
101
- for name, module in tqdm(model.named_modules(), desc="Pruning", file=sys.stdout):
102
  if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
103
- prune.random_unstructured(module, name="weight", amount=amount)
104
- progress(percent_complete=50) # Example progress update
105
 
106
- # Remove the pruned weights
107
  for name, module in model.named_modules():
108
  if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
109
- prune.remove(module, name="weight")
110
- progress(percent_complete=100) # Example progress update
111
-
112
  return model
113
 
114
  # Function to create a Gradio interface
@@ -128,7 +123,7 @@ def create_interface():
128
  progress_bar = gr.Progress()
129
 
130
  def show_logs():
131
- with open("pruning.log", "r") as log_file:
132
  logs = log_file.read()
133
  return logs
134
 
@@ -143,20 +138,20 @@ def create_interface():
143
  text_output = gr.Textbox(label="Generated Text")
144
  generate_button = gr.Button("Generate Text")
145
 
146
- def generate_text(text, repo_name):
147
  try:
148
  tokenizer = AutoTokenizer.from_pretrained(repo_name, use_auth_token=hf_write_token)
149
  model = AutoModelForCausalLM.from_pretrained(repo_name, use_auth_token=hf_write_token)
150
- generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
151
- generated_text = generator(text, max_length=50, num_beams=5, num_return_sequences=1)[0]["generated_text"]
152
  return generated_text
153
  except Exception as e:
154
  return f"Error: {e}"
155
 
156
- generate_button.click(fn=generate_text, inputs=[text_input, repo_name], outputs=text_output)
157
 
158
  return demo
159
 
160
  # Create and launch the Gradio interface
161
  demo = create_interface()
162
- demo.launch(share=True)
 
22
  try:
23
  import sentencepiece
24
  except ImportError:
25
+ subprocess.check_call(['pip', 'install', 'sentencepiece'])
26
 
27
  # Function to fetch open-weight LLM models
28
  def fetch_open_weight_models():
 
40
  torch_dtype=torch.float16,
41
  )
42
 
43
+ log_messages.append('Model and tokenizer loaded successfully.')
44
+ logging.info('Model and tokenizer loaded successfully.')
45
 
46
  # Get the model config
47
  config = AutoConfig.from_pretrained(llm_model_name)
 
50
  # Prune the model
51
  pruned_model = merge_kit_prune(llm_model, target_num_parameters, progress)
52
 
53
+ log_messages.append('Model pruned successfully.')
54
+ logging.info('Model pruned successfully.')
55
 
56
  # Save the pruned model
57
  api = HfApi()
 
65
 
66
  # Create a visualization
67
  fig, ax = plt.subplots(figsize=(10, 5))
68
+ ax.bar(['Original', 'Pruned'], [config.num_parameters, sum(p.numel() for p in pruned_model.parameters())])
69
+ ax.set_ylabel('Number of Parameters')
70
+ ax.set_title('Model Size Comparison')
71
  buf = BytesIO()
72
+ fig.savefig(buf, format='png')
73
  buf.seek(0)
74
+ image_base64 = base64.b64encode(buf.read()).decode('utf-8')
75
 
76
+ return f"Pruned model saved to Hugging Face Hub in repository {repo_id}", f"data:image/png;base64,{image_base64}", '\n'.join(log_messages)
77
 
78
  except Exception as e:
79
  error_message = f"Error: {e}"
80
  log_messages.append(error_message)
81
  logging.error(error_message)
82
+ return error_message, None, '\n'.join(log_messages)
83
 
84
  # Merge-kit Pruning Function (adjust as needed)
85
  def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int, progress: gr.Progress) -> PreTrainedModel:
 
90
  Returns:
91
  PreTrainedModel: The pruned model.
92
  """
93
+
 
 
 
94
  total_params = sum(p.numel() for p in model.parameters())
95
+ amount = 1 - (target_num_parameters / total_params)
96
 
97
+ for name, module in tqdm(model.named_modules(), desc='Pruning', file=sys.stdout):
 
98
  if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
99
+ prune.random_unstructured(module, name='weight', amount=amount)
100
+ progress(percent_complete=50)
101
 
 
102
  for name, module in model.named_modules():
103
  if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
104
+ prune.remove(module, name='weight')
105
+ progress(percent_complete=100)
106
+
107
  return model
108
 
109
  # Function to create a Gradio interface
 
123
  progress_bar = gr.Progress()
124
 
125
  def show_logs():
126
+ with open('pruning.log', 'r') as log_file:
127
  logs = log_file.read()
128
  return logs
129
 
 
138
  text_output = gr.Textbox(label="Generated Text")
139
  generate_button = gr.Button("Generate Text")
140
 
141
+ def generate_text(text, repo_name, hf_write_token):
142
  try:
143
  tokenizer = AutoTokenizer.from_pretrained(repo_name, use_auth_token=hf_write_token)
144
  model = AutoModelForCausalLM.from_pretrained(repo_name, use_auth_token=hf_write_token)
145
+ generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
146
+ generated_text = generator(text, max_length=50, num_beams=5, num_return_sequences=1)[0]['generated_text']
147
  return generated_text
148
  except Exception as e:
149
  return f"Error: {e}"
150
 
151
+ generate_button.click(fn=generate_text, inputs=[text_input, repo_name, hf_write_token], outputs=text_output)
152
 
153
  return demo
154
 
155
  # Create and launch the Gradio interface
156
  demo = create_interface()
157
+ demo.launch()