|
import time |
|
import os |
|
import gradio as gr |
|
from text_generation import Client |
|
from conversation import get_default_conv_template |
|
from transformers import AutoTokenizer |
|
|
|
|
|
endpoint_url = os.environ.get("ENDPOINT_URL", "http://127.0.0.1:8080") |
|
client = Client(endpoint_url, timeout=120) |
|
eos_token = "</s>" |
|
max_prompt_length = 4000 |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("yentinglin/Taiwan-LLaMa-v1.0") |
|
|
|
with gr.Blocks() as demo: |
|
chatbot = gr.Chatbot() |
|
msg = gr.Textbox() |
|
clear = gr.Button("Clear") |
|
|
|
def user(user_message, history): |
|
return "", history + [[user_message, None]] |
|
|
|
def bot(history): |
|
conv = get_default_conv_template("vicuna").copy() |
|
roles = {"human": conv.roles[0], "gpt": conv.roles[1]} |
|
for user, bot in history: |
|
conv.append_message(roles['human'], user) |
|
conv.append_message(roles["gpt"], bot) |
|
msg = conv.get_prompt() |
|
prompt_tokens = tokenizer.encode(msg) |
|
length_of_prompt = len(prompt_tokens) |
|
if length_of_prompt > max_prompt_length: |
|
msg = tokenizer.decode(prompt_tokens[-max_prompt_length:]) |
|
|
|
history[-1][1] = "" |
|
for response in client.generate_stream( |
|
msg, |
|
max_new_tokens=512, |
|
): |
|
if not response.token.special: |
|
character = response.token.text |
|
history[-1][1] += character |
|
yield history |
|
|
|
|
|
def generate_response(history, max_new_token=512, top_p=0.9, temperature=0.8, do_sample=True): |
|
conv = get_default_conv_template("vicuna").copy() |
|
roles = {"human": conv.roles[0], "gpt": conv.roles[1]} |
|
for user, bot in history: |
|
conv.append_message(roles['human'], user) |
|
conv.append_message(roles["gpt"], bot) |
|
msg = conv.get_prompt() |
|
|
|
for response in client.generate_stream( |
|
msg, |
|
max_new_tokens=max_new_token, |
|
top_p=top_p, |
|
temperature=temperature, |
|
do_sample=do_sample, |
|
): |
|
history[-1][1] = "" |
|
|
|
character = response.token.text |
|
history[-1][1] += character |
|
print(history[-1][1]) |
|
time.sleep(0.05) |
|
yield history |
|
|
|
|
|
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( |
|
bot, chatbot, chatbot |
|
) |
|
clear.click(lambda: None, None, chatbot, queue=False) |
|
|
|
demo.queue() |
|
demo.launch() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|