import os
from string import Template
from threading import Thread
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
auth_token = os.environ.get("HUGGINGFACE_TOKEN")
tokenizer = AutoTokenizer.from_pretrained(
"CarperAI/vicuna-13b-fine-tuned-rlhf",
use_auth_token=auth_token if auth_token else True,
)
model = AutoModelForCausalLM.from_pretrained(
"CarperAI/vicuna-13b-fine-tuned-rlhf-8bit",
use_auth_token=auth_token if auth_token else True,
)
model.cuda()
max_context_length = model.config.max_position_embeddings
max_new_tokens = 512
prompt_template = Template("""\
### Human: $human
### Assistant: $bot\
""")
def bot(history):
# print(f"History:\n`{history}`")
history = history or []
# Hack to inject prompt formatting into the history
prompt_history = []
for human, bot in history:
if bot is not None:
bot = bot.replace("
", "\n")
bot = bot.rstrip()
prompt_history.append(
prompt_template.substitute(
human=human, bot=bot if bot is not None else "")
)
messages = "\n\n".join(prompt_history)
messages = messages.rstrip()
# print(f"Messages:\n{messages}")
# Use only the most recent context up to the maximum context length with room left over
# for the max new tokens
inputs = tokenizer(messages, return_tensors='pt').to('cuda')
inputs = {k: v[:, -max_context_length + max_new_tokens:]
for k, v in inputs.items()}
if inputs.get("token_type_ids", None) is not None:
inputs.pop("token_type_ids")
# print(f"Inputs: {inputs}")
streamer = TextIteratorStreamer(
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
# Generate the response
generate_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=1.0,
top_p=0.9999,
)
# print(f"Generating with kwargs: {generate_kwargs}")
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
partial_text = ""
for new_text in streamer:
# Process out the prompt separator. NOTE: we should tune with special tokens for this
new_text = new_text.replace("
", "\n")
# print(f"New text: `{new_text}`")
if "###" in new_text:
new_text = new_text.split("###")[0]
partial_text += new_text.strip()
history[-1][1] = partial_text
break
else:
# Filter empty trailing whitespaces
if new_text.isspace():
new_text = new_text.strip()
partial_text += new_text
history[-1][1] = partial_text
yield history
return partial_text
def user(user_message, history):
return "", history + [[user_message, None]]
with gr.Blocks() as demo:
gr.Markdown("Chat-RLHF by CarperAI")
gr.HTML("CarperAI/vicuna-13b-fine-tuned-rlhf
")
gr.HTML('''