Tech-Meld commited on
Commit
241160e
1 Parent(s): 5796c7a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -22,14 +22,14 @@ def prune_model(llm_model_name, target_size, output_dir):
22
  try:
23
  # Load the LLM model and tokenizer
24
  llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
25
- llm_model = AutoModelForSeq2SeqLM.from_pretrained(llm_model_name)
26
 
27
  # Get the model config
28
  config = AutoConfig.from_pretrained(llm_model_name)
29
  # Calculate the target number of parameters
30
  target_num_parameters = int(config.num_parameters * (target_size / 100))
31
 
32
- # Use merge-kit to prune the model
33
  pruned_model = merge_kit_prune(llm_model, target_num_parameters)
34
 
35
  # Save the pruned model
@@ -67,9 +67,10 @@ def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int) -> PreTr
67
  # Calculate the pruning amount
68
  amount = 1 - (target_num_parameters / model.num_parameters)
69
 
70
- # Prune the model using the selected method
 
71
  for name, module in model.named_modules():
72
- if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
73
  prune.random_unstructured(module, name="weight", amount=amount)
74
 
75
  # Remove the pruned weights
 
22
  try:
23
  # Load the LLM model and tokenizer
24
  llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
25
+ llm_model = AutoModel.from_pretrained(llm_model_name) # Load using AutoModel
26
 
27
  # Get the model config
28
  config = AutoConfig.from_pretrained(llm_model_name)
29
  # Calculate the target number of parameters
30
  target_num_parameters = int(config.num_parameters * (target_size / 100))
31
 
32
+ # Use merge-kit to prune the model (modify pruning logic for Llama)
33
  pruned_model = merge_kit_prune(llm_model, target_num_parameters)
34
 
35
  # Save the pruned model
 
67
  # Calculate the pruning amount
68
  amount = 1 - (target_num_parameters / model.num_parameters)
69
 
70
+ # Prune the model using the selected method (adapt for Llama)
71
+ # Example: If Llama uses specific layers, adjust the pruning logic here
72
  for name, module in model.named_modules():
73
+ if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
74
  prune.random_unstructured(module, name="weight", amount=amount)
75
 
76
  # Remove the pruned weights