import os import re import torch from transformers import AutoModelForCausalLM, AutoTokenizer from repeng import ControlVector, ControlModel import gradio as gr # Initialize model and tokenizer from huggingface_hub import login # Initialize model and tokenizer mistral_path = "mistralai/Mistral-7B-Instruct-v0.3" #mistral_path = r"E:/language_models/models/mistral" access_token = os.getenv("mistralaccesstoken") login(access_token) tokenizer = AutoTokenizer.from_pretrained(mistral_path) tokenizer.pad_token_id = 0 model = AutoModelForCausalLM.from_pretrained( mistral_path, torch_dtype=torch.float16, trust_remote_code=True, use_safetensors=True ) model = model.to("cuda:0" if torch.cuda.is_available() else "cpu") model = ControlModel(model, list(range(-5, -18, -1))) # Generation settings default_generation_settings = { "pad_token_id": tokenizer.eos_token_id, # Silence warning "do_sample": False, # Deterministic output "max_new_tokens": 384, "repetition_penalty": 1.1, # Reduce repetition } # Tags for prompt formatting user_tag, asst_tag = "[INST]", "[/INST]" # List available control vectors control_vector_files = [f for f in os.listdir('.') if f.endswith('.gguf')] if not control_vector_files: raise FileNotFoundError("No .gguf control vector files found in the current directory.") # Function to toggle slider visibility based on checkbox state def toggle_slider(checked): return gr.update(visible=checked) # Function to generate the model's response def generate_response(system_prompt, user_message, history, max_new_tokens, repitition_penalty, *args): checkboxes = [] sliders = [] #inputs_list = [system_prompt, user_input, chatbot, max_new_tokens, repetition_penalty] + control_checks + control_sliders # Separate checkboxes and sliders based on type # The first x in args are the checkbox names (the file names) # The second x in args are the slider values for i in range(len(control_vector_files)): checkboxes.append(args[i]) sliders.append(args[len(control_vector_files) + i]) if len(checkboxes) != len(control_vector_files) or len(sliders) != len(control_vector_files): return history if history else [], history if history else [] # Reset any previous control vectors model.reset() # Apply selected control vectors with their corresponding weights assistant_message_title = "" for i in range(len(control_vector_files)): if checkboxes[i]: cv_file = control_vector_files[i] weight = sliders[i] try: control_vector = ControlVector.import_gguf(cv_file) model.set_control(control_vector, weight) assistant_message_title += f"{cv_file}: {weight};" except Exception as e: print(f"Failed to set control vector {cv_file}: {e}") formatted_prompt = "" # [INST] user message[/INST] assistant message[INST] new user message[/INST] # Mistral expects the history to be wrapped in history if len(history) > 0: formatted_prompt += "" # Append the system prompt if provided if system_prompt.strip(): formatted_prompt += f"{user_tag} {system_prompt}{asst_tag} " # Construct the formatted prompt based on history if len(history) > 0: for turn in history: user_msg, asst_msg = turn formatted_prompt += f"{user_tag} {user_msg} {asst_tag} {asst_msg}" if len(history) > 0: formatted_prompt += "" # Append the new user message formatted_prompt += f"{user_tag} {user_message} {asst_tag}" # Tokenize the input input_ids = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) generation_settings = { "pad_token_id": tokenizer.eos_token_id, "do_sample": default_generation_settings["do_sample"], "max_new_tokens": int(max_new_tokens), "repetition_penalty": repetition_penalty.value, } # Generate the response output_ids = model.generate(**input_ids, **generation_settings) response = tokenizer.decode(output_ids.squeeze(), skip_special_tokens=False) def get_assistant_response(input_string): # Use regex to find the text between the final [/INST] tag and pattern = r'\[/INST\](?!.*\[/INST\])\s*(.*?)(?:|$)' match = re.search(pattern, input_string, re.DOTALL) if match: return match.group(1).strip() return None assistant_response = get_assistant_response(response) # Update conversation history assistant_response = get_assistant_response(response) assistant_response_display = f"*{assistant_message_title}*\n\n{assistant_response}" # Update conversation history history.append((user_message, assistant_response_display)) return history def generate_response_with_retry(system_prompt, user_message, history, max_new_tokens, repitition_penalty, *args): # Remove last user input and assistant response from history, then call generate_response() if history: history = history[0:-1] return generate_response(system_prompt, user_message, history, max_new_tokens, repitition_penalty, *args) # Function to reset the conversation history def reset_chat(): # returns a blank user input text and a blank conversation history return [], [] # Build the Gradio interface with gr.Blocks() as demo: gr.Markdown("# 🧠 LLM Brain Control") gr.Markdown("Usage demo: (link)") with gr.Row(): # Left Column: Settings and Control Vectors with gr.Column(scale=1): gr.Markdown("### ⚙️ Settings") # System Prompt Input system_prompt = gr.Textbox( label="System Prompt", lines=2, value="Respond to the user concisely" ) gr.Markdown("### ⚡ Control Vectors") gr.Markdown("Select how you want to control the LLM. Start with +/- 1.0. Stronger values may overload it.") # Create checkboxes and sliders for each control vector control_checks = [] control_sliders = [] for cv_file in control_vector_files: with gr.Row(): # Checkbox to select the control vector checkbox = gr.Checkbox(label=cv_file, value=False) control_checks.append(checkbox) # Slider to adjust the control vector's weight slider = gr.Slider( minimum=-2.5, maximum=2.5, value=0.0, step=0.1, label=f"{cv_file} Weight", visible=False ) control_sliders.append(slider) # Link the checkbox to toggle slider visibility checkbox.change( toggle_slider, inputs=checkbox, outputs=slider ) # Advanced Settings Section (collapsed by default) with gr.Accordion("🔧 Advanced Settings", open=False): with gr.Row(): max_new_tokens = gr.Number( label="Max Response Length (in tokens)", value=default_generation_settings["max_new_tokens"], precision=0, step=10, ) repetition_penalty = gr.Number( label="Repetition Penalty", value=default_generation_settings["repetition_penalty"], precision=2, step=0.1 ) # Right Column: Chat Interface with gr.Column(scale=2): gr.Markdown("### 🗨️ Conversation") # Chatbot to display conversation chatbot = gr.Chatbot(label="Conversation", type='tuples') # User Message Input user_input = gr.Textbox( label="Your Message (Shift+Enter submits)", lines=2, placeholder="I was out partying too late last night, and I'm going to be late for work. What should I tell my boss?" ) with gr.Row(): # Submit and New Chat buttons submit_button = gr.Button("💬 Submit") retry_button = gr.Button("🔃 Retry last turn") new_chat_button = gr.Button("🌟 New Chat") inputs_list = [system_prompt, user_input, chatbot, max_new_tokens, repetition_penalty] + control_checks + control_sliders # Define button actions submit_button.click( generate_response, inputs=inputs_list, outputs=[chatbot] ) user_input.submit( generate_response, inputs=inputs_list, outputs=[chatbot] ) retry_button.click( generate_response_with_retry, inputs=inputs_list, outputs=[chatbot] ) new_chat_button.click( reset_chat, inputs=[], outputs=[chatbot, user_input] ) # Launch the Gradio app if __name__ == "__main__": demo.launch()