import os
import re
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from repeng import ControlVector, ControlModel
import gradio as gr
# Initialize model and tokenizer
from huggingface_hub import login
# Initialize model and tokenizer
mistral_path = "mistralai/Mistral-7B-Instruct-v0.3"
#mistral_path = r"E:/language_models/models/mistral"
access_token = os.getenv("mistralaccesstoken")
login(access_token)
tokenizer = AutoTokenizer.from_pretrained(mistral_path)
tokenizer.pad_token_id = 0
model = AutoModelForCausalLM.from_pretrained(
mistral_path,
torch_dtype=torch.float16,
trust_remote_code=True,
use_safetensors=True
)
model = model.to("cuda:0" if torch.cuda.is_available() else "cpu")
model = ControlModel(model, list(range(-5, -18, -1)))
# Generation settings
default_generation_settings = {
"pad_token_id": tokenizer.eos_token_id, # Silence warning
"do_sample": False, # Deterministic output
"max_new_tokens": 384,
"repetition_penalty": 1.1, # Reduce repetition
}
# Tags for prompt formatting
user_tag, asst_tag = "[INST]", "[/INST]"
# List available control vectors
control_vector_files = [f for f in os.listdir('.') if f.endswith('.gguf')]
if not control_vector_files:
raise FileNotFoundError("No .gguf control vector files found in the current directory.")
# Function to toggle slider visibility based on checkbox state
def toggle_slider(checked):
return gr.update(visible=checked)
# Function to generate the model's response
def generate_response(system_prompt, user_message, history, max_new_tokens, repitition_penalty, *args):
checkboxes = []
sliders = []
#inputs_list = [system_prompt, user_input, chatbot, max_new_tokens, repetition_penalty] + control_checks + control_sliders
# Separate checkboxes and sliders based on type
# The first x in args are the checkbox names (the file names)
# The second x in args are the slider values
for i in range(len(control_vector_files)):
checkboxes.append(args[i])
sliders.append(args[len(control_vector_files) + i])
if len(checkboxes) != len(control_vector_files) or len(sliders) != len(control_vector_files):
return history if history else [], history if history else []
# Reset any previous control vectors
model.reset()
# Apply selected control vectors with their corresponding weights
assistant_message_title = ""
for i in range(len(control_vector_files)):
if checkboxes[i]:
cv_file = control_vector_files[i]
weight = sliders[i]
try:
control_vector = ControlVector.import_gguf(cv_file)
model.set_control(control_vector, weight)
assistant_message_title += f"{cv_file}: {weight};"
except Exception as e:
print(f"Failed to set control vector {cv_file}: {e}")
formatted_prompt = ""
# [INST] user message[/INST] assistant message[INST] new user message[/INST]
# Mistral expects the history to be wrapped in history
if len(history) > 0:
formatted_prompt += ""
# Append the system prompt if provided
if system_prompt.strip():
formatted_prompt += f"{user_tag} {system_prompt}{asst_tag} "
# Construct the formatted prompt based on history
if len(history) > 0:
for turn in history:
user_msg, asst_msg = turn
formatted_prompt += f"{user_tag} {user_msg} {asst_tag} {asst_msg}"
if len(history) > 0:
formatted_prompt += ""
# Append the new user message
formatted_prompt += f"{user_tag} {user_message} {asst_tag}"
# Tokenize the input
input_ids = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
generation_settings = {
"pad_token_id": tokenizer.eos_token_id,
"do_sample": default_generation_settings["do_sample"],
"max_new_tokens": int(max_new_tokens),
"repetition_penalty": repetition_penalty.value,
}
# Generate the response
output_ids = model.generate(**input_ids, **generation_settings)
response = tokenizer.decode(output_ids.squeeze(), skip_special_tokens=False)
def get_assistant_response(input_string):
# Use regex to find the text between the final [/INST] tag and
pattern = r'\[/INST\](?!.*\[/INST\])\s*(.*?)(?:|$)'
match = re.search(pattern, input_string, re.DOTALL)
if match:
return match.group(1).strip()
return None
assistant_response = get_assistant_response(response)
# Update conversation history
assistant_response = get_assistant_response(response)
assistant_response_display = f"*{assistant_message_title}*\n\n{assistant_response}"
# Update conversation history
history.append((user_message, assistant_response_display))
return history
def generate_response_with_retry(system_prompt, user_message, history, max_new_tokens, repitition_penalty, *args):
# Remove last user input and assistant response from history, then call generate_response()
if history:
history = history[0:-1]
return generate_response(system_prompt, user_message, history, max_new_tokens, repitition_penalty, *args)
# Function to reset the conversation history
def reset_chat():
# returns a blank user input text and a blank conversation history
return [], []
# Build the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# 🧠 LLM Brain Control")
gr.Markdown("Usage demo: (link)")
with gr.Row():
# Left Column: Settings and Control Vectors
with gr.Column(scale=1):
gr.Markdown("### ⚙️ Settings")
# System Prompt Input
system_prompt = gr.Textbox(
label="System Prompt",
lines=2,
value="Respond to the user concisely"
)
gr.Markdown("### ⚡ Control Vectors")
gr.Markdown("Select how you want to control the LLM. Start with +/- 1.0. Stronger values may overload it.")
# Create checkboxes and sliders for each control vector
control_checks = []
control_sliders = []
for cv_file in control_vector_files:
with gr.Row():
# Checkbox to select the control vector
checkbox = gr.Checkbox(label=cv_file, value=False)
control_checks.append(checkbox)
# Slider to adjust the control vector's weight
slider = gr.Slider(
minimum=-2.5,
maximum=2.5,
value=0.0,
step=0.1,
label=f"{cv_file} Weight",
visible=False
)
control_sliders.append(slider)
# Link the checkbox to toggle slider visibility
checkbox.change(
toggle_slider,
inputs=checkbox,
outputs=slider
)
# Advanced Settings Section (collapsed by default)
with gr.Accordion("🔧 Advanced Settings", open=False):
with gr.Row():
max_new_tokens = gr.Number(
label="Max Response Length (in tokens)",
value=default_generation_settings["max_new_tokens"],
precision=0,
step=10,
)
repetition_penalty = gr.Number(
label="Repetition Penalty",
value=default_generation_settings["repetition_penalty"],
precision=2,
step=0.1
)
# Right Column: Chat Interface
with gr.Column(scale=2):
gr.Markdown("### 🗨️ Conversation")
# Chatbot to display conversation
chatbot = gr.Chatbot(label="Conversation", type='tuples')
# User Message Input
user_input = gr.Textbox(
label="Your Message (Shift+Enter submits)",
lines=2,
placeholder="I was out partying too late last night, and I'm going to be late for work. What should I tell my boss?"
)
with gr.Row():
# Submit and New Chat buttons
submit_button = gr.Button("💬 Submit")
retry_button = gr.Button("🔃 Retry last turn")
new_chat_button = gr.Button("🌟 New Chat")
inputs_list = [system_prompt, user_input, chatbot, max_new_tokens, repetition_penalty] + control_checks + control_sliders
# Define button actions
submit_button.click(
generate_response,
inputs=inputs_list,
outputs=[chatbot]
)
user_input.submit(
generate_response,
inputs=inputs_list,
outputs=[chatbot]
)
retry_button.click(
generate_response_with_retry,
inputs=inputs_list,
outputs=[chatbot]
)
new_chat_button.click(
reset_chat,
inputs=[],
outputs=[chatbot, user_input]
)
# Launch the Gradio app
if __name__ == "__main__":
demo.launch()