Tech-Meld's picture
Update app.py
5796c7a verified
raw
history blame
5.19 kB
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline, AutoConfig
from huggingface_hub import cached_download, hf_hub_url, list_models
import requests
import json
import os
import matplotlib.pyplot as plt
from io import BytesIO
import base64
from transformers.models.auto import AutoModel
from transformers.modeling_utils import PreTrainedModel
import torch
from torch.nn.utils import prune
# Function to fetch open-weight LLM models
def fetch_open_weight_models():
models = list_models()
return models
# Function to prune a model using the "merge-kit" approach
def prune_model(llm_model_name, target_size, output_dir):
try:
# Load the LLM model and tokenizer
llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
llm_model = AutoModelForSeq2SeqLM.from_pretrained(llm_model_name)
# Get the model config
config = AutoConfig.from_pretrained(llm_model_name)
# Calculate the target number of parameters
target_num_parameters = int(config.num_parameters * (target_size / 100))
# Use merge-kit to prune the model
pruned_model = merge_kit_prune(llm_model, target_num_parameters)
# Save the pruned model
pruned_model.save_pretrained(output_dir)
# Create a visualization
fig, ax = plt.subplots(figsize=(10, 5))
ax.bar(["Original", "Pruned"], [config.num_parameters, pruned_model.num_parameters])
ax.set_ylabel("Number of Parameters")
ax.set_title("Model Size Comparison")
buf = BytesIO()
fig.savefig(buf, format="png")
buf.seek(0)
image_base64 = base64.b64encode(buf.read()).decode("utf-8")
return f"Pruned model saved to {output_dir}", f"data:image/png;base64,{image_base64}"
except Exception as e:
return f"Error: {e}", None
# Merge-kit Pruning Function
def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int) -> PreTrainedModel:
"""Prunes a model using a merge-kit approach.
Args:
model (PreTrainedModel): The model to be pruned.
target_num_parameters (int): The target number of parameters after pruning.
Returns:
PreTrainedModel: The pruned model.
"""
# Define the pruning method
pruning_method = "unstructured"
# Calculate the pruning amount
amount = 1 - (target_num_parameters / model.num_parameters)
# Prune the model using the selected method
for name, module in model.named_modules():
if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
prune.random_unstructured(module, name="weight", amount=amount)
# Remove the pruned weights
for name, module in model.named_modules():
if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
prune.remove(module, name="weight")
return model
# Function to create a Gradio interface
def create_interface():
with gr.Blocks() as demo:
gr.Markdown("## Create a Smaller LLM")
# Input for model name
llm_model_name = gr.Textbox(label="Choose a Large Language Model", placeholder="Enter the model name", interactive=True)
# Input for target model size
target_size = gr.Slider(
label="Target Model Size (%)",
minimum=1,
maximum=100,
step=1,
value=50,
interactive=True,
)
# Output for pruning status
pruning_status = gr.Textbox(label="Pruning Status")
# Output for saving the model
save_model_path = gr.Textbox(label="Save Model Path", placeholder="Path to save the pruned model", interactive=True)
# Button to start pruning
prune_button = gr.Button("Prune Model")
# Output for visualization
visualization = gr.Image(label="Model Size Comparison")
# Connect components
prune_button.click(
fn=prune_model,
inputs=[llm_model_name, target_size, save_model_path],
outputs=[pruning_status, visualization],
)
# Example usage of the pruned model (optional)
text_input = gr.Textbox(label="Input Text")
text_output = gr.Textbox(label="Generated Text")
# Generate text button
generate_button = gr.Button("Generate Text")
def generate_text(text, model_path):
try:
# Load the pruned model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
# Use the pipeline for text generation
generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
generated_text = generator(text, max_length=50, num_beams=5, num_return_sequences=1)[0]["generated_text"]
return generated_text
except Exception as e:
return f"Error: {e}"
generate_button.click(fn=generate_text, inputs=[text_input, save_model_path], outputs=text_output)
return demo
# Create and launch the Gradio interface
demo = create_interface()
demo.launch(share=True)