Spaces:
Sleeping
Sleeping
abrakjamson
commited on
Commit
•
85e58bb
1
Parent(s):
4da1fb0
Disable input while generating
Browse files
app.py
CHANGED
@@ -24,10 +24,12 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
24 |
trust_remote_code=True,
|
25 |
use_safetensors=True
|
26 |
)
|
27 |
-
|
28 |
-
print(f"Is CUDA available: {
|
29 |
-
|
|
|
30 |
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
|
|
|
31 |
|
32 |
model = ControlModel(model, list(range(-5, -18, -1)))
|
33 |
|
@@ -87,7 +89,8 @@ def generate_response(system_prompt, user_message, history, max_new_tokens, repi
|
|
87 |
Returns a list of tuples, the user message and the assistant response,
|
88 |
which Gradio uses to update the chatbot history
|
89 |
"""
|
90 |
-
|
|
|
91 |
# Separate checkboxes and sliders based on type
|
92 |
# The first x in args are the checkbox names (the file names)
|
93 |
# The second x in args are the slider values
|
@@ -139,7 +142,10 @@ def generate_response(system_prompt, user_message, history, max_new_tokens, repi
|
|
139 |
"repetition_penalty": repetition_penalty.value,
|
140 |
}
|
141 |
|
142 |
-
|
|
|
|
|
|
|
143 |
|
144 |
generate_kwargs = dict(
|
145 |
input_ids,
|
@@ -155,6 +161,9 @@ def generate_response(system_prompt, user_message, history, max_new_tokens, repi
|
|
155 |
|
156 |
# Display the response as it streams in, prepending the control vector info
|
157 |
partial_message = ""
|
|
|
|
|
|
|
158 |
for new_token in _streamer:
|
159 |
if new_token != '<' and new_token != '</s>': # seems to hit EOS correctly without this needed
|
160 |
partial_message += new_token
|
@@ -181,14 +190,17 @@ def generate_response(system_prompt, user_message, history, max_new_tokens, repi
|
|
181 |
|
182 |
# Update conversation history
|
183 |
history.append((user_message, assistant_response_display))
|
184 |
-
|
185 |
|
186 |
def generate_response_with_retry(system_prompt, user_message, history, max_new_tokens, repitition_penalty, do_sample, *args):
|
187 |
# Remove last user input and assistant response from history, then call generate_response()
|
|
|
|
|
188 |
if history:
|
189 |
history = history[0:-1]
|
190 |
-
|
191 |
-
|
|
|
192 |
|
193 |
# Function to reset the conversation history
|
194 |
def reset_chat():
|
@@ -281,7 +293,7 @@ def set_preset_stoner(*args):
|
|
281 |
for check in model_names_and_indexes:
|
282 |
if check == "Angry":
|
283 |
new_checkbox_values.append(True)
|
284 |
-
new_slider_values.append(0.
|
285 |
elif check == "Right-leaning":
|
286 |
new_checkbox_values.append(True)
|
287 |
new_slider_values.append(-0.5)
|
@@ -323,6 +335,15 @@ def set_preset_facts(*args):
|
|
323 |
|
324 |
return new_checkbox_values + new_slider_values
|
325 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
326 |
tooltip_css = """
|
327 |
/* Tooltip container */
|
328 |
.tooltip {
|
@@ -560,10 +581,22 @@ with gr.Blocks(
|
|
560 |
inputs_list = [system_prompt, user_input, chatbot, max_new_tokens, repetition_penalty, do_sample] + control_checks + control_sliders
|
561 |
|
562 |
# Define button actions
|
|
|
|
|
|
|
|
|
|
|
|
|
563 |
submit_button.click(
|
564 |
generate_response,
|
565 |
inputs=inputs_list,
|
566 |
outputs=[chatbot]
|
|
|
|
|
|
|
|
|
|
|
|
|
567 |
)
|
568 |
|
569 |
user_input.submit(
|
@@ -575,7 +608,11 @@ with gr.Blocks(
|
|
575 |
retry_button.click(
|
576 |
generate_response_with_retry,
|
577 |
inputs=inputs_list,
|
578 |
-
outputs=[chatbot]
|
|
|
|
|
|
|
|
|
579 |
)
|
580 |
|
581 |
new_chat_button.click(
|
|
|
24 |
trust_remote_code=True,
|
25 |
use_safetensors=True
|
26 |
)
|
27 |
+
cuda = torch.cuda.is_available()
|
28 |
+
print(f"Is CUDA available: {cuda}")
|
29 |
+
model = model.to("cuda:0" if cuda else "cpu")
|
30 |
+
if cuda:
|
31 |
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
|
32 |
+
|
33 |
|
34 |
model = ControlModel(model, list(range(-5, -18, -1)))
|
35 |
|
|
|
89 |
Returns a list of tuples, the user message and the assistant response,
|
90 |
which Gradio uses to update the chatbot history
|
91 |
"""
|
92 |
+
global previous_turn
|
93 |
+
previous_turn = user_message
|
94 |
# Separate checkboxes and sliders based on type
|
95 |
# The first x in args are the checkbox names (the file names)
|
96 |
# The second x in args are the slider values
|
|
|
142 |
"repetition_penalty": repetition_penalty.value,
|
143 |
}
|
144 |
|
145 |
+
timeout = 120.0
|
146 |
+
if cuda:
|
147 |
+
timeout = 10.0
|
148 |
+
_streamer = TextIteratorStreamer(tokenizer, timeout=timeout, skip_prompt=True, skip_special_tokens=False,)
|
149 |
|
150 |
generate_kwargs = dict(
|
151 |
input_ids,
|
|
|
161 |
|
162 |
# Display the response as it streams in, prepending the control vector info
|
163 |
partial_message = ""
|
164 |
+
#show the control vector info while we wait for the first token
|
165 |
+
temp_output = "*" + assistant_message_title + "*" + "\n\n*Please wait*..." + partial_message
|
166 |
+
yield history + [(user_message, temp_output)]
|
167 |
for new_token in _streamer:
|
168 |
if new_token != '<' and new_token != '</s>': # seems to hit EOS correctly without this needed
|
169 |
partial_message += new_token
|
|
|
190 |
|
191 |
# Update conversation history
|
192 |
history.append((user_message, assistant_response_display))
|
193 |
+
return history
|
194 |
|
195 |
def generate_response_with_retry(system_prompt, user_message, history, max_new_tokens, repitition_penalty, do_sample, *args):
|
196 |
# Remove last user input and assistant response from history, then call generate_response()
|
197 |
+
global previous_turn
|
198 |
+
previous_ueser_message = previous_turn
|
199 |
if history:
|
200 |
history = history[0:-1]
|
201 |
+
# Using the previous turn's text, even though it isn't in the textbox anymore
|
202 |
+
for output in generate_response(system_prompt, previous_ueser_message, history, max_new_tokens, repetition_penalty, do_sample, *args):
|
203 |
+
yield [output, previous_ueser_message]
|
204 |
|
205 |
# Function to reset the conversation history
|
206 |
def reset_chat():
|
|
|
293 |
for check in model_names_and_indexes:
|
294 |
if check == "Angry":
|
295 |
new_checkbox_values.append(True)
|
296 |
+
new_slider_values.append(0.4)
|
297 |
elif check == "Right-leaning":
|
298 |
new_checkbox_values.append(True)
|
299 |
new_slider_values.append(-0.5)
|
|
|
335 |
|
336 |
return new_checkbox_values + new_slider_values
|
337 |
|
338 |
+
def disable_controls():
|
339 |
+
return gr.update(interactive= False, value= "⌛ Processing"), gr.update(interactive=False)
|
340 |
+
|
341 |
+
def enable_controls():
|
342 |
+
return gr.update(interactive= True, value= "💬 Submit"), gr.update(interactive= True)
|
343 |
+
|
344 |
+
def clear_input(input_textbox):
|
345 |
+
return ""
|
346 |
+
|
347 |
tooltip_css = """
|
348 |
/* Tooltip container */
|
349 |
.tooltip {
|
|
|
581 |
inputs_list = [system_prompt, user_input, chatbot, max_new_tokens, repetition_penalty, do_sample] + control_checks + control_sliders
|
582 |
|
583 |
# Define button actions
|
584 |
+
# Disable the submit button while processing
|
585 |
+
submit_button.click(
|
586 |
+
disable_controls,
|
587 |
+
inputs= None,
|
588 |
+
outputs= [submit_button, user_input]
|
589 |
+
)
|
590 |
submit_button.click(
|
591 |
generate_response,
|
592 |
inputs=inputs_list,
|
593 |
outputs=[chatbot]
|
594 |
+
).then(
|
595 |
+
clear_input,
|
596 |
+
inputs= user_input,
|
597 |
+
outputs= user_input
|
598 |
+
).then(
|
599 |
+
enable_controls, inputs=None, outputs=[submit_button, user_input]
|
600 |
)
|
601 |
|
602 |
user_input.submit(
|
|
|
608 |
retry_button.click(
|
609 |
generate_response_with_retry,
|
610 |
inputs=inputs_list,
|
611 |
+
outputs=[chatbot, user_input]
|
612 |
+
).then(
|
613 |
+
clear_input,
|
614 |
+
inputs= user_input,
|
615 |
+
outputs= user_input
|
616 |
)
|
617 |
|
618 |
new_chat_button.click(
|