import os
import re
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from repeng import ControlVector, ControlModel
import gradio as gr
# Initialize model and tokenizer
from huggingface_hub import login
# Initialize model and tokenizer
mistral_path = "mistralai/Mistral-7B-Instruct-v0.3"
#mistral_path = r"E:/language_models/models/mistral"
access_token = os.getenv("mistralaccesstoken")
login(access_token)
tokenizer = AutoTokenizer.from_pretrained(mistral_path)
tokenizer.pad_token_id = 0
model = AutoModelForCausalLM.from_pretrained(
mistral_path,
torch_dtype=torch.float16,
trust_remote_code=True,
use_safetensors=True
)
model = model.to("cuda:0" if torch.cuda.is_available() else "cpu")
model = ControlModel(model, list(range(-5, -18, -1)))
# Generation settings
default_generation_settings = {
#"pad_token_id": tokenizer.eos_token_id, # Silence warning
"do_sample": False, # Deterministic output
"max_new_tokens": 384,
"repetition_penalty": 1.1, # Reduce repetition
}
# Tags for prompt formatting
user_tag, asst_tag = "[INST]", "[/INST]"
# List available control vectors
control_vector_files = [f for f in os.listdir('.') if f.endswith('.gguf')]
if not control_vector_files:
raise FileNotFoundError("No .gguf control vector files found in the current directory.")
# Function to toggle slider visibility based on checkbox state
def toggle_slider(checked):
return gr.update(visible=checked)
# Function to generate the model's response
def generate_response(system_prompt, user_message, history, max_new_tokens, repitition_penalty, do_sample, *args):
# Separate checkboxes and sliders based on type
# The first x in args are the checkbox names (the file names)
# The second x in args are the slider values
checkboxes = []
sliders = []
for i in range(len(control_vector_files)):
checkboxes.append(args[i])
sliders.append(args[len(control_vector_files) + i])
# Apply selected control vectors with their corresponding weights
assistant_message_title = ""
control_vectors = []
for i in range(len(control_vector_files)):
if checkboxes[i]:
cv_file = control_vector_files[i]
weight = sliders[i]
try:
# Set the control vector's weight (and sign) by multiplying by its slider value
control_vectors.append(ControlVector.import_gguf(cv_file) * weight)
assistant_message_title += f"{cv_file.split('.')[0]}: {weight};"
except Exception as e:
print(f"Failed to set control vector {cv_file}: {e}")
# The control model takes a sum of positive and negative control vectors
model.reset()
combined_vector = None
for i in range(len(control_vectors)):
if combined_vector is None:
combined_vector = control_vectors[i]
else:
combined_vector += control_vectors[i]
if combined_vector is not None:
model.set_control(combined_vector)
formatted_prompt = ""
# [INST] user message[/INST] assistant message[INST] new user message[/INST]
# Mistral expects the history to be wrapped in history
if len(history) > 0:
formatted_prompt += ""
# Append the system prompt if provided
if system_prompt.strip():
formatted_prompt += f"{user_tag} {system_prompt}{asst_tag} "
# Construct the formatted prompt based on history
#TODO move back to ChatMessage type instead of Tuple, because the message title gets into the history
if len(history) > 0:
for turn in history:
user_msg, asst_msg = turn
formatted_prompt += f"{user_tag} {user_msg} {asst_tag} {asst_msg}"
if len(history) > 0:
formatted_prompt += ""
# Append the new user message
formatted_prompt += f"{user_tag} {user_message} {asst_tag}"
# Tokenize the input
input_ids = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
generation_settings = {
"pad_token_id": tokenizer.eos_token_id,
"do_sample": do_sample,
"max_new_tokens": int(max_new_tokens),
"repetition_penalty": repetition_penalty.value,
}
# Generate the response
output_ids = model.generate(**input_ids, **generation_settings)
response = tokenizer.decode(output_ids.squeeze(), skip_special_tokens=False)
def get_assistant_response(input_string):
# Use regex to find the text between the final [/INST] tag and
pattern = r'\[/INST\](?!.*\[/INST\])\s*(.*?)(?:|$)'
match = re.search(pattern, input_string, re.DOTALL)
if match:
return match.group(1).strip()
return None
assistant_response = get_assistant_response(response)
# Update conversation history
assistant_response = get_assistant_response(response)
assistant_response_display = f"*{assistant_message_title}*\n\n{assistant_response}"
# Update conversation history
history.append((user_message, assistant_response_display))
return history
def generate_response_with_retry(system_prompt, user_message, history, max_new_tokens, repitition_penalty, do_sample, *args):
# Remove last user input and assistant response from history, then call generate_response()
if history:
history = history[0:-1]
return generate_response(system_prompt, user_message, history, max_new_tokens, repitition_penalty, do_sample, *args)
# Function to reset the conversation history
def reset_chat():
# returns a blank state
return [], []
# I'm not a good enough coder with Python and Gradio to figure out how to generalize this. PRs accepted!
def set_preset_helpful(*args):
# gets the list of all checkboxes and sliders
# sets checkboxes and sliders accordingly to this persona
# args is a list of checkboxes and then slider values
# must return the updated list of checkboxes and sliders
count_checkboxes = int(len(args)/2)
new_checkbox_values = []
new_slider_values = []
for i in range(count_checkboxes):
if i == 4:
new_checkbox_values.append(True)
# set slider value (sliders are after the checkboxes)
new_slider_values.append(1.0)
elif i == 7:
new_checkbox_values.append(True)
# set slider value (sliders are after the checkboxes)
new_slider_values.append(1.0)
else:
new_checkbox_values.append(False)
new_slider_values.append(0.0)
return new_checkbox_values + new_slider_values
def set_preset_conspiracist(*args):
# gets the list of all checkboxes and sliders
# sets checkboxes and sliders accordingly to this persona
# args is a list of checkboxes and then slider values
# must return the updated list of checkboxes and sliders
count_checkboxes = int(len(args)/2)
new_checkbox_values = []
new_slider_values = []
for i in range(count_checkboxes):
if i == 2:
new_checkbox_values.append(True)
# set slider value (sliders are after the checkboxes)
new_slider_values.append(1.5)
elif i == 3:
new_checkbox_values.append(True)
# set slider value (sliders are after the checkboxes)
new_slider_values.append(1.0)
elif i == 6:
new_checkbox_values.append(True)
# set slider value (sliders are after the checkboxes)
new_slider_values.append(-0.5)
elif i == 10:
new_checkbox_values.append(True)
# set slider value (sliders are after the checkboxes)
new_slider_values.append(-1.0)
else:
new_checkbox_values.append(False)
new_slider_values.append(0.0)
return new_checkbox_values + new_slider_values
def set_preset_stoner(*args):
# gets the list of all checkboxes and sliders
# sets checkboxes and sliders accordingly to this persona
# args is a list of checkboxes and then slider values
# must return the updated list of checkboxes and sliders
count_checkboxes = int(len(args)/2)
new_checkbox_values = []
new_slider_values = []
for i in range(count_checkboxes):
if i == 0:
new_checkbox_values.append(True)
# set slider value (sliders are after the checkboxes)
new_slider_values.append(0.5)
elif i == 8:
new_checkbox_values.append(True)
# set slider value (sliders are after the checkboxes)
new_slider_values.append(-0.5)
elif i == 9:
new_checkbox_values.append(True)
# set slider value (sliders are after the checkboxes)
new_slider_values.append(0.6)
else:
new_checkbox_values.append(False)
new_slider_values.append(0.0)
return new_checkbox_values + new_slider_values
def set_preset_facts(*args):
# gets the list of all checkboxes and sliders
# sets checkboxes and sliders accordingly to this persona
# args is a list of checkboxes and then slider values
# must return the updated list of checkboxes and sliders
count_checkboxes = int(len(args)/2)
new_checkbox_values = []
new_slider_values = []
for i in range(count_checkboxes):
if i == 1:
new_checkbox_values.append(True)
# set slider value (sliders are after the checkboxes)
new_slider_values.append(0.5)
elif i == 5:
new_checkbox_values.append(True)
# set slider value (sliders are after the checkboxes)
new_slider_values.append(-0.5)
elif i == 6:
new_checkbox_values.append(True)
# set slider value (sliders are after the checkboxes)
new_slider_values.append(-0.5)
elif i == 10:
new_checkbox_values.append(True)
# set slider value (sliders are after the checkboxes)
new_slider_values.append(0.5)
else:
new_checkbox_values.append(False)
new_slider_values.append(0.0)
return new_checkbox_values + new_slider_values
tooltip_css = """
/* Tooltip container */
.tooltip {
position: relative;
display: inline-block;
cursor: help;
}
/* Tooltip text */
.tooltip .tooltiptext {
visibility: hidden;
width: 200px;
background-color: #1f2937;
color: #f3f4f6;
text-align: left;
border-radius: 6px;
padding: 8px;
position: absolute;
z-index: 1;
bottom: 125%; /* Position above the element */
left: 50%;
margin-left: -100px;
opacity: 0;
transition: opacity 0.3s;
}
/* Tooltip arrow */
.tooltip .tooltiptext::after {
content: "";
position: absolute;
top: 100%; /* At the bottom of tooltip */
left: 50%;
margin-left: -5px;
border-width: 5px;
border-style: solid;
border-color: #1f2937 transparent transparent transparent;
}
/* Show the tooltip text when hovering */
.tooltip:hover .tooltiptext {
visibility: visible;
opacity: 1;"""
dark_theme = gr.Theme.from_hub("ParityError/Anime").set(
# body_background_fill= "url(https://image uri) #000000 no-repeat right bottom / auto 100svh padding-box fixed;",
# body_background_fill_dark= "url(https://image uri) #000000 no-repeat right bottom / auto 100svh padding-box fixed;",
)
with gr.Blocks(
theme=dark_theme,
css=tooltip_css,
) as app:
# Header
gr.Markdown("# 🧠 LLM Brain Control")
gr.Markdown("Usage demo: [link](https://example.com)")
with gr.Row():
# Left Column: Control Vectors and advanced settings
with gr.Column(scale=1):
gr.Markdown("### ⚡ Control Vectors")
control_vector_label = gr.HTML("""