Tech-Meld commited on
Commit
3b422da
1 Parent(s): 241160e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -7
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline, AutoConfig
3
  from huggingface_hub import cached_download, hf_hub_url, list_models
4
  import requests
5
  import json
@@ -7,10 +7,9 @@ import os
7
  import matplotlib.pyplot as plt
8
  from io import BytesIO
9
  import base64
10
- from transformers.models.auto import AutoModel
11
- from transformers.modeling_utils import PreTrainedModel
12
  import torch
13
  from torch.nn.utils import prune
 
14
 
15
  # Function to fetch open-weight LLM models
16
  def fetch_open_weight_models():
@@ -22,14 +21,23 @@ def prune_model(llm_model_name, target_size, output_dir):
22
  try:
23
  # Load the LLM model and tokenizer
24
  llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
25
- llm_model = AutoModel.from_pretrained(llm_model_name) # Load using AutoModel
 
 
 
 
 
 
 
 
 
26
 
27
  # Get the model config
28
  config = AutoConfig.from_pretrained(llm_model_name)
29
  # Calculate the target number of parameters
30
  target_num_parameters = int(config.num_parameters * (target_size / 100))
31
 
32
- # Use merge-kit to prune the model (modify pruning logic for Llama)
33
  pruned_model = merge_kit_prune(llm_model, target_num_parameters)
34
 
35
  # Save the pruned model
@@ -49,7 +57,7 @@ def prune_model(llm_model_name, target_size, output_dir):
49
  except Exception as e:
50
  return f"Error: {e}", None
51
 
52
- # Merge-kit Pruning Function
53
  def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int) -> PreTrainedModel:
54
  """Prunes a model using a merge-kit approach.
55
 
@@ -128,7 +136,7 @@ def create_interface():
128
  try:
129
  # Load the pruned model and tokenizer
130
  tokenizer = AutoTokenizer.from_pretrained(model_path)
131
- model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
132
 
133
  # Use the pipeline for text generation
134
  generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
 
1
  import gradio as gr
2
+ from transformers import AutoModel, AutoTokenizer, pipeline, AutoConfig
3
  from huggingface_hub import cached_download, hf_hub_url, list_models
4
  import requests
5
  import json
 
7
  import matplotlib.pyplot as plt
8
  from io import BytesIO
9
  import base64
 
 
10
  import torch
11
  from torch.nn.utils import prune
12
+ from transformers.models.auto import AutoModelForCausalLM # Import for CausalLM
13
 
14
  # Function to fetch open-weight LLM models
15
  def fetch_open_weight_models():
 
21
  try:
22
  # Load the LLM model and tokenizer
23
  llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
24
+ # Handle cases where the model is split into multiple safetensors
25
+ if "safetensors" in llm_tokenizer.vocab_files_names:
26
+ llm_model = AutoModelForCausalLM.from_pretrained(
27
+ llm_model_name,
28
+ from_safetensors=True,
29
+ torch_dtype=torch.float16, # Adjust dtype as needed
30
+ use_auth_token=None,
31
+ )
32
+ else:
33
+ llm_model = AutoModel.from_pretrained(llm_model_name)
34
 
35
  # Get the model config
36
  config = AutoConfig.from_pretrained(llm_model_name)
37
  # Calculate the target number of parameters
38
  target_num_parameters = int(config.num_parameters * (target_size / 100))
39
 
40
+ # Use merge-kit to prune the model
41
  pruned_model = merge_kit_prune(llm_model, target_num_parameters)
42
 
43
  # Save the pruned model
 
57
  except Exception as e:
58
  return f"Error: {e}", None
59
 
60
+ # Merge-kit Pruning Function (adjust as needed)
61
  def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int) -> PreTrainedModel:
62
  """Prunes a model using a merge-kit approach.
63
 
 
136
  try:
137
  # Load the pruned model and tokenizer
138
  tokenizer = AutoTokenizer.from_pretrained(model_path)
139
+ model = AutoModelForCausalLM.from_pretrained(model_path) # Load as CausalLM
140
 
141
  # Use the pipeline for text generation
142
  generator = pipeline("text-generation", model=model, tokenizer=tokenizer)