File size: 6,760 Bytes
dd6c56a
8bb39cb
0445e3f
7730f68
dd6c56a
 
 
 
 
 
5796c7a
 
3cfae8b
0445e3f
 
8a2d207
dd6c56a
0445e3f
 
dd6c56a
3cfae8b
 
 
 
10fbcfe
3cfae8b
0445e3f
 
 
 
 
dd6c56a
0445e3f
 
5796c7a
 
 
8bb39cb
 
3cfae8b
8bb39cb
5796c7a
10fbcfe
 
0445e3f
5796c7a
 
 
 
3cfae8b
8a2d207
0445e3f
10fbcfe
 
5796c7a
3cfae8b
8bb39cb
 
 
 
 
5796c7a
0445e3f
 
 
5796c7a
 
10fbcfe
 
 
5796c7a
10fbcfe
5796c7a
10fbcfe
3cfae8b
10fbcfe
5796c7a
 
0445e3f
 
 
10fbcfe
dd6c56a
3b422da
8a2d207
dd6c56a
 
 
 
 
 
 
10fbcfe
0445e3f
10fbcfe
dd6c56a
10fbcfe
241160e
10fbcfe
 
dd6c56a
 
 
10fbcfe
 
 
dd6c56a
 
 
 
 
 
 
5796c7a
3cfae8b
8bb39cb
 
 
dd6c56a
8bb39cb
0445e3f
 
ae62561
0445e3f
 
10fbcfe
0445e3f
 
 
 
dd6c56a
ae62561
8a2d207
ae62561
 
dd6c56a
 
 
 
 
10fbcfe
5796c7a
8bb39cb
 
10fbcfe
 
5796c7a
 
 
dd6c56a
10fbcfe
dd6c56a
 
 
 
 
10fbcfe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import gradio as gr
from transformers import AutoModel, AutoTokenizer, pipeline, AutoConfig, AutoModelForCausalLM
from huggingface_hub import create_repo, HfApi, list_models
from transformers.modeling_utils import PreTrainedModel
import requests
import json
import os
import matplotlib.pyplot as plt
from io import BytesIO
import base64
import torch
from torch.nn.utils import prune
import subprocess
from tqdm import tqdm
import logging
import sys

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Ensure sentencepiece is installed
try:
    import sentencepiece
except ImportError:
    subprocess.check_call(['pip', 'install', 'sentencepiece'])

# Function to fetch open-weight LLM models
def fetch_open_weight_models():
    models = list_models()
    return models

# Function to prune a model using the "merge-kit" approach
def prune_model(llm_model_name, target_size, hf_write_token, repo_name, progress=gr.Progress(track_tqdm=True)):
    log_messages = []
    try:
        # Load the LLM model and tokenizer
        llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
        llm_model = AutoModelForCausalLM.from_pretrained(
            llm_model_name,
            torch_dtype=torch.float16,
        )

        log_messages.append('Model and tokenizer loaded successfully.')
        logging.info('Model and tokenizer loaded successfully.')

        # Get the model config
        config = AutoConfig.from_pretrained(llm_model_name)
        target_num_parameters = int(config.num_parameters * (target_size / 100))

        # Prune the model
        pruned_model = merge_kit_prune(llm_model, target_num_parameters, progress)

        log_messages.append('Model pruned successfully.')
        logging.info('Model pruned successfully.')

        # Save the pruned model
        api = HfApi()
        repo_id = f"{hf_write_token}/{repo_name}"
        create_repo(repo_id, token=hf_write_token, private=False, exist_ok=True)
        pruned_model.push_to_hub(repo_id, use_auth_token=hf_write_token)
        llm_tokenizer.push_to_hub(repo_id, use_auth_token=hf_write_token)

        log_messages.append(f"Pruned model saved to Hugging Face Hub in repository {repo_id}")
        logging.info(f"Pruned model saved to Hugging Face Hub in repository {repo_id}")

        # Create a visualization
        fig, ax = plt.subplots(figsize=(10, 5))
        ax.bar(['Original', 'Pruned'], [config.num_parameters, sum(p.numel() for p in pruned_model.parameters())])
        ax.set_ylabel('Number of Parameters')
        ax.set_title('Model Size Comparison')
        buf = BytesIO()
        fig.savefig(buf, format='png')
        buf.seek(0)
        image_base64 = base64.b64encode(buf.read()).decode('utf-8')

        return f"Pruned model saved to Hugging Face Hub in repository {repo_id}", f"data:image/png;base64,{image_base64}", '\n'.join(log_messages)

    except Exception as e:
        error_message = f"Error: {e}"
        log_messages.append(error_message)
        logging.error(error_message)
        return error_message, None, '\n'.join(log_messages)

# Merge-kit Pruning Function (adjust as needed)
def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int, progress: gr.Progress) -> PreTrainedModel:
    """Prunes a model using a merge-kit approach.
    Args:
        model (PreTrainedModel): The model to be pruned.
        target_num_parameters (int): The target number of parameters after pruning.
    Returns:
        PreTrainedModel: The pruned model.
    """
    
    total_params = sum(p.numel() for p in model.parameters())
    amount = 1 - (target_num_parameters / total_params) 

    for name, module in tqdm(model.named_modules(), desc='Pruning', file=sys.stdout):
        if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)): 
            prune.random_unstructured(module, name='weight', amount=amount)
            progress(percent_complete=50) 

    for name, module in model.named_modules():
        if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
            prune.remove(module, name='weight')
            progress(percent_complete=100)
    
    return model

# Function to create a Gradio interface
def create_interface():
    with gr.Blocks() as demo:
        gr.Markdown("## Create a Smaller LLM")

        llm_model_name = gr.Textbox(label="Choose a Large Language Model", placeholder="Enter the model name", interactive=True)
        target_size = gr.Slider(label="Target Model Size (%)", minimum=1, maximum=100, step=1, value=50, interactive=True)
        hf_write_token = gr.Textbox(label="Hugging Face Write Token", placeholder="Enter your HF write token", interactive=True, type="password")
        repo_name = gr.Textbox(label="Repository Name", placeholder="Enter the name of the repository", interactive=True)
        pruning_status = gr.Textbox(label="Pruning Status", interactive=False)
        prune_button = gr.Button("Prune Model")
        visualization = gr.Image(label="Model Size Comparison", interactive=False)
        logs_button = gr.Button("Show Logs")
        logs_output = gr.Textbox(label="Logs", interactive=False)
        progress_bar = gr.Progress()

        def show_logs():
            with open('pruning.log', 'r') as log_file:
                logs = log_file.read()
            return logs

        logs_button.click(fn=show_logs, outputs=logs_output)

        def prune_model_with_progress(llm_model_name, target_size, hf_write_token, repo_name):
            return prune_model(llm_model_name, target_size, hf_write_token, repo_name, progress_bar)

        prune_button.click(fn=prune_model_with_progress, inputs=[llm_model_name, target_size, hf_write_token, repo_name], outputs=[pruning_status, visualization, logs_output])

        text_input = gr.Textbox(label="Input Text")
        text_output = gr.Textbox(label="Generated Text")
        generate_button = gr.Button("Generate Text")

        def generate_text(text, repo_name, hf_write_token):
            try:
                tokenizer = AutoTokenizer.from_pretrained(repo_name, use_auth_token=hf_write_token)
                model = AutoModelForCausalLM.from_pretrained(repo_name, use_auth_token=hf_write_token)
                generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
                generated_text = generator(text, max_length=50, num_beams=5, num_return_sequences=1)[0]['generated_text']
                return generated_text
            except Exception as e:
                return f"Error: {e}"

        generate_button.click(fn=generate_text, inputs=[text_input, repo_name, hf_write_token], outputs=text_output)

    return demo

# Create and launch the Gradio interface
demo = create_interface()
demo.launch()