StableVicuna / app.py
jon-tow's picture
fix: add system prompt
1d56001
raw
history blame
4.78 kB
import os
import gc
from string import Template
from threading import Thread
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, BatchEncoding, TextIteratorStreamer
auth_token = os.environ.get("HUGGINGFACE_TOKEN")
tokenizer = AutoTokenizer.from_pretrained(
"CarperAI/stable-vicuna-13b-fp16",
use_auth_token=auth_token if auth_token else True,
)
model = AutoModelForCausalLM.from_pretrained(
"CarperAI/stable-vicuna-13b-fp16",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="auto",
use_auth_token=auth_token if auth_token else True,
)
model.eval()
max_context_length = model.config.max_position_embeddings
max_new_tokens = 768
prompt_template = Template("""\
### Human: $human
### Assistant: $bot\
""")
system_prompt = "### Assistant: I am StableVicuna, a large language model created by Stability AI. I am here to chat!"
system_prompt_tokens = tokenizer([f"{system_prompt}\n\n"], return_tensors="pt")
max_sys_tokens = system_prompt_tokens['input_ids'].size(-1)
def bot(history):
history = history or []
# Inject prompt formatting into the history
prompt_history = []
for human, bot in history:
if bot is not None:
bot = bot.replace("<br>", "\n")
bot = bot.rstrip()
prompt_history.append(
prompt_template.substitute(
human=human, bot=bot if bot is not None else "")
)
msg_tokens = tokenizer(
"\n\n".join(prompt_history).strip(),
return_tensors="pt",
add_special_tokens=False # Use <BOS> from the system prompt
)
# Take only the most recent context up to the max context length and prepend the
# system prompt with the messages
max_tokens = -max_context_length + max_new_tokens + max_sys_tokens
inputs = BatchEncoding({
k: torch.concat([system_prompt_tokens[k], msg_tokens[k][:, max_tokens:]], dim=-1)
for k in msg_tokens
}).to('cuda')
# Remove `token_type_ids` b/c it's not yet supported for LLaMA `transformers` models
if inputs.get("token_type_ids", None) is not None:
inputs.pop("token_type_ids")
streamer = TextIteratorStreamer(
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=1.0,
temperature=1.0,
)
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
partial_text = ""
for new_text in streamer:
# Process out the prompt separator
new_text = new_text.replace("<br>", "\n")
if "###" in new_text:
new_text = new_text.split("###")[0]
partial_text += new_text.strip()
history[-1][1] = partial_text
break
else:
# Filter empty trailing new lines
if new_text == "\n":
new_text = new_text.strip()
partial_text += new_text
history[-1][1] = partial_text
yield history
return partial_text
def user(user_message, history):
return "", history + [[user_message, None]]
with gr.Blocks() as demo:
gr.Markdown("StableVicuna by Stability AI")
gr.HTML("<a href='https://huggingface.co/stabilityai/stable-vicuna-13b-delta'><code>stabilityai/stable-vicuna-13b-delta</a>")
gr.HTML('''<center><a href="https://huggingface.co/spaces/stabilityai/stable-vicuna?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate the Space to skip the queue and run in a private space</center>''')
chatbot = gr.Chatbot([], elem_id="chatbot").style(height=500)
state = gr.State([])
with gr.Row():
with gr.Column():
msg = gr.Textbox(
label="Send a message",
placeholder="Send a message",
show_label=False
).style(container=False)
with gr.Column():
with gr.Row():
submit = gr.Button("Send")
stop = gr.Button("Stop")
clear = gr.Button("Clear History")
submit_event = msg.submit(user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then(
fn=bot, inputs=[chatbot], outputs=[chatbot], queue=True)
submit_click_event = submit.click(user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then(
fn=bot, inputs=[chatbot], outputs=[chatbot], queue=True)
stop.click(fn=None, inputs=None, outputs=None, cancels=[submit_event, submit_click_event], queue=False)
clear.click(lambda: None, None, [chatbot], queue=True)
demo.queue(max_size=32, concurrency_count=2)
demo.launch(share=True)