import os from string import Template from threading import Thread import torch import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer auth_token = os.environ.get("HUGGINGFACE_TOKEN") tokenizer = AutoTokenizer.from_pretrained( "CarperAI/vicuna-13b-fine-tuned-rlhf", use_auth_token=auth_token if auth_token else True, ) model = AutoModelForCausalLM.from_pretrained( "CarperAI/vicuna-13b-fine-tuned-rlhf-8bit", use_auth_token=auth_token if auth_token else True, ) max_context_length = model.config.max_position_embeddings max_new_tokens = 512 prompt_template = Template("""\ ### Human: $human ### Assistant: $bot\ """) def bot(history): # print(f"History:\n`{history}`") history = history or [] # Hack to inject prompt formatting into the history prompt_history = [] for human, bot in history: if bot is not None: bot = bot.replace("
", "\n") bot = bot.rstrip() prompt_history.append( prompt_template.substitute( human=human, bot=bot if bot is not None else "") ) messages = "\n\n".join(prompt_history) messages = messages.rstrip() # print(f"Messages:\n{messages}") # Use only the most recent context up to the maximum context length with room left over # for the max new tokens inputs = tokenizer(messages, return_tensors='pt').to('cuda') inputs = {k: v[:, -max_context_length + max_new_tokens:] for k, v in inputs.items()} if inputs.get("token_type_ids", None) is not None: inputs.pop("token_type_ids") # print(f"Inputs: {inputs}") streamer = TextIteratorStreamer( tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True ) # Generate the response generate_kwargs = dict( inputs, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, temperature=1.0, top_p=0.9999, ) # print(f"Generating with kwargs: {generate_kwargs}") thread = Thread(target=model.generate, kwargs=generate_kwargs) thread.start() partial_text = "" for new_text in streamer: # Process out the prompt separator. NOTE: we should tune with special tokens for this new_text = new_text.replace("
", "\n") # print(f"New text: `{new_text}`") if "###" in new_text: new_text = new_text.split("###")[0] partial_text += new_text.strip() history[-1][1] = partial_text break else: # Filter empty trailing whitespaces if new_text.isspace(): new_text = new_text.strip() partial_text += new_text history[-1][1] = partial_text yield history return partial_text def user(user_message, history): return "", history + [[user_message, None]] with gr.Blocks() as demo: gr.Markdown("Chat-RLHF by CarperAI") gr.HTML("CarperAI/vicuna-13b-fine-tuned-rlhf") gr.HTML('''
Duplicate SpaceDuplicate the Space to skip the queue and run in a private space
''') chatbot = gr.Chatbot([], elem_id="chatbot").style(height=512) state = gr.State([]) with gr.Row(): with gr.Column(): msg = gr.Textbox(label="Chat Message Box", placeholder="Chat Message Box", show_label=False).style(container=False) with gr.Column(): with gr.Row(): submit = gr.Button("Submit") stop = gr.Button("Stop") clear = gr.Button("Clear") submit_event = msg.submit(user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=True).then( bot, chatbot, chatbot) submit_click_event = submit.click(user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=True).then( bot, chatbot, chatbot) stop.click(fn=None, inputs=None, outputs=None, cancels=[ submit_event, submit_click_event], queue=False) clear.click(lambda: None, None, chatbot, queue=True) demo.queue(max_size=32, concurrency_count=2) demo.launch()