import json import re import time from datetime import datetime import gradio as gr import chat_client CHAT_URL = "wss://chat.petals.dev/api/v2/generate" #CHAT_URL='ws://localhost:8000/api/v2/generate' EMPTY_STATE = { "generate": False, "model": None, "client": None, "history": [], } def generate(state, prompt, model, context, output, *args): # Save that we're in generating loop state["generate"] = True try: yield from _generate(state, prompt, model, context, output, *args) except (json.decoder.JSONDecodeError, BrokenPipeError): # Broken session, try to renew # TODO This is a bit fragile because of recursive call... print("Retrying session...") context = output output = "" yield from generate(state, prompt, model, context, output, *args) finally: state["generate"] = False def _generate( state, prompt, model, context, output, endseq, max_length, do_sample, top_k, top_p, temperature, ): start = time.time() cnt = 0 # Tokens generated def stats(): # Produces inline stats for generation speed if cnt == 0: return "\u2026 | ? sec/t" if cnt > time.time() - start: items_per_sec = cnt / (time.time() - start) return f" | {items_per_sec:.1f} t/sec" sec_per_item = (time.time() - start) / cnt return f" | {sec_per_item:.1f} sec/t" eos = "\n" if "bloomz" in model else "\n\n" if state["model"] != model and output: # If the connection is resumed, output is truncated in generate(). # So this executes when user change model. context = output output = "" # Update widgets even before we get the first response print("prompt", prompt) yield state, state["history"] + [[prompt, stats()]], "", output if ( state["model"] != model or state["client"] == None or state["client"].is_session() == False ): try: state["client"] = chat_client.ModelClient(CHAT_URL) state["client"].open_session(model, max_length) state["model"] = model except Exception as e: print(datetime.now(), str(e)[-500:]) raise gr.Error(str(e)[-500:]) else: context = "" client = state["client"] context += eos # Fix eventual eos token mismatch and add eos token to context and prompt if "bloomz" in model: context = context.replace("\n\n", eos) prompt2 = prompt.replace("\n\n", eos) + "\n" else: context = context.replace("", eos) context = re.sub(r"\n\n+", "\n\n", context) prompt2 = prompt.replace("", eos) + "\n\n" prompt2 = f"{context}Human: {prompt2}AI:" # Translate checkbox items to actual sequences seq = [] for s in endseq: if s == "Human:": seq.append("Human:") if s == "AI:": seq.append("AI:") if s == "\\n": seq.append("\n") elif s == "": seq.append("") elif s == "? (question mark)": seq.append("?") elif s == ". (dot)": seq.append(".") # only top_k or top_p can be set if top_k == 0: top_k = None if top_p == 0: top_p = None if top_p and top_k: top_k = None if temperature == 0: temperature = 1.0 output += prompt2 orig_history = state["history"] new_line = "" try: for out in client.generate( prompt2, max_new_tokens=1, do_sample=do_sample, temperature=temperature, top_k=top_k, top_p=top_p, stop_sequences=seq, ): if not state["generate"]: client.close_session() yield state, [], "", "" # Stopping generation return cnt += 1 new_line += out # Detect end sequences and finish the generation # prematurely if found. for s in seq: spl = new_line.split(s) new_line = spl[0] if len(spl) > 1: state["history"] = orig_history + [[prompt, new_line]] output += new_line yield state, state["history"], "", output # Stopping generation return # Keep original history untouched as we're adding just # a chunks at one moment. state["history"] = orig_history + [[prompt, new_line + stats()]] yield state, state["history"], "", output # Avoid throwing an exception by generate() # to prevent UI errors. if cnt >= max_length - 6: # FIXME Bulgarian constant break # Final line w/o statistics yield state, state["history"], "", output except (json.decoder.JSONDecodeError, BrokenPipeError): # Session was interrupted # Handled in upstream func client.close_session() state["client"] = None state["model"] = None print("Broken session!") raise except Exception as e: client.close_session() state["client"] = None state["model"] = None print(datetime.now(), str(e)[-500:]) raise gr.Error(str(e)[-500:]) def reset(state): """Resets the session and clears the chat window.""" state.update(EMPTY_STATE) return state, [], "" # --------------------------------------------------------- # Defining Gradio layout with gr.Blocks() as iface_chat: gr.Markdown("""**Let's talk to AI in a chat!**""") with gr.Row(): model = gr.Radio( ["stabilityai/StableBeluga2", "meta-llama/Llama-2-70b-chat-hf", "bigscience/bloomz"], value="stabilityai/StableBeluga2", label="Use model" ) # Additional ending sequence, at which generation shoud stop endseq = gr.CheckboxGroup( ["Human:", "AI:", "\\n", "", "? (question mark)", ". (dot)"], value=["Human:", "AI:", ""], label="Extra end sequences", ) # Maximum length of inference session max_length = gr.Radio( [64, 128, 256, 512, 1024, 2048], value=1024, interactive=True, label="Max length", ) with gr.Row(): with gr.Column(): # Switch between sampling and greedy generation do_sample = gr.Checkbox(value=True, interactive=True, label="do_sample") context = gr.Textbox( lines=3, label="Initial context:", interactive=True, value="A Human talks to a powerful AI that follows " "the Human's instructions.\n" "AI is talkative, friendly, positive and provides " "detailed answers to any question.\n" "Human: Hi!\n" "AI: How can I help you?", ) # Only one of top_k and top_p can be set. Requires "do_sample=True" to work. top_k = gr.Number(value=0, precision=0, interactive=True, label="top_k") top_p = gr.Number(value=0.9, precision=2, interactive=True, label="top_p") # TODO num_beams # Generation temperature temperature = gr.Number( value=0.75, precision=2, interactive=True, label="Temperature" ) chat = gr.Chatbot(label="Chat window") prompt = gr.Textbox( show_label=False, label="Prompt", placeholder="Prompt Here and press Enter..." ).style(container=False) with gr.Row(): button_generate = gr.Button("Generate") button_reset = gr.Button("Reset session") with gr.Accordion("Raw prompt log", open=False): output = gr.Textbox(lines=3, show_label=False).style(container=False) # Chat history state = gr.State(EMPTY_STATE) # Define button actions inputs = [ state, prompt, model, context, output, endseq, max_length, do_sample, top_k, top_p, temperature, ] outputs = [state, chat, prompt, output] prompt.submit(generate, inputs=inputs, outputs=outputs) button_generate.click(generate, inputs=inputs, outputs=outputs) button_reset.click(reset, inputs=[state], outputs=[state, chat, output]) examples = gr.Examples( inputs=[context, prompt, model, do_sample, top_k, top_p, temperature], examples=[ [ "Human talks to a powerful AI that follows the Human's instructions. " "AI is a smart, talkative, friendly, honest, helpful, harmless assistant to Human. " "AI has instant access to an online encyclopedia containing all the facts about the world " "and answers any question in detail. AI never says common misconceptions, " "outdated information, lies, fiction, myths, jokes, or memes.\n" "AI: Hi! How can I help you?\n", "Could you remind me please who was Neil Armstrong?", "stabilityai/StableBeluga2", True, 0, 0.9, 0.75, ], [ "Human mluví s mocnou, inteligentní a vševědoucí AI, která plní instrukce od Human. " "AI je výřečná, přátelská, pozitivní a poskytuje detailní odpovědi na jakoukoliv otázku.\n" "Human: Ahoj!\n" "AI: Ahoj! Jak ti mohu pomoci?", "Můžeš mi prosím připomenout, kdo byl Neil Armstrong?", "stabilityai/StableBeluga2", True, 0, 0.9, 0.75, ], ], )