Tech-Meld's picture
Update app.py
3cfae8b verified
raw
history blame
5.35 kB
import gradio as gr
from transformers import AutoModel, AutoTokenizer, pipeline, AutoConfig, AutoModelForCausalLM
from huggingface_hub import cached_download, hf_hub_url, list_models, create_repo, HfApi
from transformers.modeling_utils import PreTrainedModel
import requests
import json
import os
import matplotlib.pyplot as plt
from io import BytesIO
import base64
import torch
from torch.nn.utils import prune
import subprocess
# Function to fetch open-weight LLM models
def fetch_open_weight_models():
models = list_models()
return models
# Ensure sentencepiece is installed
try:
import sentencepiece
except ImportError:
subprocess.check_call(["pip", "install", "sentencepiece"])
# Function to prune a model using the "merge-kit" approach
def prune_model(llm_model_name, target_size, hf_write_token, repo_name):
try:
# Load the LLM model and tokenizer
llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
llm_model = AutoModelForCausalLM.from_pretrained(
llm_model_name,
torch_dtype=torch.float16,
)
# Get the model config
config = AutoConfig.from_pretrained(llm_model_name)
target_num_parameters = int(config.num_parameters * (target_size / 100))
# Prune the model
pruned_model = merge_kit_prune(llm_model, target_num_parameters)
# Save the pruned model
api = HfApi()
repo_id = f"{hf_write_token}/{repo_name}"
create_repo(repo_id, token=hf_write_token, private=False, exist_ok=True)
pruned_model.push_to_hub(repo_id, use_auth_token=hf_write_token)
llm_tokenizer.push_to_hub(repo_id, use_auth_token=hf_write_token)
# Create a visualization
fig, ax = plt.subplots(figsize=(10, 5))
ax.bar(["Original", "Pruned"], [config.num_parameters, pruned_model.num_parameters])
ax.set_ylabel("Number of Parameters")
ax.set_title("Model Size Comparison")
buf = BytesIO()
fig.savefig(buf, format="png")
buf.seek(0)
image_base64 = base64.b64encode(buf.read()).decode("utf-8")
return f"Pruned model saved to Hugging Face Hub in repository {repo_id}", f"data:image/png;base64,{image_base64}", None
except Exception as e:
return f"Error: {e}", None, None
# Merge-kit Pruning Function (adjust as needed)
def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int) -> PreTrainedModel:
"""Prunes a model using a merge-kit approach.
Args:
model (PreTrainedModel): The model to be pruned.
target_num_parameters (int): The target number of parameters after pruning.
Returns:
PreTrainedModel: The pruned model.
"""
# Define the pruning method
pruning_method = "unstructured"
# Calculate the pruning amount
amount = 1 - (target_num_parameters / sum(p.numel() for p in model.parameters()))
# Prune the model using the selected method
for name, module in model.named_modules():
if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
prune.random_unstructured(module, name="weight", amount=amount)
# Remove the pruned weights
for name, module in model.named_modules():
if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
prune.remove(module, name="weight")
return model
# Function to create a Gradio interface
def create_interface():
with gr.Blocks() as demo:
gr.Markdown("## Create a Smaller LLM")
llm_model_name = gr.Textbox(label="Choose a Large Language Model", placeholder="Enter the model name", interactive=True)
target_size = gr.Slider(label="Target Model Size (%)", minimum=1, maximum=100, step=1, value=50, interactive=True)
hf_write_token = gr.Textbox(label="Hugging Face Write Token", placeholder="Enter your HF write token", interactive=True, type="password")
repo_name = gr.Textbox(label="Repository Name", placeholder="Enter the name of the repository", interactive=True)
pruning_status = gr.Textbox(label="Pruning Status", interactive=False)
prune_button = gr.Button("Prune Model")
visualization = gr.Image(label="Model Size Comparison", interactive=False)
prune_button.click(fn=prune_model, inputs=[llm_model_name, target_size, hf_write_token, repo_name], outputs=[pruning_status, visualization])
text_input = gr.Textbox(label="Input Text")
text_output = gr.Textbox(label="Generated Text")
generate_button = gr.Button("Generate Text")
def generate_text(text, repo_name):
try:
tokenizer = AutoTokenizer.from_pretrained(repo_name, use_auth_token=hf_write_token)
model = AutoModelForCausalLM.from_pretrained(repo_name, use_auth_token=hf_write_token)
generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
generated_text = generator(text, max_length=50, num_beams=5, num_return_sequences=1)[0]["generated_text"]
return generated_text
except Exception as e:
return f"Error: {e}"
generate_button.click(fn=generate_text, inputs=[text_input, repo_name], outputs=text_output)
return demo
# Create and launch the Gradio interface
demo = create_interface()
demo.launch(share=True)