File size: 5,062 Bytes
dd6c56a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline, AutoConfig
from huggingface_hub import cached_download, hf_hub_url, list_models
import requests
import json
import os
import matplotlib.pyplot as plt
from io import BytesIO
import base64
from transformers.models.auto import AutoModel
from transformers.modeling_utils import PreTrainedModel
import torch
from torch.nn.utils import prune

# Function to fetch open-weight LLM models
def fetch_open_weight_models():
    models = list_models(filter="open-weight", sort="downloads", limit=12)
    return [model["id"] for model in models]

# Function to prune a model using the "merge-kit" approach
def prune_model(llm_model_name, target_size, output_dir):
    # Load the LLM model and tokenizer
    llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
    llm_model = AutoModelForSeq2SeqLM.from_pretrained(llm_model_name)

    # Get the model config
    config = AutoConfig.from_pretrained(llm_model_name)
    # Calculate the target number of parameters
    target_num_parameters = int(config.num_parameters * (target_size / 100))

    # Use merge-kit to prune the model
    pruned_model = merge_kit_prune(llm_model, target_num_parameters)

    # Save the pruned model
    pruned_model.save_pretrained(output_dir)

    # Create a visualization
    fig, ax = plt.subplots(figsize=(10, 5))
    ax.bar(["Original", "Pruned"], [config.num_parameters, pruned_model.num_parameters])
    ax.set_ylabel("Number of Parameters")
    ax.set_title("Model Size Comparison")
    buf = BytesIO()
    fig.savefig(buf, format="png")
    buf.seek(0)
    image_base64 = base64.b64encode(buf.read()).decode("utf-8")
    return f"Pruned model saved to {output_dir}", f"data:image/png;base64,{image_base64}"

# Merge-kit Pruning Function
def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int) -> PreTrainedModel:
    """Prunes a model using a merge-kit approach.

    Args:
        model (PreTrainedModel): The model to be pruned.
        target_num_parameters (int): The target number of parameters after pruning.

    Returns:
        PreTrainedModel: The pruned model.
    """

    # Define the pruning method
    pruning_method = "unstructured"

    # Calculate the pruning amount
    amount = 1 - (target_num_parameters / model.num_parameters)

    # Prune the model using the selected method
    for name, module in model.named_modules():
        if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
            prune.random_unstructured(module, name="weight", amount=amount)

    # Remove the pruned weights
    for name, module in model.named_modules():
        if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
            prune.remove(module, name="weight")

    return model

# Function to create a Gradio interface
def create_interface():
    with gr.Blocks() as demo:
        gr.Markdown("## Create a Smaller LLM")

        # Fetch open-weight models from Hugging Face
        available_models = gr.Dropdown(
            label="Choose a Large Language Model",
            choices=fetch_open_weight_models(),
            interactive=True,
        )

        # Input for target model size
        target_size = gr.Slider(
            label="Target Model Size (%)",
            minimum=1,
            maximum=100,
            step=1,
            value=50,
            interactive=True,
        )

        # Output for pruning status
        pruning_status = gr.Textbox(label="Pruning Status")

        # Output for saving the model
        save_model_path = gr.Textbox(label="Save Model Path", placeholder="Path to save the pruned model", interactive=True)

        # Button to start pruning
        prune_button = gr.Button("Prune Model")

        # Output for visualization
        visualization = gr.Image(label="Model Size Comparison")

        # Connect components
        prune_button.click(
            fn=prune_model,
            inputs=[available_models, target_size, save_model_path],
            outputs=[pruning_status, visualization],
        )

        # Example usage of the pruned model (optional)
        text_input = gr.Textbox(label="Input Text")
        text_output = gr.Textbox(label="Generated Text")

        # Generate text button
        generate_button = gr.Button("Generate Text")

        def generate_text(text, model_path):
            # Load the pruned model and tokenizer
            tokenizer = AutoTokenizer.from_pretrained(model_path)
            model = AutoModelForSeq2SeqLM.from_pretrained(model_path)

            # Use the pipeline for text generation
            generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
            generated_text = generator(text, max_length=50, num_beams=5, num_return_sequences=1)[0]["generated_text"]
            return generated_text

        generate_button.click(fn=generate_text, inputs=[text_input, save_model_path], outputs=text_output)

    return demo

# Create and launch the Gradio interface
demo = create_interface()
demo.launch(share=True)