fix queue (#6)
Browse files- fix queue (9431ee6c6359396371b5551bbcb2e678cfa4b060)
Co-authored-by: Radamés Ajna <[email protected]>
app.py
CHANGED
@@ -4,11 +4,12 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, Stopping
|
|
4 |
import time
|
5 |
import numpy as np
|
6 |
from torch.nn import functional as F
|
7 |
-
import os
|
8 |
-
auth_key = os.environ["HF_ACCESS_TOKEN"]
|
9 |
print(f"Starting to load the model to memory")
|
10 |
-
m = AutoModelForCausalLM.from_pretrained(
|
11 |
-
|
|
|
12 |
generator = pipeline('text-generation', model=m, tokenizer=tok, device=0)
|
13 |
print(f"Sucessfully loaded the model to the memory")
|
14 |
|
@@ -30,8 +31,10 @@ class StopOnTokens(StoppingCriteria):
|
|
30 |
|
31 |
def contrastive_generate(text, bad_text):
|
32 |
with torch.no_grad():
|
33 |
-
tokens = tok(text, return_tensors="pt")[
|
34 |
-
|
|
|
|
|
35 |
history = None
|
36 |
bad_history = None
|
37 |
curr_output = list()
|
@@ -39,7 +42,8 @@ def contrastive_generate(text, bad_text):
|
|
39 |
out = m(tokens, past_key_values=history, use_cache=True)
|
40 |
logits = out.logits
|
41 |
history = out.past_key_values
|
42 |
-
bad_out = m(bad_tokens, past_key_values=bad_history,
|
|
|
43 |
bad_logits = bad_out.logits
|
44 |
bad_history = bad_out.past_key_values
|
45 |
probs = F.softmax(logits.float(), dim=-1)[0][-1].cpu()
|
@@ -60,39 +64,48 @@ def contrastive_generate(text, bad_text):
|
|
60 |
tokens.device)
|
61 |
return tok.decode(curr_output)
|
62 |
|
|
|
63 |
def generate(text, bad_text=None):
|
64 |
stop = StopOnTokens()
|
65 |
-
result = generator(text, max_new_tokens=1024, num_return_sequences=1, num_beams=1, do_sample=True,
|
|
|
66 |
return result[0]["generated_text"].replace(text, "")
|
67 |
|
68 |
|
69 |
def user(user_message, history):
|
70 |
-
|
|
|
71 |
|
72 |
|
73 |
def bot(history, curr_system_message):
|
74 |
-
messages = curr_system_message +
|
|
|
|
|
75 |
output = generate(messages)
|
76 |
history[-1][1] = output
|
77 |
time.sleep(1)
|
78 |
-
return history
|
79 |
-
|
80 |
-
|
81 |
|
82 |
|
83 |
with gr.Blocks() as demo:
|
84 |
-
|
85 |
gr.Markdown("## StableLM-Tuned-Alpha-7b Chat")
|
86 |
gr.HTML('''<center><a href="https://huggingface.co/spaces/stabilityai/stablelm-tuned-alpha-chat?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate the Space to skip the queue and run in a private space</center>''')
|
87 |
-
chatbot = gr.Chatbot(
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
-
msg.submit(user, [msg,
|
94 |
-
bot, [chatbot, system_msg], chatbot
|
95 |
-
)
|
96 |
-
|
|
|
97 |
demo.queue(concurrency_count=5)
|
98 |
-
demo.launch()
|
|
|
4 |
import time
|
5 |
import numpy as np
|
6 |
from torch.nn import functional as F
|
7 |
+
import os
|
8 |
+
# auth_key = os.environ["HF_ACCESS_TOKEN"]
|
9 |
print(f"Starting to load the model to memory")
|
10 |
+
m = AutoModelForCausalLM.from_pretrained(
|
11 |
+
"stabilityai/stablelm-tuned-alpha-7b", torch_dtype=torch.float16).cuda()
|
12 |
+
tok = AutoTokenizer.from_pretrained("stabilityai/stablelm-tuned-alpha-7b")
|
13 |
generator = pipeline('text-generation', model=m, tokenizer=tok, device=0)
|
14 |
print(f"Sucessfully loaded the model to the memory")
|
15 |
|
|
|
31 |
|
32 |
def contrastive_generate(text, bad_text):
|
33 |
with torch.no_grad():
|
34 |
+
tokens = tok(text, return_tensors="pt")[
|
35 |
+
'input_ids'].cuda()[:, :4096-1024]
|
36 |
+
bad_tokens = tok(bad_text, return_tensors="pt")[
|
37 |
+
'input_ids'].cuda()[:, :4096-1024]
|
38 |
history = None
|
39 |
bad_history = None
|
40 |
curr_output = list()
|
|
|
42 |
out = m(tokens, past_key_values=history, use_cache=True)
|
43 |
logits = out.logits
|
44 |
history = out.past_key_values
|
45 |
+
bad_out = m(bad_tokens, past_key_values=bad_history,
|
46 |
+
use_cache=True)
|
47 |
bad_logits = bad_out.logits
|
48 |
bad_history = bad_out.past_key_values
|
49 |
probs = F.softmax(logits.float(), dim=-1)[0][-1].cpu()
|
|
|
64 |
tokens.device)
|
65 |
return tok.decode(curr_output)
|
66 |
|
67 |
+
|
68 |
def generate(text, bad_text=None):
|
69 |
stop = StopOnTokens()
|
70 |
+
result = generator(text, max_new_tokens=1024, num_return_sequences=1, num_beams=1, do_sample=True,
|
71 |
+
temperature=1.0, top_p=0.95, top_k=1000, stopping_criteria=StoppingCriteriaList([stop]))
|
72 |
return result[0]["generated_text"].replace(text, "")
|
73 |
|
74 |
|
75 |
def user(user_message, history):
|
76 |
+
history = history + [[user_message, ""]]
|
77 |
+
return "", history, history
|
78 |
|
79 |
|
80 |
def bot(history, curr_system_message):
|
81 |
+
messages = curr_system_message + \
|
82 |
+
"".join(["".join(["<|USER|>"+item[0], "<|ASSISTANT|>"+item[1]])
|
83 |
+
for item in history])
|
84 |
output = generate(messages)
|
85 |
history[-1][1] = output
|
86 |
time.sleep(1)
|
87 |
+
return history, history
|
|
|
|
|
88 |
|
89 |
|
90 |
with gr.Blocks() as demo:
|
91 |
+
history = gr.State([])
|
92 |
gr.Markdown("## StableLM-Tuned-Alpha-7b Chat")
|
93 |
gr.HTML('''<center><a href="https://huggingface.co/spaces/stabilityai/stablelm-tuned-alpha-chat?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate the Space to skip the queue and run in a private space</center>''')
|
94 |
+
chatbot = gr.Chatbot().style(height=500)
|
95 |
+
with gr.Row():
|
96 |
+
with gr.Column(scale=0.70):
|
97 |
+
msg = gr.Textbox(label="", placeholder="Chat Message Box")
|
98 |
+
with gr.Column(scale=0.30, min_width=0):
|
99 |
+
with gr.Row():
|
100 |
+
submit = gr.Button("Submit")
|
101 |
+
clear = gr.Button("Clear")
|
102 |
+
system_msg = gr.Textbox(
|
103 |
+
start_message, label="System Message", interactive=False, visible=False)
|
104 |
|
105 |
+
msg.submit(fn=user, inputs=[msg, history], outputs=[msg, chatbot, history], queue=False).then(
|
106 |
+
fn=bot, inputs=[chatbot, system_msg], outputs=[chatbot, history], queue=True)
|
107 |
+
submit.click(fn=user, inputs=[msg, history], outputs=[msg, chatbot, history], queue=False).then(
|
108 |
+
fn=bot, inputs=[chatbot, system_msg], outputs=[chatbot, history], queue=True)
|
109 |
+
clear.click(lambda: [None, []], None, [chatbot, history], queue=False)
|
110 |
demo.queue(concurrency_count=5)
|
111 |
+
demo.launch()
|