abrakjamson commited on
Commit
129904f
1 Parent(s): f631e46

Updating interface, fixing checkbox bugs

Browse files
Files changed (1) hide show
  1. app.py +63 -48
app.py CHANGED
@@ -3,6 +3,8 @@ import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from repeng import ControlVector, ControlModel
5
  import gradio as gr
 
 
6
  from huggingface_hub import login
7
 
8
  # Initialize model and tokenizer
@@ -12,7 +14,6 @@ mistral_path = "mistralai/Mistral-7B-Instruct-v0.3"
12
  access_token = os.getenv("mistralaccesstoken")
13
  login(access_token)
14
 
15
- #tokenizer = AutoTokenizer.from_pretrained(mistral_path)
16
  tokenizer = AutoTokenizer.from_pretrained(mistral_path)
17
  tokenizer.pad_token_id = 0
18
 
@@ -47,69 +48,83 @@ def toggle_slider(checked):
47
  return gr.update(visible=checked)
48
 
49
  # Function to generate the model's response
50
- def generate_response(system_prompt, user_message, *args, history=None):
51
- # args contains alternating checkbox and slider values
52
- print("generating response for user query {user_message}")
53
- num_controls = len(control_vector_files)
54
- checkboxes = args[0::2] # Extract every first item in each pair
55
- sliders = args[1::2] # Extract every second item in each pair
56
-
 
 
57
  # Reset any previous control vectors
58
  model.reset()
59
-
60
- print("applying weights")
61
  # Apply selected control vectors with their corresponding weights
62
- for i in range(num_controls):
63
  if checkboxes[i]:
64
- print(f"checkbox: {i} True for {cv_file}, weight: {weight}")
65
  cv_file = control_vector_files[i]
66
  weight = sliders[i]
67
  try:
 
68
  control_vector = ControlVector.import_gguf(cv_file)
69
  model.set_control(control_vector, weight)
70
- print("control vector set for {cv_file}")
71
  except Exception as e:
72
  print(f"Failed to set control vector {cv_file}: {e}")
73
-
74
- # Format the prompt
 
 
 
 
 
 
 
 
 
75
  if system_prompt.strip():
76
- formatted_prompt = f"{system_prompt}\n{user_tag}{user_message}{asst_tag}"
77
- else:
78
- formatted_prompt = f"{user_tag}{user_message}{asst_tag}"
79
-
 
80
  # Tokenize the input
81
  input_ids = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
82
-
83
  # Generate the response
84
  output_ids = model.generate(**input_ids, **generation_settings)
85
  response = tokenizer.decode(output_ids.squeeze(), skip_special_tokens=True)
86
-
 
 
 
 
87
  # Update conversation history
88
- history = history or []
89
  history.append((user_message, response))
90
  return history
91
 
92
  # Function to reset the conversation history
93
  def reset_chat():
94
- return []
95
 
96
  # Build the Gradio interface
97
  with gr.Blocks() as demo:
98
- gr.Markdown("# 🧠 Language Model Interface")
99
-
100
  with gr.Row():
 
101
  with gr.Column(scale=1):
 
 
102
  # System Prompt Input
103
  system_prompt = gr.Textbox(
104
  label="System Prompt",
105
  lines=2,
106
  placeholder="Enter system-level instructions here..."
107
  )
108
-
109
-
110
-
111
  gr.Markdown("### 📊 Control Vectors")
112
-
113
  # Create checkboxes and sliders for each control vector
114
  control_checks = []
115
  control_sliders = []
@@ -118,7 +133,7 @@ with gr.Blocks() as demo:
118
  # Checkbox to select the control vector
119
  checkbox = gr.Checkbox(label=cv_file, value=False)
120
  control_checks.append(checkbox)
121
-
122
  # Slider to adjust the control vector's weight
123
  slider = gr.Slider(
124
  minimum=-2.5,
@@ -129,49 +144,49 @@ with gr.Blocks() as demo:
129
  visible=False
130
  )
131
  control_sliders.append(slider)
132
-
133
  # Link the checkbox to toggle slider visibility
134
  checkbox.change(
135
  toggle_slider,
136
  inputs=checkbox,
137
  outputs=slider
138
  )
139
-
140
- with gr.Row():
141
- # Submit and New Chat buttons
142
- submit_button = gr.Button("💬 Submit")
143
- new_chat_button = gr.Button("🆕 New Chat")
144
-
145
  with gr.Column(scale=2):
 
 
146
  # Chatbot to display conversation
147
- chatbot = gr.Chatbot(label="🗨️ Conversation")
148
 
149
  # User Message Input
150
  user_input = gr.Textbox(
151
- label="User Message",
152
  lines=2,
153
  placeholder="Type your message here..."
154
  )
155
-
 
 
 
 
 
156
  # State to keep track of conversation history
157
  state = gr.State([])
158
-
159
  # Define button actions
160
  submit_button.click(
161
  generate_response,
162
  inputs=[system_prompt, user_input] + control_checks + control_sliders + [state],
163
- outputs=[chatbot]
164
  )
165
-
166
  new_chat_button.click(
167
  reset_chat,
168
  inputs=[],
169
- outputs=[chatbot]
170
  )
171
 
172
  # Launch the Gradio app
173
  if __name__ == "__main__":
174
- demo.launch()
175
- # control_checks = []
176
- # control_checks.append()
177
- # generate_response("helpful assistant", "help me come up with a lie to my boss about why I'm late", ((True), (-1.0)), None)
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from repeng import ControlVector, ControlModel
5
  import gradio as gr
6
+
7
+ # Initialize model and tokenizer
8
  from huggingface_hub import login
9
 
10
  # Initialize model and tokenizer
 
14
  access_token = os.getenv("mistralaccesstoken")
15
  login(access_token)
16
 
 
17
  tokenizer = AutoTokenizer.from_pretrained(mistral_path)
18
  tokenizer.pad_token_id = 0
19
 
 
48
  return gr.update(visible=checked)
49
 
50
  # Function to generate the model's response
51
+ def generate_response(system_prompt, user_message, *args, history):
52
+ # Separate checkboxes and sliders based on type
53
+ print(f"Generating response to {user_message}")
54
+ checkboxes = [item for item in args if isinstance(item, bool)]
55
+ sliders = [item for item in args if isinstance(item, (int, float))]
56
+
57
+ if len(checkboxes) != len(control_vector_files) or len(sliders) != len(control_vector_files):
58
+ return history # Return current history if there's a mismatch
59
+
60
  # Reset any previous control vectors
61
  model.reset()
62
+
 
63
  # Apply selected control vectors with their corresponding weights
64
+ for i in range(len(control_vector_files)):
65
  if checkboxes[i]:
 
66
  cv_file = control_vector_files[i]
67
  weight = sliders[i]
68
  try:
69
+ print(f"Setting {cv_file} to {weight}")
70
  control_vector = ControlVector.import_gguf(cv_file)
71
  model.set_control(control_vector, weight)
 
72
  except Exception as e:
73
  print(f"Failed to set control vector {cv_file}: {e}")
74
+
75
+ # Initialize history if None
76
+ history = history or []
77
+
78
+ # Construct the formatted prompt based on history
79
+ formatted_prompt = ""
80
+ for turn in history:
81
+ user_msg, asst_msg = turn
82
+ formatted_prompt += f"{user_tag} {user_msg} {asst_tag} {asst_msg} </s>"
83
+
84
+ # Append the system prompt if provided
85
  if system_prompt.strip():
86
+ formatted_prompt += f"[INST] {system_prompt}"
87
+
88
+ # Append the new user message
89
+ formatted_prompt += f"\n{user_tag} {user_message} {asst_tag}"
90
+
91
  # Tokenize the input
92
  input_ids = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
93
+
94
  # Generate the response
95
  output_ids = model.generate(**input_ids, **generation_settings)
96
  response = tokenizer.decode(output_ids.squeeze(), skip_special_tokens=True)
97
+
98
+ # Clean up the response by removing any trailing tags
99
+ if "</s>" in response:
100
+ response = response.split("</s>")[0].strip()
101
+
102
  # Update conversation history
 
103
  history.append((user_message, response))
104
  return history
105
 
106
  # Function to reset the conversation history
107
  def reset_chat():
108
+ return [], []
109
 
110
  # Build the Gradio interface
111
  with gr.Blocks() as demo:
112
+ gr.Markdown("# 🧠 Mistral v3 Language Model Interface")
113
+
114
  with gr.Row():
115
+ # Left Column: Settings and Control Vectors
116
  with gr.Column(scale=1):
117
+ gr.Markdown("### ⚙️ Settings")
118
+
119
  # System Prompt Input
120
  system_prompt = gr.Textbox(
121
  label="System Prompt",
122
  lines=2,
123
  placeholder="Enter system-level instructions here..."
124
  )
125
+
 
 
126
  gr.Markdown("### 📊 Control Vectors")
127
+
128
  # Create checkboxes and sliders for each control vector
129
  control_checks = []
130
  control_sliders = []
 
133
  # Checkbox to select the control vector
134
  checkbox = gr.Checkbox(label=cv_file, value=False)
135
  control_checks.append(checkbox)
136
+
137
  # Slider to adjust the control vector's weight
138
  slider = gr.Slider(
139
  minimum=-2.5,
 
144
  visible=False
145
  )
146
  control_sliders.append(slider)
147
+
148
  # Link the checkbox to toggle slider visibility
149
  checkbox.change(
150
  toggle_slider,
151
  inputs=checkbox,
152
  outputs=slider
153
  )
154
+
155
+ # Right Column: Chat Interface
 
 
 
 
156
  with gr.Column(scale=2):
157
+ gr.Markdown("### 🗨️ Conversation")
158
+
159
  # Chatbot to display conversation
160
+ chatbot = gr.Chatbot(label="Conversation")
161
 
162
  # User Message Input
163
  user_input = gr.Textbox(
164
+ label="Your Message",
165
  lines=2,
166
  placeholder="Type your message here..."
167
  )
168
+
169
+ with gr.Row():
170
+ # Submit and New Chat buttons
171
+ submit_button = gr.Button("💬 Submit")
172
+ new_chat_button = gr.Button("🆕 New Chat")
173
+
174
  # State to keep track of conversation history
175
  state = gr.State([])
176
+
177
  # Define button actions
178
  submit_button.click(
179
  generate_response,
180
  inputs=[system_prompt, user_input] + control_checks + control_sliders + [state],
181
+ outputs=[chatbot, state]
182
  )
183
+
184
  new_chat_button.click(
185
  reset_chat,
186
  inputs=[],
187
+ outputs=[chatbot, state]
188
  )
189
 
190
  # Launch the Gradio app
191
  if __name__ == "__main__":
192
+ demo.launch()