jon-tow commited on
Commit
1d56001
1 Parent(s): 75cc022

fix: add system prompt

Browse files
Files changed (1) hide show
  1. app.py +54 -41
app.py CHANGED
@@ -1,25 +1,30 @@
1
  import os
 
2
  from string import Template
3
  from threading import Thread
4
 
5
  import torch
6
  import gradio as gr
7
- from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
8
 
9
 
10
  auth_token = os.environ.get("HUGGINGFACE_TOKEN")
11
  tokenizer = AutoTokenizer.from_pretrained(
12
- "CarperAI/vicuna-13b-fine-tuned-rlhf",
13
  use_auth_token=auth_token if auth_token else True,
14
  )
15
  model = AutoModelForCausalLM.from_pretrained(
16
- "CarperAI/vicuna-13b-fine-tuned-rlhf-8bit",
 
 
 
17
  use_auth_token=auth_token if auth_token else True,
18
  )
 
19
 
20
 
21
  max_context_length = model.config.max_position_embeddings
22
- max_new_tokens = 512
23
 
24
 
25
  prompt_template = Template("""\
@@ -28,10 +33,15 @@ prompt_template = Template("""\
28
  """)
29
 
30
 
 
 
 
 
 
31
  def bot(history):
32
- # print(f"History:\n`{history}`")
33
  history = history or []
34
- # Hack to inject prompt formatting into the history
 
35
  prompt_history = []
36
  for human, bot in history:
37
  if bot is not None:
@@ -42,54 +52,53 @@ def bot(history):
42
  human=human, bot=bot if bot is not None else "")
43
  )
44
 
45
- messages = "\n\n".join(prompt_history)
46
- messages = messages.rstrip()
47
- # print(f"Messages:\n{messages}")
 
 
48
 
49
- # Use only the most recent context up to the maximum context length with room left over
50
- # for the max new tokens
51
- inputs = tokenizer(messages, return_tensors='pt').to('cuda')
52
- inputs = {k: v[:, -max_context_length + max_new_tokens:]
53
- for k, v in inputs.items()}
 
 
 
54
  if inputs.get("token_type_ids", None) is not None:
55
  inputs.pop("token_type_ids")
56
- # print(f"Inputs: {inputs}")
57
  streamer = TextIteratorStreamer(
58
  tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
59
  )
60
-
61
- # Generate the response
62
  generate_kwargs = dict(
63
  inputs,
64
  streamer=streamer,
65
  max_new_tokens=max_new_tokens,
66
  do_sample=True,
 
67
  temperature=1.0,
68
- top_p=0.9999,
69
  )
70
-
71
- # print(f"Generating with kwargs: {generate_kwargs}")
72
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
73
  thread.start()
74
 
75
  partial_text = ""
76
  for new_text in streamer:
77
- # Process out the prompt separator. NOTE: we should tune with special tokens for this
78
  new_text = new_text.replace("<br>", "\n")
79
- # print(f"New text: `{new_text}`")
80
  if "###" in new_text:
81
  new_text = new_text.split("###")[0]
82
  partial_text += new_text.strip()
83
  history[-1][1] = partial_text
84
  break
85
  else:
86
- # Filter empty trailing whitespaces
87
- if new_text.isspace():
88
  new_text = new_text.strip()
89
  partial_text += new_text
90
  history[-1][1] = partial_text
91
  yield history
92
-
93
  return partial_text
94
 
95
 
@@ -98,28 +107,32 @@ def user(user_message, history):
98
 
99
 
100
  with gr.Blocks() as demo:
101
- gr.Markdown("Chat-RLHF by CarperAI")
102
- gr.HTML("<a href='https://huggingface.co/CarperAI/vicuna-13b-fine-tuned-rlhf'><code>CarperAI/vicuna-13b-fine-tuned-rlhf</a>")
103
- gr.HTML('''<center><a href="https://huggingface.co/spaces/CarperAI/chat-rlhf?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate the Space to skip the queue and run in a private space</center>''')
104
 
105
- chatbot = gr.Chatbot([], elem_id="chatbot").style(height=512)
106
  state = gr.State([])
107
  with gr.Row():
108
  with gr.Column():
109
- msg = gr.Textbox(label="Chat Message Box", placeholder="Chat Message Box",
110
- show_label=False).style(container=False)
 
 
 
111
  with gr.Column():
112
  with gr.Row():
113
- submit = gr.Button("Submit")
114
  stop = gr.Button("Stop")
115
- clear = gr.Button("Clear")
116
- submit_event = msg.submit(user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=True).then(
117
- bot, chatbot, chatbot)
118
- submit_click_event = submit.click(user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=True).then(
119
- bot, chatbot, chatbot)
120
- stop.click(fn=None, inputs=None, outputs=None, cancels=[
121
- submit_event, submit_click_event], queue=False)
122
- clear.click(lambda: None, None, chatbot, queue=True)
 
123
 
124
  demo.queue(max_size=32, concurrency_count=2)
125
- demo.launch()
 
1
  import os
2
+ import gc
3
  from string import Template
4
  from threading import Thread
5
 
6
  import torch
7
  import gradio as gr
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BatchEncoding, TextIteratorStreamer
9
 
10
 
11
  auth_token = os.environ.get("HUGGINGFACE_TOKEN")
12
  tokenizer = AutoTokenizer.from_pretrained(
13
+ "CarperAI/stable-vicuna-13b-fp16",
14
  use_auth_token=auth_token if auth_token else True,
15
  )
16
  model = AutoModelForCausalLM.from_pretrained(
17
+ "CarperAI/stable-vicuna-13b-fp16",
18
+ torch_dtype=torch.float16,
19
+ low_cpu_mem_usage=True,
20
+ device_map="auto",
21
  use_auth_token=auth_token if auth_token else True,
22
  )
23
+ model.eval()
24
 
25
 
26
  max_context_length = model.config.max_position_embeddings
27
+ max_new_tokens = 768
28
 
29
 
30
  prompt_template = Template("""\
 
33
  """)
34
 
35
 
36
+ system_prompt = "### Assistant: I am StableVicuna, a large language model created by Stability AI. I am here to chat!"
37
+ system_prompt_tokens = tokenizer([f"{system_prompt}\n\n"], return_tensors="pt")
38
+ max_sys_tokens = system_prompt_tokens['input_ids'].size(-1)
39
+
40
+
41
  def bot(history):
 
42
  history = history or []
43
+
44
+ # Inject prompt formatting into the history
45
  prompt_history = []
46
  for human, bot in history:
47
  if bot is not None:
 
52
  human=human, bot=bot if bot is not None else "")
53
  )
54
 
55
+ msg_tokens = tokenizer(
56
+ "\n\n".join(prompt_history).strip(),
57
+ return_tensors="pt",
58
+ add_special_tokens=False # Use <BOS> from the system prompt
59
+ )
60
 
61
+ # Take only the most recent context up to the max context length and prepend the
62
+ # system prompt with the messages
63
+ max_tokens = -max_context_length + max_new_tokens + max_sys_tokens
64
+ inputs = BatchEncoding({
65
+ k: torch.concat([system_prompt_tokens[k], msg_tokens[k][:, max_tokens:]], dim=-1)
66
+ for k in msg_tokens
67
+ }).to('cuda')
68
+ # Remove `token_type_ids` b/c it's not yet supported for LLaMA `transformers` models
69
  if inputs.get("token_type_ids", None) is not None:
70
  inputs.pop("token_type_ids")
71
+
72
  streamer = TextIteratorStreamer(
73
  tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
74
  )
 
 
75
  generate_kwargs = dict(
76
  inputs,
77
  streamer=streamer,
78
  max_new_tokens=max_new_tokens,
79
  do_sample=True,
80
+ top_p=1.0,
81
  temperature=1.0,
 
82
  )
 
 
83
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
84
  thread.start()
85
 
86
  partial_text = ""
87
  for new_text in streamer:
88
+ # Process out the prompt separator
89
  new_text = new_text.replace("<br>", "\n")
 
90
  if "###" in new_text:
91
  new_text = new_text.split("###")[0]
92
  partial_text += new_text.strip()
93
  history[-1][1] = partial_text
94
  break
95
  else:
96
+ # Filter empty trailing new lines
97
+ if new_text == "\n":
98
  new_text = new_text.strip()
99
  partial_text += new_text
100
  history[-1][1] = partial_text
101
  yield history
 
102
  return partial_text
103
 
104
 
 
107
 
108
 
109
  with gr.Blocks() as demo:
110
+ gr.Markdown("StableVicuna by Stability AI")
111
+ gr.HTML("<a href='https://huggingface.co/stabilityai/stable-vicuna-13b-delta'><code>stabilityai/stable-vicuna-13b-delta</a>")
112
+ gr.HTML('''<center><a href="https://huggingface.co/spaces/stabilityai/stable-vicuna?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate the Space to skip the queue and run in a private space</center>''')
113
 
114
+ chatbot = gr.Chatbot([], elem_id="chatbot").style(height=500)
115
  state = gr.State([])
116
  with gr.Row():
117
  with gr.Column():
118
+ msg = gr.Textbox(
119
+ label="Send a message",
120
+ placeholder="Send a message",
121
+ show_label=False
122
+ ).style(container=False)
123
  with gr.Column():
124
  with gr.Row():
125
+ submit = gr.Button("Send")
126
  stop = gr.Button("Stop")
127
+ clear = gr.Button("Clear History")
128
+
129
+ submit_event = msg.submit(user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then(
130
+ fn=bot, inputs=[chatbot], outputs=[chatbot], queue=True)
131
+ submit_click_event = submit.click(user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then(
132
+ fn=bot, inputs=[chatbot], outputs=[chatbot], queue=True)
133
+
134
+ stop.click(fn=None, inputs=None, outputs=None, cancels=[submit_event, submit_click_event], queue=False)
135
+ clear.click(lambda: None, None, [chatbot], queue=True)
136
 
137
  demo.queue(max_size=32, concurrency_count=2)
138
+ demo.launch(share=True)