Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoModel, AutoTokenizer, pipeline, AutoConfig, AutoModelForCausalLM | |
from huggingface_hub import create_repo, HfApi, list_models | |
from transformers.modeling_utils import PreTrainedModel | |
import requests | |
import json | |
import os | |
import matplotlib.pyplot as plt | |
from io import BytesIO | |
import base64 | |
import torch | |
from torch.nn.utils import prune | |
import subprocess | |
from tqdm import tqdm | |
import logging | |
import sys | |
# Setup logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
# Ensure sentencepiece is installed | |
try: | |
import sentencepiece | |
except ImportError: | |
subprocess.check_call(["pip", "install", "sentencepiece"]) | |
# Function to fetch open-weight LLM models | |
def fetch_open_weight_models(): | |
models = list_models() | |
return models | |
# Function to prune a model using the "merge-kit" approach | |
def prune_model(llm_model_name, target_size, hf_write_token, repo_name, progress=gr.Progress(track_tqdm=True)): | |
log_messages = [] | |
try: | |
# Load the LLM model and tokenizer | |
llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name) | |
llm_model = AutoModelForCausalLM.from_pretrained( | |
llm_model_name, | |
torch_dtype=torch.float16, | |
) | |
log_messages.append("Model and tokenizer loaded successfully.") | |
logging.info("Model and tokenizer loaded successfully.") | |
# Get the model config | |
config = AutoConfig.from_pretrained(llm_model_name) | |
target_num_parameters = int(config.num_parameters * (target_size / 100)) | |
# Prune the model | |
pruned_model = merge_kit_prune(llm_model, target_num_parameters, progress) | |
log_messages.append("Model pruned successfully.") | |
logging.info("Model pruned successfully.") | |
# Save the pruned model | |
api = HfApi() | |
repo_id = f"{hf_write_token}/{repo_name}" | |
create_repo(repo_id, token=hf_write_token, private=False, exist_ok=True) | |
pruned_model.push_to_hub(repo_id, use_auth_token=hf_write_token) | |
llm_tokenizer.push_to_hub(repo_id, use_auth_token=hf_write_token) | |
log_messages.append(f"Pruned model saved to Hugging Face Hub in repository {repo_id}") | |
logging.info(f"Pruned model saved to Hugging Face Hub in repository {repo_id}") | |
# Create a visualization | |
fig, ax = plt.subplots(figsize=(10, 5)) | |
ax.bar(["Original", "Pruned"], [config.num_parameters, pruned_model.num_parameters]) | |
ax.set_ylabel("Number of Parameters") | |
ax.set_title("Model Size Comparison") | |
buf = BytesIO() | |
fig.savefig(buf, format="png") | |
buf.seek(0) | |
image_base64 = base64.b64encode(buf.read()).decode("utf-8") | |
return f"Pruned model saved to Hugging Face Hub in repository {repo_id}", f"data:image/png;base64,{image_base64}", "\n".join(log_messages) | |
except Exception as e: | |
error_message = f"Error: {e}" | |
log_messages.append(error_message) | |
logging.error(error_message) | |
return error_message, None, "\n".join(log_messages) | |
# Merge-kit Pruning Function (adjust as needed) | |
def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int, progress: gr.Progress) -> PreTrainedModel: | |
"""Prunes a model using a merge-kit approach. | |
Args: | |
model (PreTrainedModel): The model to be pruned. | |
target_num_parameters (int): The target number of parameters after pruning. | |
Returns: | |
PreTrainedModel: The pruned model. | |
""" | |
# Define the pruning method | |
pruning_method = "unstructured" | |
# Calculate the pruning amount | |
total_params = sum(p.numel() for p in model.parameters()) | |
amount = 1 - (target_num_parameters / total_params) | |
# Prune the model using the selected method | |
for name, module in tqdm(model.named_modules(), desc="Pruning", file=sys.stdout): | |
if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)): | |
prune.random_unstructured(module, name="weight", amount=amount) | |
progress(percent_complete=50) # Example progress update | |
# Remove the pruned weights | |
for name, module in model.named_modules(): | |
if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)): | |
prune.remove(module, name="weight") | |
progress(percent_complete=100) # Example progress update | |
return model | |
# Function to create a Gradio interface | |
def create_interface(): | |
with gr.Blocks() as demo: | |
gr.Markdown("## Create a Smaller LLM") | |
llm_model_name = gr.Textbox(label="Choose a Large Language Model", placeholder="Enter the model name", interactive=True) | |
target_size = gr.Slider(label="Target Model Size (%)", minimum=1, maximum=100, step=1, value=50, interactive=True) | |
hf_write_token = gr.Textbox(label="Hugging Face Write Token", placeholder="Enter your HF write token", interactive=True, type="password") | |
repo_name = gr.Textbox(label="Repository Name", placeholder="Enter the name of the repository", interactive=True) | |
pruning_status = gr.Textbox(label="Pruning Status", interactive=False) | |
prune_button = gr.Button("Prune Model") | |
visualization = gr.Image(label="Model Size Comparison", interactive=False) | |
logs_button = gr.Button("Show Logs") | |
logs_output = gr.Textbox(label="Logs", interactive=False) | |
progress_bar = gr.Progress() | |
def show_logs(): | |
with open("pruning.log", "r") as log_file: | |
logs = log_file.read() | |
return logs | |
logs_button.click(fn=show_logs, outputs=logs_output) | |
def prune_model_with_progress(llm_model_name, target_size, hf_write_token, repo_name): | |
return prune_model(llm_model_name, target_size, hf_write_token, repo_name, progress_bar) | |
prune_button.click(fn=prune_model_with_progress, inputs=[llm_model_name, target_size, hf_write_token, repo_name], outputs=[pruning_status, visualization, logs_output]) | |
text_input = gr.Textbox(label="Input Text") | |
text_output = gr.Textbox(label="Generated Text") | |
generate_button = gr.Button("Generate Text") | |
def generate_text(text, repo_name): | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(repo_name, use_auth_token=hf_write_token) | |
model = AutoModelForCausalLM.from_pretrained(repo_name, use_auth_token=hf_write_token) | |
generator = pipeline("text-generation", model=model, tokenizer=tokenizer) | |
generated_text = generator(text, max_length=50, num_beams=5, num_return_sequences=1)[0]["generated_text"] | |
return generated_text | |
except Exception as e: | |
return f"Error: {e}" | |
generate_button.click(fn=generate_text, inputs=[text_input, repo_name], outputs=text_output) | |
return demo | |
# Create and launch the Gradio interface | |
demo = create_interface() | |
demo.launch(share=True) |