BLOOMZ_Compare / app.py
gururise's picture
add additional debug info
a74cf5e
raw
history blame
7.8 kB
import gradio as gr
import threading
import codecs
from datetime import datetime
from transformers import BloomTokenizerFast
from petals.client import DistributedBloomForCausalLM
import torch
import time
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TORCH_DTYPE = torch.bfloat16
MODEL_NAMES = ["bigscience/bloom-petals", "bigscience/bloomz-petals"]
models = {MODEL_NAMES[0]:None,MODEL_NAMES[1]:None}
output = {MODEL_NAMES[0]:"",MODEL_NAMES[1]:""}
kill = threading.Event()
def stop_threads():
global kill
print("Force stopping threads")
kill.set()
def gen_thread(model_name, prompt, max_tokens, temperature, top_p, repetition_penalty, stop):
global output
if kill.is_set():
return
flag = False
token_cnt = 0
with models[model_name][1].inference_session(max_length=512) as sess:
print(f"Thread Start -> {threading.get_ident()}")
output[model_name] = ""
inputs = models[model_name][0](prompt, return_tensors="pt")["input_ids"].to(DEVICE)
n_input_tokens = inputs.shape[1]
done = False
while not done and not kill.is_set():
outputs = models[model_name][1].generate(
inputs,
max_new_tokens=1,
do_sample=True,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
session=sess
)
output[model_name] += models[model_name][0].decode(outputs[0, n_input_tokens:])
token_cnt += 1
print("\n["+ str(threading.get_ident()) + "]" + output[model_name], end="", flush=True)
for stop_word in stop:
stop_word = codecs.getdecoder("unicode_escape")(stop_word)[0]
if stop_word != '' and stop_word in output[model_name]:
print(f"\nDONE (stop) -> {threading.get_ident()}")
done = True
if flag or (token_cnt >= max_tokens):
print(f"\nDONE (max tokens) -> {threading.get_ident()}")
done = True
inputs = None # Prefix is passed only for the 1st token of the bot's response
n_input_tokens = 0
print(f"\nThread End -> {threading.get_ident()}")
def to_md(text):
return text.replace("\n", "<br />")
threads = list()
def infer(
prompt,
model_idx = ["BLOOM","BLOOMZ"],
max_new_tokens=10,
temperature=0.1,
top_p=1.0,
repetition_penalty = 1.0,
stop="\n",
num_completions=1,
seed=42,
):
global threads
global output
global models
if len(model_idx) == 0:
return
kill.clear()
print("Loading Models\n")
for idx in model_idx:
model_name = MODEL_NAMES[idx]
if models[model_name] == None:
print ("Initializing " + model_name)
tokenizer = BloomTokenizerFast.from_pretrained(model_name)
model = DistributedBloomForCausalLM.from_pretrained(model_name, torch_dtype=TORCH_DTYPE)
model = model.to(DEVICE)
models[model_name] = tokenizer, model
output[model_name] = ""
max_new_tokens = int(max_new_tokens)
temperature = float(temperature)
top_p = float(top_p)
stop = [x.strip(' ') for x in stop.split(',')]
repetition_penalty = float(repetition_penalty)
seed = seed
assert 1 <= max_new_tokens <= 384
assert 1 <= num_completions <= 5
assert 0.0 <= temperature <= 1.0
assert 0.0 <= top_p <= 1.0
assert 0.9 <= repetition_penalty <= 3.0
if temperature == 0.0:
temperature = 0.01
if prompt == "":
prompt = " "
print(f"START -> ({datetime.now()})\n")
print(f"PROMPT ({datetime.now()}):\n-------\n{prompt}\n")
for idx in model_idx:
model_name = MODEL_NAMES[idx]
x = threading.Thread(target=gen_thread, args=(model_name, prompt, max_new_tokens, temperature, top_p, repetition_penalty, stop))
threads.append(x)
x.start()
# Join Threads
for model_name, thread in enumerate(threads):
while thread.is_alive():
thread.join(timeout=0.2)
yield output[MODEL_NAMES[0]], output[MODEL_NAMES[1]]
examples = [
[
# Question Answering
'''Please answer the following question:
Question: What is the capital of Germany?
Answer:''',["BLOOM","BLOOMZ"] , 3, 0.2, 1.0, 1.0, "\\n,</s>", ["BLOOM","BLOOMZ"]],
[
# Natural Language Interface
'''Given a pair of sentences, choose whether the two sentences agree (entailment)/disagree (contradiction) with each other.
Possible labels: 1. entailment 2. contradiction
Sentence 1: The skier was on the edge of the ramp. Sentence 2: The skier was dressed in winter clothes.
Label: entailment
Sentence 1: The boy skated down the staircase railing. Sentence 2: The boy is a newbie skater.
Label: contradiction
Sentence 1: Two middle-aged people stand by a golf hole. Sentence 2: A couple riding in a golf cart.
Label:''',["BLOOM","BLOOMZ"] , 2, 0.2, 1.0, 1.0, "\\n,</s>"]
]
def clear_prompt():
return "","",""
with gr.Blocks() as demo:
gr.Markdown("# <p style='text-align: center;'>BLOOM vs BLOOMZ Comparison</p>")
gr.Markdown("")
gr.Markdown("Test Inference on the [BLOOM](https://huggingface.co/bigscience/bloom) and [BLOOMZ](https://huggingface.co/bigscience/bloomz) 176 Billion Parameter models using Petals. \
Please consider contributing your unused GPU cycles to the [Petals Swarm](https://github.com/bigscience-workshop/petals) to speed up inference. <br />\n \
Due to heavy resource requirements of these large models, token generation can take upwards of 3-5 seconds per token. Try to keep Max Tokens to a minimum.")
gr.Markdown("")
gr.Markdown("Special thanks to [RFT Capital](https://www.rftcapital.com/) for supporting our experiments with compute time dontations.")
gr.Markdown("Type a Prompt and then click **Run** to see the output.")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(lines=17,label="Prompt",placeholder="Enter Prompt", interactive=True)
with gr.Box():
chk_boxes = gr.CheckboxGroup(choices=["BLOOM","BLOOMZ"],value=["BLOOM","BLOOMZ"], type="index", label="Model")
#min_length = gr.Slider(minimum=0, maximum=256, value=1, label="Minimum Length") #min_length
max_tokens = gr.Slider(minimum=1, maximum=256, value=15, label="Max Tokens") # max_tokens
temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.2, label="Temperature") # temperature
top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.9, label="Top P") # top_p
rep_penalty = gr.Slider(minimum=0.9, maximum=3.0, step=0.1, value=1.0, label="Repetition Penalty") # repetition penalty
stop = gr.Textbox(lines=1, value="\\n,</s>", label="Stop Token") # stop
with gr.Column():
bloom_out = gr.Textbox(lines=7, label="BLOOM OUTPUT:")
bloomz_out = gr.Textbox(lines=7,label="BLOOMZ OUTPUT:")
with gr.Row():
btn_clear = gr.Button("Clear", variant="secondary")
btn_run = gr.Button("Run", variant="primary")
btn_stop = gr.Button("Stop", variant="stop")
click_run = btn_run.click(infer, inputs=[prompt, chk_boxes, max_tokens, temperature, top_p, rep_penalty, stop], outputs=[bloom_out,bloomz_out])
btn_clear.click(clear_prompt, outputs=[prompt, bloom_out, bloomz_out])
btn_stop.click(stop_threads,cancels=click_run)
gr.Examples(examples, inputs=[prompt, chk_boxes, max_tokens, temperature, top_p, rep_penalty, stop])
demo.queue(concurrency_count=1)
demo.launch()