import os
import re
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from repeng import ControlVector, ControlModel
import gradio as gr
# Initialize model and tokenizer
from huggingface_hub import login
# Initialize model and tokenizer
mistral_path = "mistralai/Mistral-7B-Instruct-v0.3"
#mistral_path = r"E:/language_models/models/mistral"
access_token = os.getenv("mistralaccesstoken")
login(access_token)
tokenizer = AutoTokenizer.from_pretrained(mistral_path)
tokenizer.pad_token_id = 0
model = AutoModelForCausalLM.from_pretrained(
mistral_path,
torch_dtype=torch.float16,
trust_remote_code=True,
use_safetensors=True
)
model = model.to("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Is CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
model = ControlModel(model, list(range(-5, -18, -1)))
# Generation settings
default_generation_settings = {
#"pad_token_id": tokenizer.eos_token_id, # Silence warning
"do_sample": False, # Deterministic output
"max_new_tokens": 384,
"repetition_penalty": 1.1, # Reduce repetition
}
# Tags for prompt formatting
user_tag, asst_tag = "[INST]", "[/INST]"
# List available control vectors
control_vector_files = [f for f in os.listdir('control_models') if f.endswith('.gguf')]
if not control_vector_files:
raise FileNotFoundError("No .gguf control vector files found in the control_models directory.")
# Function to toggle slider visibility based on checkbox state
def toggle_slider(checked):
return gr.update(visible=checked)
def construct_prompt(history, system_prompt, user_message):
"""
Converts the history (list of tuples) back into the string format Mistral expects
"""
formatted_prompt = ""
# [INST] user message[/INST] assistant message[INST] new user message[/INST]
# Mistral expects the history to be wrapped in history, so it's added here
if len(history) > 0:
formatted_prompt += ""
# Append the system prompt if provided
if system_prompt.strip():
formatted_prompt += f"{user_tag} {system_prompt}{asst_tag} "
# Construct the formatted prompt based on history
if len(history) > 0:
for turn in history:
user_msg, asst_msg = turn
asst_msg = asst_msg.split("\n")[1:]
formatted_prompt += f"{user_tag} {user_msg} {asst_tag} {asst_msg}"
if len(history) > 0:
formatted_prompt += ""
# Append the new user message
formatted_prompt += f"{user_tag} {user_message} {asst_tag}"
return formatted_prompt
def generate_response(system_prompt, user_message, history, max_new_tokens, repitition_penalty, do_sample, *args):
"""
Applies the control vectors and calls the language model.
Returns a list of tuples, the user message and the assistant response,
which Gradio uses to update the chatbot history
"""
# Separate checkboxes and sliders based on type
# The first x in args are the checkbox names (the file names)
# The second x in args are the slider values
checkboxes = []
sliders = []
for i in range(len(control_vector_files)):
checkboxes.append(args[i])
sliders.append(args[len(control_vector_files) + i])
# Apply selected control vectors with their corresponding weights
assistant_message_title = ""
control_vectors = []
for i in range(len(control_vector_files)):
if checkboxes[i]:
cv_file = control_vector_files[i]
weight = sliders[i]
try:
# Set the control vector's weight (and sign) by multiplying by its slider value
control_vectors.append(ControlVector.import_gguf(f"control_models/{cv_file}") * weight)
assistant_message_title += f"{cv_file.split('.')[0]}: {weight};"
except Exception as e:
print(f"Failed to set control vector {cv_file}: {e}")
# The control model takes a sum of positive and negative control vectors
model.reset()
combined_vector = None
for i in range(len(control_vectors)):
if combined_vector is None:
combined_vector = control_vectors[i]
else:
combined_vector += control_vectors[i]
# Set the combined set of vectors as the control for the model
try:
if combined_vector is not None:
model.set_control(combined_vector)
except Exception as e:
print(f"Failed to set Control: {e}")
formatted_prompt = construct_prompt(history, system_prompt, user_message)
# Tokenize the input
input_ids = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
generation_settings = {
"pad_token_id": tokenizer.eos_token_id,
"do_sample": do_sample,
"max_new_tokens": int(max_new_tokens),
"repetition_penalty": repetition_penalty.value,
}
_streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=False,)
generate_kwargs = dict(
input_ids,
streamer=_streamer,
pad_token_id= tokenizer.eos_token_id,
do_sample= do_sample,
max_new_tokens= int(max_new_tokens),
repetition_penalty= repetition_penalty.value,
)
t = threading.Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
# Display the response as it streams in, prepending the control vector info
partial_message = ""
for new_token in _streamer:
if new_token != '<' and new_token != '': # seems to hit EOS correctly without this needed
partial_message += new_token
partial_with_title = "*" + assistant_message_title + "*" + "\n\n" + partial_message
temp_history = history + [(user_message, partial_with_title)]
yield temp_history
else:
_streamer.end()
# remove the trailing if present
# it won't be present if the model ran out from max_tokens
def get_assistant_response(input_string):
if len(input_string) >= 4:
if input_string[-4:] == "":
return input_string[:-4]
else:
return input_string
else:
return input_string
# Update conversation history
assistant_response = get_assistant_response(partial_message)
assistant_response_display = f"*{assistant_message_title}*\n\n{assistant_response}"
# Update conversation history
history.append((user_message, assistant_response_display))
yield history
def generate_response_with_retry(system_prompt, user_message, history, max_new_tokens, repitition_penalty, do_sample, *args):
# Remove last user input and assistant response from history, then call generate_response()
if history:
history = history[0:-1]
for output in generate_response(system_prompt, user_message, history, max_new_tokens, repetition_penalty, do_sample, *args):
yield output
# Function to reset the conversation history
def reset_chat():
# returns a blank state
return [], ""
def get_checkboxes():
# rebuilding the list of checkboxes, so that these presets don't have to change
# when adding a new control model
checkbox_column = app.children[2].children[0].children
model_names_and_indexes = {}
checkbox_index = 0
for i in range(len(checkbox_column)):
if isinstance(checkbox_column[i], gr.Row):
try:
model_name = checkbox_column[i].children[0].children[0].label
model_names_and_indexes[model_name] = checkbox_index
checkbox_index += 1
except IndexError:
# allow for other rows to be in the interface
pass
except AttributeError:
pass
return model_names_and_indexes
def set_preset_helpful(*args):
# gets the list of all checkboxes and sliders
# sets checkboxes and sliders accordingly to this persona
# args is a list of checkboxes and then slider values
# must return the updated list of checkboxes and sliders
count_checkboxes = int(len(args)/2)
new_checkbox_values = []
new_slider_values = []
model_names_and_indexes = get_checkboxes()
for check in model_names_and_indexes:
if check == "Empathatic":
new_checkbox_values.append(True)
new_slider_values.append(1.0)
elif check == "Optimistic":
new_checkbox_values.append(True)
new_slider_values.append(1.0)
else:
new_checkbox_values.append(False)
new_slider_values.append(0.0)
return new_checkbox_values + new_slider_values
def set_preset_conspiracist(*args):
# gets the list of all checkboxes and sliders
# sets checkboxes and sliders accordingly to this persona
# args is a list of checkboxes and then slider values
# must return the updated list of checkboxes and sliders
new_checkbox_values = []
new_slider_values = []
model_names_and_indexes = get_checkboxes()
for check in model_names_and_indexes:
if check == "Conspiracies":
new_checkbox_values.append(True)
new_slider_values.append(1.5)
elif check == "Creative":
new_checkbox_values.append(True)
new_slider_values.append(1.0)
elif check == "Lazy":
new_checkbox_values.append(True)
new_slider_values.append(-0.5)
elif check == "Truthful":
new_checkbox_values.append(True)
new_slider_values.append(-1.0)
else:
new_checkbox_values.append(False)
new_slider_values.append(0.0)
return new_checkbox_values + new_slider_values
def set_preset_stoner(*args):
# gets the list of all checkboxes and sliders
# sets checkboxes and sliders accordingly to this persona
# args is a list of checkboxes and then slider values
# must return the updated list of checkboxes and sliders
new_checkbox_values = []
new_slider_values = []
model_names_and_indexes = get_checkboxes()
for check in model_names_and_indexes:
if check == "Angry":
new_checkbox_values.append(True)
new_slider_values.append(0.5)
elif check == "Right-leaning":
new_checkbox_values.append(True)
new_slider_values.append(-0.5)
elif check == "Tripping":
new_checkbox_values.append(True)
new_slider_values.append(0.6)
else:
new_checkbox_values.append(False)
new_slider_values.append(0.0)
return new_checkbox_values + new_slider_values
def set_preset_facts(*args):
# gets the list of all checkboxes and sliders
# sets checkboxes and sliders accordingly to this persona
# args is a list of checkboxes and then slider values
# must return the updated list of checkboxes and sliders
new_checkbox_values = []
new_slider_values = []
model_names_and_indexes = get_checkboxes()
for check in model_names_and_indexes:
if check == "Confident":
new_checkbox_values.append(True)
new_slider_values.append(0.5)
elif check == "Joking":
new_checkbox_values.append(True)
new_slider_values.append(-0.5)
elif check == "Lazy":
new_checkbox_values.append(True)
new_slider_values.append(-0.5)
elif check == "Truthful":
new_checkbox_values.append(True)
new_slider_values.append(0.5)
else:
new_checkbox_values.append(False)
new_slider_values.append(0.0)
return new_checkbox_values + new_slider_values
tooltip_css = """
/* Tooltip container */
.tooltip {
position: relative;
display: inline-block;
cursor: help;
}
/* Tooltip text */
.tooltip .tooltiptext {
visibility: hidden;
width: 200px;
background-color: #1f2937;
color: #f3f4f6;
text-align: left;
border-radius: 6px;
padding: 8px;
position: absolute;
z-index: 1;
bottom: 125%; /* Position above the element */
left: 50%;
margin-left: -100px;
opacity: 0;
transition: opacity 0.3s;
}
/* Tooltip arrow */
.tooltip .tooltiptext::after {
content: "";
position: absolute;
top: 100%; /* At the bottom of tooltip */
left: 50%;
margin-left: -5px;
border-width: 5px;
border-style: solid;
border-color: #1f2937 transparent transparent transparent;
}
/* Show the tooltip text when hovering */
.tooltip:hover .tooltiptext {
visibility: visible;
opacity: 1;"""
dark_theme = gr.Theme.from_hub("ParityError/Anime").set(
# body_background_fill= "url(https://image uri) #000000 no-repeat right bottom / auto 100svh padding-box fixed;",
# body_background_fill_dark= "url(https://image uri) #000000 no-repeat right bottom / auto 100svh padding-box fixed;",
)
with gr.Blocks(
theme=dark_theme,
css=tooltip_css,
) as app:
# Header
gr.Markdown("# 🧠 LLM Brain Control")
gr.Markdown("Usage demo: [link](https://example.com)")
with gr.Row():
# Left Column: Control Vectors and advanced settings
with gr.Column(scale=1):
gr.Markdown("### ⚡ Control Vectors")
control_vector_label = gr.HTML("""