Spaces:
Runtime error
Runtime error
File size: 6,760 Bytes
dd6c56a 8bb39cb 0445e3f 7730f68 dd6c56a 5796c7a 3cfae8b 0445e3f 8a2d207 dd6c56a 0445e3f dd6c56a 3cfae8b 10fbcfe 3cfae8b 0445e3f dd6c56a 0445e3f 5796c7a 8bb39cb 3cfae8b 8bb39cb 5796c7a 10fbcfe 0445e3f 5796c7a 3cfae8b 8a2d207 0445e3f 10fbcfe 5796c7a 3cfae8b 8bb39cb 5796c7a 0445e3f 5796c7a 10fbcfe 5796c7a 10fbcfe 5796c7a 10fbcfe 3cfae8b 10fbcfe 5796c7a 0445e3f 10fbcfe dd6c56a 3b422da 8a2d207 dd6c56a 10fbcfe 0445e3f 10fbcfe dd6c56a 10fbcfe 241160e 10fbcfe dd6c56a 10fbcfe dd6c56a 5796c7a 3cfae8b 8bb39cb dd6c56a 8bb39cb 0445e3f ae62561 0445e3f 10fbcfe 0445e3f dd6c56a ae62561 8a2d207 ae62561 dd6c56a 10fbcfe 5796c7a 8bb39cb 10fbcfe 5796c7a dd6c56a 10fbcfe dd6c56a 10fbcfe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import gradio as gr
from transformers import AutoModel, AutoTokenizer, pipeline, AutoConfig, AutoModelForCausalLM
from huggingface_hub import create_repo, HfApi, list_models
from transformers.modeling_utils import PreTrainedModel
import requests
import json
import os
import matplotlib.pyplot as plt
from io import BytesIO
import base64
import torch
from torch.nn.utils import prune
import subprocess
from tqdm import tqdm
import logging
import sys
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Ensure sentencepiece is installed
try:
import sentencepiece
except ImportError:
subprocess.check_call(['pip', 'install', 'sentencepiece'])
# Function to fetch open-weight LLM models
def fetch_open_weight_models():
models = list_models()
return models
# Function to prune a model using the "merge-kit" approach
def prune_model(llm_model_name, target_size, hf_write_token, repo_name, progress=gr.Progress(track_tqdm=True)):
log_messages = []
try:
# Load the LLM model and tokenizer
llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
llm_model = AutoModelForCausalLM.from_pretrained(
llm_model_name,
torch_dtype=torch.float16,
)
log_messages.append('Model and tokenizer loaded successfully.')
logging.info('Model and tokenizer loaded successfully.')
# Get the model config
config = AutoConfig.from_pretrained(llm_model_name)
target_num_parameters = int(config.num_parameters * (target_size / 100))
# Prune the model
pruned_model = merge_kit_prune(llm_model, target_num_parameters, progress)
log_messages.append('Model pruned successfully.')
logging.info('Model pruned successfully.')
# Save the pruned model
api = HfApi()
repo_id = f"{hf_write_token}/{repo_name}"
create_repo(repo_id, token=hf_write_token, private=False, exist_ok=True)
pruned_model.push_to_hub(repo_id, use_auth_token=hf_write_token)
llm_tokenizer.push_to_hub(repo_id, use_auth_token=hf_write_token)
log_messages.append(f"Pruned model saved to Hugging Face Hub in repository {repo_id}")
logging.info(f"Pruned model saved to Hugging Face Hub in repository {repo_id}")
# Create a visualization
fig, ax = plt.subplots(figsize=(10, 5))
ax.bar(['Original', 'Pruned'], [config.num_parameters, sum(p.numel() for p in pruned_model.parameters())])
ax.set_ylabel('Number of Parameters')
ax.set_title('Model Size Comparison')
buf = BytesIO()
fig.savefig(buf, format='png')
buf.seek(0)
image_base64 = base64.b64encode(buf.read()).decode('utf-8')
return f"Pruned model saved to Hugging Face Hub in repository {repo_id}", f"data:image/png;base64,{image_base64}", '\n'.join(log_messages)
except Exception as e:
error_message = f"Error: {e}"
log_messages.append(error_message)
logging.error(error_message)
return error_message, None, '\n'.join(log_messages)
# Merge-kit Pruning Function (adjust as needed)
def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int, progress: gr.Progress) -> PreTrainedModel:
"""Prunes a model using a merge-kit approach.
Args:
model (PreTrainedModel): The model to be pruned.
target_num_parameters (int): The target number of parameters after pruning.
Returns:
PreTrainedModel: The pruned model.
"""
total_params = sum(p.numel() for p in model.parameters())
amount = 1 - (target_num_parameters / total_params)
for name, module in tqdm(model.named_modules(), desc='Pruning', file=sys.stdout):
if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
prune.random_unstructured(module, name='weight', amount=amount)
progress(percent_complete=50)
for name, module in model.named_modules():
if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
prune.remove(module, name='weight')
progress(percent_complete=100)
return model
# Function to create a Gradio interface
def create_interface():
with gr.Blocks() as demo:
gr.Markdown("## Create a Smaller LLM")
llm_model_name = gr.Textbox(label="Choose a Large Language Model", placeholder="Enter the model name", interactive=True)
target_size = gr.Slider(label="Target Model Size (%)", minimum=1, maximum=100, step=1, value=50, interactive=True)
hf_write_token = gr.Textbox(label="Hugging Face Write Token", placeholder="Enter your HF write token", interactive=True, type="password")
repo_name = gr.Textbox(label="Repository Name", placeholder="Enter the name of the repository", interactive=True)
pruning_status = gr.Textbox(label="Pruning Status", interactive=False)
prune_button = gr.Button("Prune Model")
visualization = gr.Image(label="Model Size Comparison", interactive=False)
logs_button = gr.Button("Show Logs")
logs_output = gr.Textbox(label="Logs", interactive=False)
progress_bar = gr.Progress()
def show_logs():
with open('pruning.log', 'r') as log_file:
logs = log_file.read()
return logs
logs_button.click(fn=show_logs, outputs=logs_output)
def prune_model_with_progress(llm_model_name, target_size, hf_write_token, repo_name):
return prune_model(llm_model_name, target_size, hf_write_token, repo_name, progress_bar)
prune_button.click(fn=prune_model_with_progress, inputs=[llm_model_name, target_size, hf_write_token, repo_name], outputs=[pruning_status, visualization, logs_output])
text_input = gr.Textbox(label="Input Text")
text_output = gr.Textbox(label="Generated Text")
generate_button = gr.Button("Generate Text")
def generate_text(text, repo_name, hf_write_token):
try:
tokenizer = AutoTokenizer.from_pretrained(repo_name, use_auth_token=hf_write_token)
model = AutoModelForCausalLM.from_pretrained(repo_name, use_auth_token=hf_write_token)
generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
generated_text = generator(text, max_length=50, num_beams=5, num_return_sequences=1)[0]['generated_text']
return generated_text
except Exception as e:
return f"Error: {e}"
generate_button.click(fn=generate_text, inputs=[text_input, repo_name, hf_write_token], outputs=text_output)
return demo
# Create and launch the Gradio interface
demo = create_interface()
demo.launch() |