multimodalart HF staff commited on
Commit
015885c
1 Parent(s): 7fdd6d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -51
app.py CHANGED
@@ -1,11 +1,12 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, StoppingCriteria, StoppingCriteriaList
4
  import time
5
  import numpy as np
6
  from torch.nn import functional as F
7
  import os
8
- # auth_key = os.environ["HF_ACCESS_TOKEN"]
 
9
  print(f"Starting to load the model to memory")
10
  m = AutoModelForCausalLM.from_pretrained(
11
  "stabilityai/stablelm-tuned-alpha-7b", torch_dtype=torch.float16).cuda()
@@ -28,62 +29,40 @@ class StopOnTokens(StoppingCriteria):
28
  return True
29
  return False
30
 
31
-
32
- def contrastive_generate(text, bad_text):
33
- with torch.no_grad():
34
- tokens = tok(text, return_tensors="pt")[
35
- 'input_ids'].cuda()[:, :4096-1024]
36
- bad_tokens = tok(bad_text, return_tensors="pt")[
37
- 'input_ids'].cuda()[:, :4096-1024]
38
- history = None
39
- bad_history = None
40
- curr_output = list()
41
- for i in range(1024):
42
- out = m(tokens, past_key_values=history, use_cache=True)
43
- logits = out.logits
44
- history = out.past_key_values
45
- bad_out = m(bad_tokens, past_key_values=bad_history,
46
- use_cache=True)
47
- bad_logits = bad_out.logits
48
- bad_history = bad_out.past_key_values
49
- probs = F.softmax(logits.float(), dim=-1)[0][-1].cpu()
50
- bad_probs = F.softmax(bad_logits.float(), dim=-1)[0][-1].cpu()
51
- logits = torch.log(probs)
52
- bad_logits = torch.log(bad_probs)
53
- logits[probs > 0.1] = logits[probs > 0.1] - bad_logits[probs > 0.1]
54
- probs = F.softmax(logits)
55
- out = int(torch.multinomial(probs, 1))
56
- if out in [50278, 50279, 50277, 1, 0]:
57
- break
58
- else:
59
- curr_output.append(out)
60
- out = np.array([out])
61
- tokens = torch.from_numpy(np.array([out])).to(
62
- tokens.device)
63
- bad_tokens = torch.from_numpy(np.array([out])).to(
64
- tokens.device)
65
- return tok.decode(curr_output)
66
-
67
-
68
- def generate(text, bad_text=None):
69
- stop = StopOnTokens()
70
- result = generator(text, max_new_tokens=1024, num_return_sequences=1, num_beams=1, do_sample=True,
71
- temperature=1.0, top_p=0.95, top_k=1000, stopping_criteria=StoppingCriteriaList([stop]))
72
- return result[0]["generated_text"].replace(text, "")
73
-
74
-
75
  def user(user_message, history):
76
  history = history + [[user_message, ""]]
77
  return "", history, history
78
 
79
 
80
  def bot(history, curr_system_message):
 
81
  messages = curr_system_message + \
82
  "".join(["".join(["<|USER|>"+item[0], "<|ASSISTANT|>"+item[1]])
83
  for item in history])
84
- output = generate(messages)
85
- history[-1][1] = output
86
- time.sleep(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  return history, history
88
 
89
 
@@ -107,5 +86,5 @@ with gr.Blocks() as demo:
107
  submit.click(fn=user, inputs=[msg, history], outputs=[msg, chatbot, history], queue=False).then(
108
  fn=bot, inputs=[chatbot, system_msg], outputs=[chatbot, history], queue=True)
109
  clear.click(lambda: [None, []], None, [chatbot, history], queue=False)
110
- demo.queue(concurrency_count=5)
111
- demo.launch()
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
4
  import time
5
  import numpy as np
6
  from torch.nn import functional as F
7
  import os
8
+ from threading import Thread
9
+
10
  print(f"Starting to load the model to memory")
11
  m = AutoModelForCausalLM.from_pretrained(
12
  "stabilityai/stablelm-tuned-alpha-7b", torch_dtype=torch.float16).cuda()
 
29
  return True
30
  return False
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def user(user_message, history):
33
  history = history + [[user_message, ""]]
34
  return "", history, history
35
 
36
 
37
  def bot(history, curr_system_message):
38
+ stop = StopOnTokens()
39
  messages = curr_system_message + \
40
  "".join(["".join(["<|USER|>"+item[0], "<|ASSISTANT|>"+item[1]])
41
  for item in history])
42
+
43
+ #model_inputs = tok([messages], return_tensors="pt")['input_ids'].cuda()[:, :4096-1024]
44
+ model_inputs = tok([messages], return_tensors="pt").to("cuda")
45
+ streamer = TextIteratorStreamer(tok, timeout=10., skip_prompt=True, skip_special_tokens=True)
46
+ generate_kwargs = dict(
47
+ model_inputs,
48
+ streamer=streamer,
49
+ max_new_tokens=1024,
50
+ do_sample=True,
51
+ top_p=0.95,
52
+ top_k=1000,
53
+ temperature=1.0,
54
+ num_beams=1,
55
+ stopping_criteria=StoppingCriteriaList([stop])
56
+ )
57
+ t = Thread(target=m.generate, kwargs=generate_kwargs)
58
+ t.start()
59
+
60
+ print(history)
61
+ for new_text in streamer:
62
+ print(new_text)
63
+ history[-1][1] += new_text
64
+ yield history, history
65
+
66
  return history, history
67
 
68
 
 
86
  submit.click(fn=user, inputs=[msg, history], outputs=[msg, chatbot, history], queue=False).then(
87
  fn=bot, inputs=[chatbot, system_msg], outputs=[chatbot, history], queue=True)
88
  clear.click(lambda: [None, []], None, [chatbot, history], queue=False)
89
+ demo.queue(concurrency_count=2)
90
+ demo.launch()