akhaliq HF staff radames commited on
Commit
d14c800
1 Parent(s): e9f9901

fix queue (#6)

Browse files

- fix queue (9431ee6c6359396371b5551bbcb2e678cfa4b060)


Co-authored-by: Radamés Ajna <[email protected]>

Files changed (1) hide show
  1. app.py +37 -24
app.py CHANGED
@@ -4,11 +4,12 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, Stopping
4
  import time
5
  import numpy as np
6
  from torch.nn import functional as F
7
- import os
8
- auth_key = os.environ["HF_ACCESS_TOKEN"]
9
  print(f"Starting to load the model to memory")
10
- m = AutoModelForCausalLM.from_pretrained("stabilityai/stablelm-tuned-alpha-7b",use_auth_token=auth_key, torch_dtype=torch.float16).cuda()
11
- tok = AutoTokenizer.from_pretrained("stabilityai/stablelm-tuned-alpha-7b",use_auth_token=auth_key)
 
12
  generator = pipeline('text-generation', model=m, tokenizer=tok, device=0)
13
  print(f"Sucessfully loaded the model to the memory")
14
 
@@ -30,8 +31,10 @@ class StopOnTokens(StoppingCriteria):
30
 
31
  def contrastive_generate(text, bad_text):
32
  with torch.no_grad():
33
- tokens = tok(text, return_tensors="pt")['input_ids'].cuda()[:,:4096-1024]
34
- bad_tokens = tok(bad_text, return_tensors="pt")['input_ids'].cuda()[:,:4096-1024]
 
 
35
  history = None
36
  bad_history = None
37
  curr_output = list()
@@ -39,7 +42,8 @@ def contrastive_generate(text, bad_text):
39
  out = m(tokens, past_key_values=history, use_cache=True)
40
  logits = out.logits
41
  history = out.past_key_values
42
- bad_out = m(bad_tokens, past_key_values=bad_history, use_cache=True)
 
43
  bad_logits = bad_out.logits
44
  bad_history = bad_out.past_key_values
45
  probs = F.softmax(logits.float(), dim=-1)[0][-1].cpu()
@@ -60,39 +64,48 @@ def contrastive_generate(text, bad_text):
60
  tokens.device)
61
  return tok.decode(curr_output)
62
 
 
63
  def generate(text, bad_text=None):
64
  stop = StopOnTokens()
65
- result = generator(text, max_new_tokens=1024, num_return_sequences=1, num_beams=1, do_sample=True, temperature=1.0, top_p=0.95, top_k=1000, stopping_criteria=StoppingCriteriaList([stop]))
 
66
  return result[0]["generated_text"].replace(text, "")
67
 
68
 
69
  def user(user_message, history):
70
- return "", history + [[user_message, ""]]
 
71
 
72
 
73
  def bot(history, curr_system_message):
74
- messages = curr_system_message + "".join(["".join(["<|USER|>"+item[0], "<|ASSISTANT|>"+item[1]]) for item in history])
 
 
75
  output = generate(messages)
76
  history[-1][1] = output
77
  time.sleep(1)
78
- return history
79
-
80
-
81
 
82
 
83
  with gr.Blocks() as demo:
84
- num = gr.State(value=0)
85
  gr.Markdown("## StableLM-Tuned-Alpha-7b Chat")
86
  gr.HTML('''<center><a href="https://huggingface.co/spaces/stabilityai/stablelm-tuned-alpha-chat?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate the Space to skip the queue and run in a private space</center>''')
87
- chatbot = gr.Chatbot([])
88
- clear = gr.Button("Clear Chat History")
89
- system_msg = gr.Textbox(start_message, label="System Message", interactive=False,visible=False)
90
- #system_msg = start_message
91
- msg = gr.Textbox(label="Chat Message Box")
 
 
 
 
 
92
 
93
- msg.submit(user, [msg, chatbot], [msg, chatbot], queue=True).then(
94
- bot, [chatbot, system_msg], chatbot
95
- )
96
- clear.click(lambda: None, None, chatbot, queue=True)
 
97
  demo.queue(concurrency_count=5)
98
- demo.launch()
 
4
  import time
5
  import numpy as np
6
  from torch.nn import functional as F
7
+ import os
8
+ # auth_key = os.environ["HF_ACCESS_TOKEN"]
9
  print(f"Starting to load the model to memory")
10
+ m = AutoModelForCausalLM.from_pretrained(
11
+ "stabilityai/stablelm-tuned-alpha-7b", torch_dtype=torch.float16).cuda()
12
+ tok = AutoTokenizer.from_pretrained("stabilityai/stablelm-tuned-alpha-7b")
13
  generator = pipeline('text-generation', model=m, tokenizer=tok, device=0)
14
  print(f"Sucessfully loaded the model to the memory")
15
 
 
31
 
32
  def contrastive_generate(text, bad_text):
33
  with torch.no_grad():
34
+ tokens = tok(text, return_tensors="pt")[
35
+ 'input_ids'].cuda()[:, :4096-1024]
36
+ bad_tokens = tok(bad_text, return_tensors="pt")[
37
+ 'input_ids'].cuda()[:, :4096-1024]
38
  history = None
39
  bad_history = None
40
  curr_output = list()
 
42
  out = m(tokens, past_key_values=history, use_cache=True)
43
  logits = out.logits
44
  history = out.past_key_values
45
+ bad_out = m(bad_tokens, past_key_values=bad_history,
46
+ use_cache=True)
47
  bad_logits = bad_out.logits
48
  bad_history = bad_out.past_key_values
49
  probs = F.softmax(logits.float(), dim=-1)[0][-1].cpu()
 
64
  tokens.device)
65
  return tok.decode(curr_output)
66
 
67
+
68
  def generate(text, bad_text=None):
69
  stop = StopOnTokens()
70
+ result = generator(text, max_new_tokens=1024, num_return_sequences=1, num_beams=1, do_sample=True,
71
+ temperature=1.0, top_p=0.95, top_k=1000, stopping_criteria=StoppingCriteriaList([stop]))
72
  return result[0]["generated_text"].replace(text, "")
73
 
74
 
75
  def user(user_message, history):
76
+ history = history + [[user_message, ""]]
77
+ return "", history, history
78
 
79
 
80
  def bot(history, curr_system_message):
81
+ messages = curr_system_message + \
82
+ "".join(["".join(["<|USER|>"+item[0], "<|ASSISTANT|>"+item[1]])
83
+ for item in history])
84
  output = generate(messages)
85
  history[-1][1] = output
86
  time.sleep(1)
87
+ return history, history
 
 
88
 
89
 
90
  with gr.Blocks() as demo:
91
+ history = gr.State([])
92
  gr.Markdown("## StableLM-Tuned-Alpha-7b Chat")
93
  gr.HTML('''<center><a href="https://huggingface.co/spaces/stabilityai/stablelm-tuned-alpha-chat?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate the Space to skip the queue and run in a private space</center>''')
94
+ chatbot = gr.Chatbot().style(height=500)
95
+ with gr.Row():
96
+ with gr.Column(scale=0.70):
97
+ msg = gr.Textbox(label="", placeholder="Chat Message Box")
98
+ with gr.Column(scale=0.30, min_width=0):
99
+ with gr.Row():
100
+ submit = gr.Button("Submit")
101
+ clear = gr.Button("Clear")
102
+ system_msg = gr.Textbox(
103
+ start_message, label="System Message", interactive=False, visible=False)
104
 
105
+ msg.submit(fn=user, inputs=[msg, history], outputs=[msg, chatbot, history], queue=False).then(
106
+ fn=bot, inputs=[chatbot, system_msg], outputs=[chatbot, history], queue=True)
107
+ submit.click(fn=user, inputs=[msg, history], outputs=[msg, chatbot, history], queue=False).then(
108
+ fn=bot, inputs=[chatbot, system_msg], outputs=[chatbot, history], queue=True)
109
+ clear.click(lambda: [None, []], None, [chatbot, history], queue=False)
110
  demo.queue(concurrency_count=5)
111
+ demo.launch()