slush0's picture
Basics works, but still WIP; separators and examples need to be updated from bloom to llama2-related models.
c461bd0
import json
import re
import time
from datetime import datetime
import gradio as gr
import chat_client
CHAT_URL = "wss://chat.petals.dev/api/v2/generate"
#CHAT_URL='ws://localhost:8000/api/v2/generate'
EMPTY_STATE = {
"generate": False,
"model": None,
"client": None,
"history": [],
}
def generate(state, prompt, model, context, output, *args):
# Save that we're in generating loop
state["generate"] = True
try:
yield from _generate(state, prompt, model, context, output, *args)
except (json.decoder.JSONDecodeError, BrokenPipeError):
# Broken session, try to renew
# TODO This is a bit fragile because of recursive call...
print("Retrying session...")
context = output
output = ""
yield from generate(state, prompt, model, context, output, *args)
finally:
state["generate"] = False
def _generate(
state,
prompt,
model,
context,
output,
endseq,
max_length,
do_sample,
top_k,
top_p,
temperature,
):
start = time.time()
cnt = 0 # Tokens generated
def stats():
# Produces inline stats for generation speed
if cnt == 0:
return "\u2026 | ? sec/t"
if cnt > time.time() - start:
items_per_sec = cnt / (time.time() - start)
return f" | {items_per_sec:.1f} t/sec"
sec_per_item = (time.time() - start) / cnt
return f" | {sec_per_item:.1f} sec/t"
eos = "</s>\n" if "bloomz" in model else "\n\n"
if state["model"] != model and output:
# If the connection is resumed, output is truncated in generate().
# So this executes when user change model.
context = output
output = ""
# Update widgets even before we get the first response
print("prompt", prompt)
yield state, state["history"] + [[prompt, stats()]], "", output
if (
state["model"] != model
or state["client"] == None
or state["client"].is_session() == False
):
try:
state["client"] = chat_client.ModelClient(CHAT_URL)
state["client"].open_session(model, max_length)
state["model"] = model
except Exception as e:
print(datetime.now(), str(e)[-500:])
raise gr.Error(str(e)[-500:])
else:
context = ""
client = state["client"]
context += eos
# Fix eventual eos token mismatch and add eos token to context and prompt
if "bloomz" in model:
context = context.replace("\n\n", eos)
prompt2 = prompt.replace("\n\n", eos) + "</s>\n"
else:
context = context.replace("</s>", eos)
context = re.sub(r"\n\n+", "\n\n", context)
prompt2 = prompt.replace("</s>", eos) + "\n\n"
prompt2 = f"{context}Human: {prompt2}AI:"
# Translate checkbox items to actual sequences
seq = []
for s in endseq:
if s == "Human:":
seq.append("Human:")
if s == "AI:":
seq.append("AI:")
if s == "\\n":
seq.append("\n")
elif s == "</s>":
seq.append("</s>")
elif s == "? (question mark)":
seq.append("?")
elif s == ". (dot)":
seq.append(".")
# only top_k or top_p can be set
if top_k == 0:
top_k = None
if top_p == 0:
top_p = None
if top_p and top_k:
top_k = None
if temperature == 0:
temperature = 1.0
output += prompt2
orig_history = state["history"]
new_line = ""
try:
for out in client.generate(
prompt2,
max_new_tokens=1,
do_sample=do_sample,
temperature=temperature,
top_k=top_k,
top_p=top_p,
stop_sequences=seq,
):
if not state["generate"]:
client.close_session()
yield state, [], "", ""
# Stopping generation
return
cnt += 1
new_line += out
# Detect end sequences and finish the generation
# prematurely if found.
for s in seq:
spl = new_line.split(s)
new_line = spl[0]
if len(spl) > 1:
state["history"] = orig_history + [[prompt, new_line]]
output += new_line
yield state, state["history"], "", output
# Stopping generation
return
# Keep original history untouched as we're adding just
# a chunks at one moment.
state["history"] = orig_history + [[prompt, new_line + stats()]]
yield state, state["history"], "", output
# Avoid throwing an exception by generate()
# to prevent UI errors.
if cnt >= max_length - 6: # FIXME Bulgarian constant
break
# Final line w/o statistics
yield state, state["history"], "", output
except (json.decoder.JSONDecodeError, BrokenPipeError):
# Session was interrupted
# Handled in upstream func
client.close_session()
state["client"] = None
state["model"] = None
print("Broken session!")
raise
except Exception as e:
client.close_session()
state["client"] = None
state["model"] = None
print(datetime.now(), str(e)[-500:])
raise gr.Error(str(e)[-500:])
def reset(state):
"""Resets the session and clears the chat window."""
state.update(EMPTY_STATE)
return state, [], ""
# ---------------------------------------------------------
# Defining Gradio layout
with gr.Blocks() as iface_chat:
gr.Markdown("""**Let's talk to AI in a chat!**""")
with gr.Row():
model = gr.Radio(
["stabilityai/StableBeluga2", "meta-llama/Llama-2-70b-chat-hf", "bigscience/bloomz"], value="stabilityai/StableBeluga2", label="Use model"
)
# Additional ending sequence, at which generation shoud stop
endseq = gr.CheckboxGroup(
["Human:", "AI:", "\\n", "</s>", "? (question mark)", ". (dot)"],
value=["Human:", "AI:", "</s>"],
label="Extra end sequences",
)
# Maximum length of inference session
max_length = gr.Radio(
[64, 128, 256, 512, 1024, 2048],
value=1024,
interactive=True,
label="Max length",
)
with gr.Row():
with gr.Column():
# Switch between sampling and greedy generation
do_sample = gr.Checkbox(value=True, interactive=True, label="do_sample")
context = gr.Textbox(
lines=3,
label="Initial context:",
interactive=True,
value="A Human talks to a powerful AI that follows "
"the Human's instructions.\n"
"AI is talkative, friendly, positive and provides "
"detailed answers to any question.</s>\n"
"Human: Hi!</s>\n"
"AI: How can I help you?",
)
# Only one of top_k and top_p can be set. Requires "do_sample=True" to work.
top_k = gr.Number(value=0, precision=0, interactive=True, label="top_k")
top_p = gr.Number(value=0.9, precision=2, interactive=True, label="top_p")
# TODO num_beams
# Generation temperature
temperature = gr.Number(
value=0.75, precision=2, interactive=True, label="Temperature"
)
chat = gr.Chatbot(label="Chat window")
prompt = gr.Textbox(
show_label=False, label="Prompt", placeholder="Prompt Here and press Enter..."
).style(container=False)
with gr.Row():
button_generate = gr.Button("Generate")
button_reset = gr.Button("Reset session")
with gr.Accordion("Raw prompt log", open=False):
output = gr.Textbox(lines=3, show_label=False).style(container=False)
# Chat history
state = gr.State(EMPTY_STATE)
# Define button actions
inputs = [
state,
prompt,
model,
context,
output,
endseq,
max_length,
do_sample,
top_k,
top_p,
temperature,
]
outputs = [state, chat, prompt, output]
prompt.submit(generate, inputs=inputs, outputs=outputs)
button_generate.click(generate, inputs=inputs, outputs=outputs)
button_reset.click(reset, inputs=[state], outputs=[state, chat, output])
examples = gr.Examples(
inputs=[context, prompt, model, do_sample, top_k, top_p, temperature],
examples=[
[
"Human talks to a powerful AI that follows the Human's instructions. "
"AI is a smart, talkative, friendly, honest, helpful, harmless assistant to Human. "
"AI has instant access to an online encyclopedia containing all the facts about the world "
"and answers any question in detail. AI never says common misconceptions, "
"outdated information, lies, fiction, myths, jokes, or memes.</s>\n"
"AI: Hi! How can I help you?</s>\n",
"Could you remind me please who was Neil Armstrong?",
"stabilityai/StableBeluga2",
True,
0,
0.9,
0.75,
],
[
"Human mluví s mocnou, inteligentní a vševědoucí AI, která plní instrukce od Human. "
"AI je výřečná, přátelská, pozitivní a poskytuje detailní odpovědi na jakoukoliv otázku.</s>\n"
"Human: Ahoj!</s>\n"
"AI: Ahoj! Jak ti mohu pomoci?",
"Můžeš mi prosím připomenout, kdo byl Neil Armstrong?",
"stabilityai/StableBeluga2",
True,
0,
0.9,
0.75,
],
],
)