File size: 9,412 Bytes
f655011
453c7fc
f655011
 
 
 
129904f
 
95b2105
f655011
 
eefa812
 
f655011
5a498e2
 
95b2105
 
f655011
 
 
 
 
 
 
 
 
 
 
 
9acb8e6
f655011
 
d646867
f655011
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
453c7fc
9acb8e6
 
 
453c7fc
 
129904f
453c7fc
 
 
 
 
129904f
 
9acb8e6
129904f
f655011
 
129904f
f655011
d646867
129904f
f655011
 
 
 
 
 
d646867
f655011
 
129904f
 
453c7fc
d646867
453c7fc
 
 
129904f
 
f655011
d646867
453c7fc
 
 
 
60a54bb
 
 
453c7fc
 
 
129904f
 
453c7fc
129904f
f655011
 
129904f
9acb8e6
 
 
 
453c7fc
9acb8e6
 
f655011
453c7fc
 
 
 
 
 
 
 
 
 
 
 
129904f
f655011
d646867
 
 
 
 
453c7fc
f655011
d646867
 
 
 
 
 
f655011
 
453c7fc
129904f
f655011
 
 
60a54bb
d646867
129904f
f655011
129904f
f655011
129904f
 
f655011
 
 
 
d646867
f655011
129904f
d646867
5a498e2
129904f
f655011
 
 
d646867
f655011
 
 
 
 
129904f
f655011
 
 
 
 
 
 
 
 
 
129904f
f655011
 
 
 
 
 
129904f
d646867
9acb8e6
 
 
d646867
9acb8e6
 
 
 
 
 
 
 
d646867
9acb8e6
 
129904f
f655011
129904f
 
f655011
d646867
f631e46
 
 
d646867
f631e46
d646867
f631e46
129904f
 
 
 
d646867
60a54bb
d646867
129904f
453c7fc
129904f
f655011
 
 
453c7fc
 
f655011
d646867
 
 
 
 
 
 
 
 
 
 
 
9acb8e6
f655011
 
 
453c7fc
f655011
 
 
 
129904f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
import os
import re
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from repeng import ControlVector, ControlModel
import gradio as gr

# Initialize model and tokenizer
from huggingface_hub import login

# Initialize model and tokenizer
mistral_path = "mistralai/Mistral-7B-Instruct-v0.3"
#mistral_path = r"E:/language_models/models/mistral"

access_token = os.getenv("mistralaccesstoken")
login(access_token)

tokenizer = AutoTokenizer.from_pretrained(mistral_path)
tokenizer.pad_token_id = 0

model = AutoModelForCausalLM.from_pretrained(
    mistral_path,
    torch_dtype=torch.float16,
    trust_remote_code=True,
    use_safetensors=True
)
model = model.to("cuda:0" if torch.cuda.is_available() else "cpu")
model = ControlModel(model, list(range(-5, -18, -1)))

# Generation settings
default_generation_settings = {
    "pad_token_id": tokenizer.eos_token_id,  # Silence warning
    "do_sample": False,                      # Deterministic output
    "max_new_tokens": 384,
    "repetition_penalty": 1.1,              # Reduce repetition
}

# Tags for prompt formatting
user_tag, asst_tag = "[INST]", "[/INST]"

# List available control vectors
control_vector_files = [f for f in os.listdir('.') if f.endswith('.gguf')]

if not control_vector_files:
    raise FileNotFoundError("No .gguf control vector files found in the current directory.")

# Function to toggle slider visibility based on checkbox state
def toggle_slider(checked):
    return gr.update(visible=checked)

# Function to generate the model's response
def generate_response(system_prompt, user_message, history, max_new_tokens, repitition_penalty, *args):
    checkboxes = []
    sliders = []

    #inputs_list = [system_prompt, user_input, chatbot, max_new_tokens, repetition_penalty] + control_checks + control_sliders

    # Separate checkboxes and sliders based on type
    # The first x in args are the checkbox names (the file names)
    # The second x in args are the slider values
    for i in range(len(control_vector_files)):
        checkboxes.append(args[i])
        sliders.append(args[len(control_vector_files) + i])

    if len(checkboxes) != len(control_vector_files) or len(sliders) != len(control_vector_files):
        return history if history else [], history if history else []

    # Reset any previous control vectors
    model.reset()

    # Apply selected control vectors with their corresponding weights
    assistant_message_title = ""
    for i in range(len(control_vector_files)):
        if checkboxes[i]:
            cv_file = control_vector_files[i]
            weight = sliders[i]
            try:
                control_vector = ControlVector.import_gguf(cv_file)
                model.set_control(control_vector, weight)
                assistant_message_title += f"{cv_file}: {weight};"
            except Exception as e:
                print(f"Failed to set control vector {cv_file}: {e}")

    formatted_prompt = ""

    # <s>[INST] user message[/INST] assistant message</s>[INST] new user message[/INST]
    # Mistral expects the history to be wrapped in <s>history</s>
    if len(history) > 0:
        formatted_prompt += "<s>"

    # Append the system prompt if provided
    if system_prompt.strip():
        formatted_prompt += f"{user_tag} {system_prompt}{asst_tag} "

    # Construct the formatted prompt based on history
    if len(history) > 0:
        for turn in history:
            user_msg, asst_msg = turn
            formatted_prompt += f"{user_tag} {user_msg} {asst_tag} {asst_msg}"

    
    if len(history) > 0:
        formatted_prompt += "</s>"

    # Append the new user message
    formatted_prompt += f"{user_tag} {user_message} {asst_tag}"

    # Tokenize the input
    input_ids = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)

    generation_settings = {
        "pad_token_id": tokenizer.eos_token_id,
        "do_sample": default_generation_settings["do_sample"],
        "max_new_tokens": int(max_new_tokens),
        "repetition_penalty": repetition_penalty.value,
    }

    # Generate the response
    output_ids = model.generate(**input_ids, **generation_settings)
    response = tokenizer.decode(output_ids.squeeze(), skip_special_tokens=False)

    def get_assistant_response(input_string):
        # Use regex to find the text between the final [/INST] tag and </s>
        pattern = r'\[/INST\](?!.*\[/INST\])\s*(.*?)(?:</s>|$)'
        match = re.search(pattern, input_string, re.DOTALL)
        if match:
            return match.group(1).strip()
        return None
    
    assistant_response = get_assistant_response(response)

    # Update conversation history
    assistant_response = get_assistant_response(response)
    assistant_response_display = f"*{assistant_message_title}*\n\n{assistant_response}"

    # Update conversation history
    history.append((user_message, assistant_response_display))
    return history

def generate_response_with_retry(system_prompt, user_message, history, max_new_tokens, repitition_penalty, *args):
    # Remove last user input and assistant response from history, then call generate_response()
    if history:
        history = history[0:-1]
    return generate_response(system_prompt, user_message, history, max_new_tokens, repitition_penalty, *args)

# Function to reset the conversation history
def reset_chat():
    # returns a blank user input text and a blank conversation history
    return [], []

# Build the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# 🧠 LLM Brain Control")
    gr.Markdown("Usage demo: (link)")

    with gr.Row():
        # Left Column: Settings and Control Vectors
        with gr.Column(scale=1):
            gr.Markdown("### ⚙️ Settings")

            # System Prompt Input
            system_prompt = gr.Textbox(
                label="System Prompt",
                lines=2,
                value="Respond to the user concisely"
            )

            gr.Markdown("### ⚡ Control Vectors")
            gr.Markdown("Select how you want to control the LLM. Start with +/- 1.0. Stronger values may overload it.")

            # Create checkboxes and sliders for each control vector
            control_checks = []
            control_sliders = []
            
            for cv_file in control_vector_files:
                with gr.Row():
                    # Checkbox to select the control vector
                    checkbox = gr.Checkbox(label=cv_file, value=False)
                    control_checks.append(checkbox)

                    # Slider to adjust the control vector's weight
                    slider = gr.Slider(
                        minimum=-2.5,
                        maximum=2.5,
                        value=0.0,
                        step=0.1,
                        label=f"{cv_file} Weight",
                        visible=False
                    )
                    control_sliders.append(slider)

                    # Link the checkbox to toggle slider visibility
                    checkbox.change(
                        toggle_slider,
                        inputs=checkbox,
                        outputs=slider
                    )

            # Advanced Settings Section (collapsed by default)
            with gr.Accordion("🔧 Advanced Settings", open=False):
                with gr.Row():
                    max_new_tokens = gr.Number(
                        label="Max Response Length (in tokens)",
                        value=default_generation_settings["max_new_tokens"],
                        precision=0,
                        step=10,
                    )
                    repetition_penalty = gr.Number(
                        label="Repetition Penalty",
                        value=default_generation_settings["repetition_penalty"],
                        precision=2,
                        step=0.1
                    )

        # Right Column: Chat Interface
        with gr.Column(scale=2):
            gr.Markdown("### 🗨️ Conversation")

            # Chatbot to display conversation
            chatbot = gr.Chatbot(label="Conversation", type='tuples')

            # User Message Input
            user_input = gr.Textbox(
                label="Your Message (Shift+Enter submits)",
                lines=2,
                placeholder="I was out partying too late last night, and I'm going to be late for work. What should I tell my boss?"
            )

            with gr.Row():
                # Submit and New Chat buttons
                submit_button = gr.Button("💬 Submit")
                retry_button = gr.Button("🔃 Retry last turn")
                new_chat_button = gr.Button("🌟 New Chat")
            

    inputs_list = [system_prompt, user_input, chatbot, max_new_tokens, repetition_penalty] + control_checks + control_sliders

    # Define button actions
    submit_button.click(
        generate_response,
        inputs=inputs_list,
        outputs=[chatbot]
    )

    user_input.submit(
        generate_response,
        inputs=inputs_list,
        outputs=[chatbot]
    )

    retry_button.click(
        generate_response_with_retry,
        inputs=inputs_list,
        outputs=[chatbot]
    )
    
    new_chat_button.click(
        reset_chat,
        inputs=[],
        outputs=[chatbot, user_input]
    )

# Launch the Gradio app
if __name__ == "__main__":
    demo.launch()