abrakjamson commited on
Commit
453c7fc
1 Parent(s): 9acb8e6

Corrected history and special tokens

Browse files
Files changed (1) hide show
  1. app.py +45 -29
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from repeng import ControlVector, ControlModel
@@ -48,16 +49,18 @@ def toggle_slider(checked):
48
  return gr.update(visible=checked)
49
 
50
  # Function to generate the model's response
51
- def generate_response(system_prompt, user_message, *args, history=None, max_new_tokens=256, repetition_penalty=1.1):
52
  checkboxes = []
53
  sliders = []
54
 
 
 
55
  # Separate checkboxes and sliders based on type
56
- for item in args:
57
- if type(item) == bool:
58
- checkboxes.append(item)
59
- elif isinstance(item, (int, float)):
60
- sliders.append(item)
61
 
62
  if len(checkboxes) != len(control_vector_files) or len(sliders) != len(control_vector_files):
63
  return history if history else [], history if history else []
@@ -76,21 +79,28 @@ def generate_response(system_prompt, user_message, *args, history=None, max_new_
76
  except Exception as e:
77
  print(f"Failed to set control vector {cv_file}: {e}")
78
 
79
- # Initialize history if None
80
- history = history or []
81
-
82
- # Construct the formatted prompt based on history
83
  formatted_prompt = ""
84
- for turn in history:
85
- user_msg, asst_msg = turn
86
- formatted_prompt += f"{user_tag} {user_msg} {asst_tag} {asst_msg} </s>"
 
 
87
 
88
  # Append the system prompt if provided
89
  if system_prompt.strip():
90
- formatted_prompt += f"[INST] {system_prompt}"
 
 
 
 
 
 
 
 
 
91
 
92
  # Append the new user message
93
- formatted_prompt += f"\n{user_tag} {user_message} {asst_tag}"
94
 
95
  # Tokenize the input
96
  input_ids = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
@@ -99,23 +109,30 @@ def generate_response(system_prompt, user_message, *args, history=None, max_new_
99
  "pad_token_id": tokenizer.eos_token_id,
100
  "do_sample": default_generation_settings["do_sample"],
101
  "max_new_tokens": int(max_new_tokens),
102
- "repetition_penalty": repetition_penalty,
103
  }
104
 
105
  # Generate the response
106
- output_ids = model.generate(**input_ids, **default_generation_settings)
107
- response = tokenizer.decode(output_ids.squeeze(), skip_special_tokens=True)
108
-
109
- # Clean up the response by removing any trailing tags
110
- if "</s>" in response:
111
- response = response.split("</s>")[0].strip()
 
 
 
 
 
 
112
 
113
  # Update conversation history
114
- history.append((user_message, response))
115
- return history, history
116
 
117
  # Function to reset the conversation history
118
  def reset_chat():
 
119
  return [], []
120
 
121
  # Build the Gradio interface
@@ -198,20 +215,19 @@ with gr.Blocks() as demo:
198
  submit_button = gr.Button("💬 Submit")
199
  new_chat_button = gr.Button("🆕 New Chat")
200
 
201
- # State to keep track of conversation history
202
- state = gr.State()
203
 
204
  # Define button actions
205
  submit_button.click(
206
  generate_response,
207
- inputs=[system_prompt, user_input] + control_checks + control_sliders + [state],
208
- outputs=[chatbot, state]
209
  )
210
 
211
  new_chat_button.click(
212
  reset_chat,
213
  inputs=[],
214
- outputs=[chatbot, state]
215
  )
216
 
217
  # Launch the Gradio app
 
1
  import os
2
+ import re
3
  import torch
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
  from repeng import ControlVector, ControlModel
 
49
  return gr.update(visible=checked)
50
 
51
  # Function to generate the model's response
52
+ def generate_response(system_prompt, user_message, history, max_new_tokens, repitition_penalty, *args):
53
  checkboxes = []
54
  sliders = []
55
 
56
+ #inputs_list = [system_prompt, user_input, chatbot, max_new_tokens, repetition_penalty] + control_checks + control_sliders
57
+
58
  # Separate checkboxes and sliders based on type
59
+ # The first x in args are the checkbox names (the file names)
60
+ # The second x in args are the slider values
61
+ for i in range(len(control_vector_files)):
62
+ checkboxes.append(args[i])
63
+ sliders.append(args[len(control_vector_files) + i])
64
 
65
  if len(checkboxes) != len(control_vector_files) or len(sliders) != len(control_vector_files):
66
  return history if history else [], history if history else []
 
79
  except Exception as e:
80
  print(f"Failed to set control vector {cv_file}: {e}")
81
 
 
 
 
 
82
  formatted_prompt = ""
83
+
84
+
85
+ # Mistral expects the history to be wrapped in <s>history</s>
86
+ if len(history) > 0:
87
+ formatted_prompt += "<s>"
88
 
89
  # Append the system prompt if provided
90
  if system_prompt.strip():
91
+ formatted_prompt += f"[INST] {system_prompt} [/INST] "
92
+
93
+ # Construct the formatted prompt based on history
94
+ if len(history) > 0:
95
+ for turn in history:
96
+ user_msg, asst_msg = turn
97
+ formatted_prompt += f"{user_tag} {user_msg} {asst_tag} {asst_msg}"
98
+
99
+ if len(history) > 0:
100
+ formatted_prompt += "</s>"
101
 
102
  # Append the new user message
103
+ formatted_prompt += f"{user_tag} {user_message} {asst_tag}"
104
 
105
  # Tokenize the input
106
  input_ids = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
 
109
  "pad_token_id": tokenizer.eos_token_id,
110
  "do_sample": default_generation_settings["do_sample"],
111
  "max_new_tokens": int(max_new_tokens),
112
+ "repetition_penalty": repetition_penalty.value,
113
  }
114
 
115
  # Generate the response
116
+ output_ids = model.generate(**input_ids, **generation_settings)
117
+ response = tokenizer.decode(output_ids.squeeze(), skip_special_tokens=False)
118
+
119
+ def get_assistant_response(input_string):
120
+ # Use regex to find the text between the final [/INST] tag and </s>
121
+ pattern = r'\[/INST\](?!.*\[/INST\])\s*(.*?)(?:</s>|$)'
122
+ match = re.search(pattern, input_string, re.DOTALL)
123
+ if match:
124
+ return match.group(1).strip()
125
+ return None
126
+
127
+ assistant_response = get_assistant_response(response)
128
 
129
  # Update conversation history
130
+ history.append((user_message, assistant_response))
131
+ return history
132
 
133
  # Function to reset the conversation history
134
  def reset_chat():
135
+ # returns a blank user input text and a blank conversation history
136
  return [], []
137
 
138
  # Build the Gradio interface
 
215
  submit_button = gr.Button("💬 Submit")
216
  new_chat_button = gr.Button("🆕 New Chat")
217
 
218
+ inputs_list = [system_prompt, user_input, chatbot, max_new_tokens, repetition_penalty] + control_checks + control_sliders
 
219
 
220
  # Define button actions
221
  submit_button.click(
222
  generate_response,
223
+ inputs=inputs_list,
224
+ outputs=[chatbot]
225
  )
226
 
227
  new_chat_button.click(
228
  reset_chat,
229
  inputs=[],
230
+ outputs=[chatbot, user_input]
231
  )
232
 
233
  # Launch the Gradio app