Spaces:
Build error
Build error
import gradio as gr | |
import threading | |
import codecs | |
from datetime import datetime | |
from transformers import BloomTokenizerFast | |
from petals.client import DistributedBloomForCausalLM | |
import torch | |
import time | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
TORCH_DTYPE = torch.bfloat16 | |
MODEL_NAMES = ["bigscience/bloom-petals", "bigscience/bloomz-petals"] | |
models = {MODEL_NAMES[0]:None,MODEL_NAMES[1]:None} | |
output = {MODEL_NAMES[0]:"",MODEL_NAMES[1]:""} | |
kill = threading.Event() | |
def stop_threads(): | |
global kill | |
print("Force stopping threads") | |
kill.set() | |
def gen_thread(model_name, prompt, max_tokens, temperature, top_p, repetition_penalty, stop): | |
global output | |
if kill.is_set(): | |
return | |
flag = False | |
token_cnt = 0 | |
with models[model_name][1].inference_session(max_length=512) as sess: | |
print(f"Thread Start -> {threading.get_ident()}") | |
output[model_name] = "" | |
inputs = models[model_name][0](prompt, return_tensors="pt")["input_ids"].to(DEVICE) | |
n_input_tokens = inputs.shape[1] | |
done = False | |
while not done and not kill.is_set(): | |
outputs = models[model_name][1].generate( | |
inputs, | |
max_new_tokens=1, | |
do_sample=True, | |
top_p=top_p, | |
temperature=temperature, | |
repetition_penalty=repetition_penalty, | |
session=sess | |
) | |
output[model_name] += models[model_name][0].decode(outputs[0, n_input_tokens:]) | |
token_cnt += 1 | |
print("\n["+ str(threading.get_ident()) + "]" + output[model_name], end="", flush=True) | |
for stop_word in stop: | |
stop_word = codecs.getdecoder("unicode_escape")(stop_word)[0] | |
if stop_word != '' and stop_word in output[model_name]: | |
print(f"\nDONE (stop) -> {threading.get_ident()}") | |
done = True | |
if flag or (token_cnt >= max_tokens): | |
print(f"\nDONE (max tokens) -> {threading.get_ident()}") | |
done = True | |
inputs = None # Prefix is passed only for the 1st token of the bot's response | |
n_input_tokens = 0 | |
print(f"\nThread End -> {threading.get_ident()}") | |
def to_md(text): | |
return text.replace("\n", "<br />") | |
threads = list() | |
def infer( | |
prompt, | |
model_idx = ["BLOOM","BLOOMZ"], | |
max_new_tokens=10, | |
temperature=0.1, | |
top_p=1.0, | |
repetition_penalty = 1.0, | |
stop="\n", | |
num_completions=1, | |
seed=42, | |
): | |
global threads | |
global output | |
global models | |
if len(model_idx) == 0: | |
return | |
kill.clear() | |
print("Loading Models\n") | |
for idx in model_idx: | |
model_name = MODEL_NAMES[idx] | |
if models[model_name] == None: | |
print ("Initializing " + model_name) | |
tokenizer = BloomTokenizerFast.from_pretrained(model_name) | |
model = DistributedBloomForCausalLM.from_pretrained(model_name, torch_dtype=TORCH_DTYPE) | |
model = model.to(DEVICE) | |
models[model_name] = tokenizer, model | |
output[model_name] = "" | |
max_new_tokens = int(max_new_tokens) | |
temperature = float(temperature) | |
top_p = float(top_p) | |
stop = [x.strip(' ') for x in stop.split(',')] | |
repetition_penalty = float(repetition_penalty) | |
seed = seed | |
assert 1 <= max_new_tokens <= 384 | |
assert 1 <= num_completions <= 5 | |
assert 0.0 <= temperature <= 1.0 | |
assert 0.0 <= top_p <= 1.0 | |
assert 0.9 <= repetition_penalty <= 3.0 | |
if temperature == 0.0: | |
temperature = 0.01 | |
if prompt == "": | |
prompt = " " | |
print(f"START -> ({datetime.now()})\n") | |
print(f"PROMPT ({datetime.now()}):\n-------\n{prompt}\n") | |
for idx in model_idx: | |
model_name = MODEL_NAMES[idx] | |
x = threading.Thread(target=gen_thread, args=(model_name, prompt, max_new_tokens, temperature, top_p, repetition_penalty, stop)) | |
threads.append(x) | |
x.start() | |
# Join Threads | |
for model_name, thread in enumerate(threads): | |
while thread.is_alive(): | |
thread.join(timeout=0.2) | |
yield output[MODEL_NAMES[0]], output[MODEL_NAMES[1]] | |
examples = [ | |
[ | |
# Question Answering | |
'''Please answer the following question: | |
Question: What is the capital of Germany? | |
Answer:''',["BLOOM","BLOOMZ"] , 3, 0.2, 1.0, 1.0, "\\n,</s>", ["BLOOM","BLOOMZ"]], | |
[ | |
# Natural Language Interface | |
'''Given a pair of sentences, choose whether the two sentences agree (entailment)/disagree (contradiction) with each other. | |
Possible labels: 1. entailment 2. contradiction | |
Sentence 1: The skier was on the edge of the ramp. Sentence 2: The skier was dressed in winter clothes. | |
Label: entailment | |
Sentence 1: The boy skated down the staircase railing. Sentence 2: The boy is a newbie skater. | |
Label: contradiction | |
Sentence 1: Two middle-aged people stand by a golf hole. Sentence 2: A couple riding in a golf cart. | |
Label:''',["BLOOM","BLOOMZ"] , 2, 0.2, 1.0, 1.0, "\\n,</s>"] | |
] | |
def clear_prompt(): | |
return "","","" | |
with gr.Blocks() as demo: | |
gr.Markdown("# <p style='text-align: center;'>BLOOM vs BLOOMZ Comparison</p>") | |
gr.Markdown("") | |
gr.Markdown("Test Inference on the [BLOOM](https://huggingface.co/bigscience/bloom) and [BLOOMZ](https://huggingface.co/bigscience/bloomz) 176 Billion Parameter models using Petals. \ | |
Please consider contributing your unused GPU cycles to the [Petals Swarm](https://github.com/bigscience-workshop/petals) to speed up inference. <br />\n \ | |
Due to heavy resource requirements of these large models, token generation can take upwards of 3-5 seconds per token. Try to keep Max Tokens to a minimum.") | |
gr.Markdown("") | |
gr.Markdown("Special thanks to [RFT Capital](https://www.rftcapital.com/) for supporting our experiments with compute time dontations.") | |
gr.Markdown("Type a Prompt and then click **Run** to see the output.") | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox(lines=17,label="Prompt",placeholder="Enter Prompt", interactive=True) | |
with gr.Box(): | |
chk_boxes = gr.CheckboxGroup(choices=["BLOOM","BLOOMZ"],value=["BLOOM","BLOOMZ"], type="index", label="Model") | |
#min_length = gr.Slider(minimum=0, maximum=256, value=1, label="Minimum Length") #min_length | |
max_tokens = gr.Slider(minimum=1, maximum=256, value=15, label="Max Tokens") # max_tokens | |
temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.2, label="Temperature") # temperature | |
top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.9, label="Top P") # top_p | |
rep_penalty = gr.Slider(minimum=0.9, maximum=3.0, step=0.1, value=1.0, label="Repetition Penalty") # repetition penalty | |
stop = gr.Textbox(lines=1, value="\\n,</s>", label="Stop Token") # stop | |
with gr.Column(): | |
bloom_out = gr.Textbox(lines=7, label="BLOOM OUTPUT:") | |
bloomz_out = gr.Textbox(lines=7,label="BLOOMZ OUTPUT:") | |
with gr.Row(): | |
btn_clear = gr.Button("Clear", variant="secondary") | |
btn_run = gr.Button("Run", variant="primary") | |
btn_stop = gr.Button("Stop", variant="stop") | |
click_run = btn_run.click(infer, inputs=[prompt, chk_boxes, max_tokens, temperature, top_p, rep_penalty, stop], outputs=[bloom_out,bloomz_out]) | |
btn_clear.click(clear_prompt, outputs=[prompt, bloom_out, bloomz_out]) | |
btn_stop.click(stop_threads,cancels=click_run) | |
gr.Examples(examples, inputs=[prompt, chk_boxes, max_tokens, temperature, top_p, rep_penalty, stop]) | |
demo.queue(concurrency_count=1) | |
demo.launch() |