import gradio as gr from transformers import AutoModel, AutoTokenizer, pipeline, AutoConfig, AutoModelForCausalLM from huggingface_hub import create_repo, HfApi, list_models from transformers.modeling_utils import PreTrainedModel import requests import json import os import matplotlib.pyplot as plt from io import BytesIO import base64 import torch from torch.nn.utils import prune import subprocess from tqdm import tqdm import logging import sys # Setup logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # Ensure sentencepiece is installed try: import sentencepiece except ImportError: subprocess.check_call(['pip', 'install', 'sentencepiece']) # Function to fetch open-weight LLM models def fetch_open_weight_models(): models = list_models() return models # Function to prune a model using the "merge-kit" approach def prune_model(llm_model_name, target_size, hf_write_token, repo_name, progress=gr.Progress(track_tqdm=True)): log_messages = [] try: # Load the LLM model and tokenizer llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name) llm_model = AutoModelForCausalLM.from_pretrained( llm_model_name, torch_dtype=torch.float16, ) log_messages.append('Model and tokenizer loaded successfully.') logging.info('Model and tokenizer loaded successfully.') # Get the model config config = AutoConfig.from_pretrained(llm_model_name) target_num_parameters = int(config.num_parameters * (target_size / 100)) # Prune the model pruned_model = merge_kit_prune(llm_model, target_num_parameters, progress) log_messages.append('Model pruned successfully.') logging.info('Model pruned successfully.') # Save the pruned model api = HfApi() repo_id = f"{hf_write_token}/{repo_name}" create_repo(repo_id, token=hf_write_token, private=False, exist_ok=True) pruned_model.push_to_hub(repo_id, use_auth_token=hf_write_token) llm_tokenizer.push_to_hub(repo_id, use_auth_token=hf_write_token) log_messages.append(f"Pruned model saved to Hugging Face Hub in repository {repo_id}") logging.info(f"Pruned model saved to Hugging Face Hub in repository {repo_id}") # Create a visualization fig, ax = plt.subplots(figsize=(10, 5)) ax.bar(['Original', 'Pruned'], [config.num_parameters, sum(p.numel() for p in pruned_model.parameters())]) ax.set_ylabel('Number of Parameters') ax.set_title('Model Size Comparison') buf = BytesIO() fig.savefig(buf, format='png') buf.seek(0) image_base64 = base64.b64encode(buf.read()).decode('utf-8') return f"Pruned model saved to Hugging Face Hub in repository {repo_id}", f"data:image/png;base64,{image_base64}", '\n'.join(log_messages) except Exception as e: error_message = f"Error: {e}" log_messages.append(error_message) logging.error(error_message) return error_message, None, '\n'.join(log_messages) # Merge-kit Pruning Function (adjust as needed) def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int, progress: gr.Progress) -> PreTrainedModel: """Prunes a model using a merge-kit approach. Args: model (PreTrainedModel): The model to be pruned. target_num_parameters (int): The target number of parameters after pruning. Returns: PreTrainedModel: The pruned model. """ total_params = sum(p.numel() for p in model.parameters()) amount = 1 - (target_num_parameters / total_params) for name, module in tqdm(model.named_modules(), desc='Pruning', file=sys.stdout): if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)): prune.random_unstructured(module, name='weight', amount=amount) progress(percent_complete=50) for name, module in model.named_modules(): if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)): prune.remove(module, name='weight') progress(percent_complete=100) return model # Function to create a Gradio interface def create_interface(): with gr.Blocks() as demo: gr.Markdown("## Create a Smaller LLM") llm_model_name = gr.Textbox(label="Choose a Large Language Model", placeholder="Enter the model name", interactive=True) target_size = gr.Slider(label="Target Model Size (%)", minimum=1, maximum=100, step=1, value=50, interactive=True) hf_write_token = gr.Textbox(label="Hugging Face Write Token", placeholder="Enter your HF write token", interactive=True, type="password") repo_name = gr.Textbox(label="Repository Name", placeholder="Enter the name of the repository", interactive=True) pruning_status = gr.Textbox(label="Pruning Status", interactive=False) prune_button = gr.Button("Prune Model") visualization = gr.Image(label="Model Size Comparison", interactive=False) logs_button = gr.Button("Show Logs") logs_output = gr.Textbox(label="Logs", interactive=False) progress_bar = gr.Progress() def show_logs(): with open('pruning.log', 'r') as log_file: logs = log_file.read() return logs logs_button.click(fn=show_logs, outputs=logs_output) def prune_model_with_progress(llm_model_name, target_size, hf_write_token, repo_name): return prune_model(llm_model_name, target_size, hf_write_token, repo_name, progress_bar) prune_button.click(fn=prune_model_with_progress, inputs=[llm_model_name, target_size, hf_write_token, repo_name], outputs=[pruning_status, visualization, logs_output]) text_input = gr.Textbox(label="Input Text") text_output = gr.Textbox(label="Generated Text") generate_button = gr.Button("Generate Text") def generate_text(text, repo_name, hf_write_token): try: tokenizer = AutoTokenizer.from_pretrained(repo_name, use_auth_token=hf_write_token) model = AutoModelForCausalLM.from_pretrained(repo_name, use_auth_token=hf_write_token) generator = pipeline('text-generation', model=model, tokenizer=tokenizer) generated_text = generator(text, max_length=50, num_beams=5, num_return_sequences=1)[0]['generated_text'] return generated_text except Exception as e: return f"Error: {e}" generate_button.click(fn=generate_text, inputs=[text_input, repo_name, hf_write_token], outputs=text_output) return demo # Create and launch the Gradio interface demo = create_interface() demo.launch()