Spaces:
Runtime error
Runtime error
File size: 5,000 Bytes
dd6c56a 3e5d73a dd6c56a 3e5d73a dd6c56a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline, AutoConfig
from huggingface_hub import cached_download, hf_hub_url, list_models
import requests
import json
import os
import matplotlib.pyplot as plt
from io import BytesIO
import base64
# Choose your backend (PyTorch, TensorFlow, or Flax)
import torch # If using PyTorch
# Function to fetch open-weight LLM models
def fetch_open_weight_models():
models = list_models(filter="open-weight", sort="downloads", limit=10)
return [model["id"] for model in models]
# Function to prune a model using the "merge-kit" approach
def prune_model(llm_model_name, target_size, output_dir):
# Load the LLM model and tokenizer
llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
llm_model = AutoModelForSeq2SeqLM.from_pretrained(llm_model_name)
# Get the model config
config = AutoConfig.from_pretrained(llm_model_name)
# Calculate the target number of parameters
target_num_parameters = int(config.num_parameters * (target_size / 100))
# Use merge-kit to prune the model
pruned_model = merge_kit_prune(llm_model, target_num_parameters)
# Save the pruned model
pruned_model.save_pretrained(output_dir)
# Create a visualization
fig, ax = plt.subplots(figsize=(10, 5))
ax.bar(["Original", "Pruned"], [config.num_parameters, pruned_model.num_parameters])
ax.set_ylabel("Number of Parameters")
ax.set_title("Model Size Comparison")
buf = BytesIO()
fig.savefig(buf, format="png")
buf.seek(0)
image_base64 = base64.b64encode(buf.read()).decode("utf-8")
return f"Pruned model saved to {output_dir}", f"data:image/png;base64,{image_base64}"
# Merge-kit Pruning Function
def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int) -> PreTrainedModel:
"""Prunes a model using a merge-kit approach.
Args:
model (PreTrainedModel): The model to be pruned.
target_num_parameters (int): The target number of parameters after pruning.
Returns:
PreTrainedModel: The pruned model.
"""
# Define the pruning method
pruning_method = "unstructured"
# Calculate the pruning amount
amount = 1 - (target_num_parameters / model.num_parameters)
# Prune the model using the selected method
for name, module in model.named_modules():
if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
prune.random_unstructured(module, name="weight", amount=amount)
# Remove the pruned weights
for name, module in model.named_modules():
if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
prune.remove(module, name="weight")
return model
# Function to create a Gradio interface
def create_interface():
with gr.Blocks() as demo:
gr.Markdown("## Create a Smaller LLM")
# Fetch open-weight models from Hugging Face
available_models = gr.Dropdown(
label="Choose a Large Language Model",
choices=fetch_open_weight_models(),
interactive=True,
)
# Input for target model size
target_size = gr.Slider(
label="Target Model Size (%)",
minimum=1,
maximum=100,
step=1,
value=50,
interactive=True,
)
# Output for pruning status
pruning_status = gr.Textbox(label="Pruning Status")
# Output for saving the model
save_model_path = gr.Textbox(label="Save Model Path", placeholder="Path to save the pruned model", interactive=True)
# Button to start pruning
prune_button = gr.Button("Prune Model")
# Output for visualization
visualization = gr.Image(label="Model Size Comparison")
# Connect components
prune_button.click(
fn=prune_model,
inputs=[available_models, target_size, save_model_path],
outputs=[pruning_status, visualization],
)
# Example usage of the pruned model (optional)
text_input = gr.Textbox(label="Input Text")
text_output = gr.Textbox(label="Generated Text")
# Generate text button
generate_button = gr.Button("Generate Text")
def generate_text(text, model_path):
# Load the pruned model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
# Use the pipeline for text generation
generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
generated_text = generator(text, max_length=50, num_beams=5, num_return_sequences=1)[0]["generated_text"]
return generated_text
generate_button.click(fn=generate_text, inputs=[text_input, save_model_path], outputs=text_output)
return demo
# Create and launch the Gradio interface
demo = create_interface()
demo.launch(share=True) |