import os
import gc
from string import Template
from threading import Thread
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, BatchEncoding, TextIteratorStreamer
auth_token = os.environ.get("HUGGINGFACE_TOKEN")
tokenizer = AutoTokenizer.from_pretrained(
"stabilityai/stable-vicuna-13b-fp16",
use_auth_token=auth_token if auth_token else True,
)
model = AutoModelForCausalLM.from_pretrained(
"stabilityai/stable-vicuna-13b-fp16",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="auto",
use_auth_token=auth_token if auth_token else True,
)
model.eval()
max_context_length = model.config.max_position_embeddings
max_new_tokens = 768
prompt_template = Template("""\
### Human: $human
### Assistant: $bot\
""")
system_prompt = "### Assistant: I am StableVicuna, a large language model created by Stability AI. I am here to chat!"
system_prompt_tokens = tokenizer([f"{system_prompt}\n\n"], return_tensors="pt")
max_sys_tokens = system_prompt_tokens['input_ids'].size(-1)
def bot(history):
history = history or []
# Inject prompt formatting into the history
prompt_history = []
for human, bot in history:
if bot is not None:
bot = bot.replace("
", "\n")
bot = bot.rstrip()
prompt_history.append(
prompt_template.substitute(
human=human, bot=bot if bot is not None else "")
)
msg_tokens = tokenizer(
"\n\n".join(prompt_history).strip(),
return_tensors="pt",
add_special_tokens=False # Use from the system prompt
)
# Take only the most recent context up to the max context length and prepend the
# system prompt with the messages
max_tokens = -max_context_length + max_new_tokens + max_sys_tokens
inputs = BatchEncoding({
k: torch.concat([system_prompt_tokens[k], msg_tokens[k][:, max_tokens:]], dim=-1)
for k in msg_tokens
}).to('cuda')
# Remove `token_type_ids` b/c it's not yet supported for LLaMA `transformers` models
if inputs.get("token_type_ids", None) is not None:
inputs.pop("token_type_ids")
streamer = TextIteratorStreamer(
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=1.0,
temperature=1.0,
)
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
partial_text = ""
for new_text in streamer:
# Process out the prompt separator
new_text = new_text.replace("
", "\n")
if "###" in new_text:
new_text = new_text.split("###")[0]
partial_text += new_text.strip()
history[-1][1] = partial_text
break
else:
# Filter empty trailing new lines
if new_text == "\n":
new_text = new_text.strip()
partial_text += new_text
history[-1][1] = partial_text
yield history
return partial_text
def user(user_message, history):
return "", history + [[user_message, None]]
with gr.Blocks() as demo:
gr.Markdown("StableVicuna by Stability AI")
gr.HTML("stabilityai/stable-vicuna-13b-delta
")
gr.HTML('''Duplicate the Space to skip the queue and run in a private space''')
chatbot = gr.Chatbot([], elem_id="chatbot").style(height=500)
state = gr.State([])
with gr.Row():
with gr.Column():
msg = gr.Textbox(
label="Send a message",
placeholder="Send a message",
show_label=False
).style(container=False)
with gr.Column():
with gr.Row():
submit = gr.Button("Send")
stop = gr.Button("Stop")
clear = gr.Button("Clear History")
submit_event = msg.submit(user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then(
fn=bot, inputs=[chatbot], outputs=[chatbot], queue=True)
submit_click_event = submit.click(user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then(
fn=bot, inputs=[chatbot], outputs=[chatbot], queue=True)
stop.click(fn=None, inputs=None, outputs=None, cancels=[submit_event, submit_click_event], queue=False)
clear.click(lambda: None, None, [chatbot], queue=True)
demo.queue(max_size=32, concurrency_count=2)
demo.launch()