Spaces:
Build error
Build error
File size: 7,804 Bytes
e956bee 26fd787 e956bee 26fd787 e956bee 26fd787 e956bee 26fd787 e956bee 26fd787 e956bee 26fd787 e956bee 26fd787 e956bee 26fd787 e956bee 26fd787 e956bee 26fd787 a74cf5e e956bee 26fd787 e956bee 26fd787 e956bee 26fd787 e956bee 26fd787 e956bee 26fd787 e956bee 26fd787 e956bee 26fd787 eb8ab73 26fd787 eb8ab73 26fd787 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
import gradio as gr
import threading
import codecs
from datetime import datetime
from transformers import BloomTokenizerFast
from petals.client import DistributedBloomForCausalLM
import torch
import time
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TORCH_DTYPE = torch.bfloat16
MODEL_NAMES = ["bigscience/bloom-petals", "bigscience/bloomz-petals"]
models = {MODEL_NAMES[0]:None,MODEL_NAMES[1]:None}
output = {MODEL_NAMES[0]:"",MODEL_NAMES[1]:""}
kill = threading.Event()
def stop_threads():
global kill
print("Force stopping threads")
kill.set()
def gen_thread(model_name, prompt, max_tokens, temperature, top_p, repetition_penalty, stop):
global output
if kill.is_set():
return
flag = False
token_cnt = 0
with models[model_name][1].inference_session(max_length=512) as sess:
print(f"Thread Start -> {threading.get_ident()}")
output[model_name] = ""
inputs = models[model_name][0](prompt, return_tensors="pt")["input_ids"].to(DEVICE)
n_input_tokens = inputs.shape[1]
done = False
while not done and not kill.is_set():
outputs = models[model_name][1].generate(
inputs,
max_new_tokens=1,
do_sample=True,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
session=sess
)
output[model_name] += models[model_name][0].decode(outputs[0, n_input_tokens:])
token_cnt += 1
print("\n["+ str(threading.get_ident()) + "]" + output[model_name], end="", flush=True)
for stop_word in stop:
stop_word = codecs.getdecoder("unicode_escape")(stop_word)[0]
if stop_word != '' and stop_word in output[model_name]:
print(f"\nDONE (stop) -> {threading.get_ident()}")
done = True
if flag or (token_cnt >= max_tokens):
print(f"\nDONE (max tokens) -> {threading.get_ident()}")
done = True
inputs = None # Prefix is passed only for the 1st token of the bot's response
n_input_tokens = 0
print(f"\nThread End -> {threading.get_ident()}")
def to_md(text):
return text.replace("\n", "<br />")
threads = list()
def infer(
prompt,
model_idx = ["BLOOM","BLOOMZ"],
max_new_tokens=10,
temperature=0.1,
top_p=1.0,
repetition_penalty = 1.0,
stop="\n",
num_completions=1,
seed=42,
):
global threads
global output
global models
if len(model_idx) == 0:
return
kill.clear()
print("Loading Models\n")
for idx in model_idx:
model_name = MODEL_NAMES[idx]
if models[model_name] == None:
print ("Initializing " + model_name)
tokenizer = BloomTokenizerFast.from_pretrained(model_name)
model = DistributedBloomForCausalLM.from_pretrained(model_name, torch_dtype=TORCH_DTYPE)
model = model.to(DEVICE)
models[model_name] = tokenizer, model
output[model_name] = ""
max_new_tokens = int(max_new_tokens)
temperature = float(temperature)
top_p = float(top_p)
stop = [x.strip(' ') for x in stop.split(',')]
repetition_penalty = float(repetition_penalty)
seed = seed
assert 1 <= max_new_tokens <= 384
assert 1 <= num_completions <= 5
assert 0.0 <= temperature <= 1.0
assert 0.0 <= top_p <= 1.0
assert 0.9 <= repetition_penalty <= 3.0
if temperature == 0.0:
temperature = 0.01
if prompt == "":
prompt = " "
print(f"START -> ({datetime.now()})\n")
print(f"PROMPT ({datetime.now()}):\n-------\n{prompt}\n")
for idx in model_idx:
model_name = MODEL_NAMES[idx]
x = threading.Thread(target=gen_thread, args=(model_name, prompt, max_new_tokens, temperature, top_p, repetition_penalty, stop))
threads.append(x)
x.start()
# Join Threads
for model_name, thread in enumerate(threads):
while thread.is_alive():
thread.join(timeout=0.2)
yield output[MODEL_NAMES[0]], output[MODEL_NAMES[1]]
examples = [
[
# Question Answering
'''Please answer the following question:
Question: What is the capital of Germany?
Answer:''',["BLOOM","BLOOMZ"] , 3, 0.2, 1.0, 1.0, "\\n,</s>", ["BLOOM","BLOOMZ"]],
[
# Natural Language Interface
'''Given a pair of sentences, choose whether the two sentences agree (entailment)/disagree (contradiction) with each other.
Possible labels: 1. entailment 2. contradiction
Sentence 1: The skier was on the edge of the ramp. Sentence 2: The skier was dressed in winter clothes.
Label: entailment
Sentence 1: The boy skated down the staircase railing. Sentence 2: The boy is a newbie skater.
Label: contradiction
Sentence 1: Two middle-aged people stand by a golf hole. Sentence 2: A couple riding in a golf cart.
Label:''',["BLOOM","BLOOMZ"] , 2, 0.2, 1.0, 1.0, "\\n,</s>"]
]
def clear_prompt():
return "","",""
with gr.Blocks() as demo:
gr.Markdown("# <p style='text-align: center;'>BLOOM vs BLOOMZ Comparison</p>")
gr.Markdown("")
gr.Markdown("Test Inference on the [BLOOM](https://huggingface.co/bigscience/bloom) and [BLOOMZ](https://huggingface.co/bigscience/bloomz) 176 Billion Parameter models using Petals. \
Please consider contributing your unused GPU cycles to the [Petals Swarm](https://github.com/bigscience-workshop/petals) to speed up inference. <br />\n \
Due to heavy resource requirements of these large models, token generation can take upwards of 3-5 seconds per token. Try to keep Max Tokens to a minimum.")
gr.Markdown("")
gr.Markdown("Special thanks to [RFT Capital](https://www.rftcapital.com/) for supporting our experiments with compute time dontations.")
gr.Markdown("Type a Prompt and then click **Run** to see the output.")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(lines=17,label="Prompt",placeholder="Enter Prompt", interactive=True)
with gr.Box():
chk_boxes = gr.CheckboxGroup(choices=["BLOOM","BLOOMZ"],value=["BLOOM","BLOOMZ"], type="index", label="Model")
#min_length = gr.Slider(minimum=0, maximum=256, value=1, label="Minimum Length") #min_length
max_tokens = gr.Slider(minimum=1, maximum=256, value=15, label="Max Tokens") # max_tokens
temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.2, label="Temperature") # temperature
top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.9, label="Top P") # top_p
rep_penalty = gr.Slider(minimum=0.9, maximum=3.0, step=0.1, value=1.0, label="Repetition Penalty") # repetition penalty
stop = gr.Textbox(lines=1, value="\\n,</s>", label="Stop Token") # stop
with gr.Column():
bloom_out = gr.Textbox(lines=7, label="BLOOM OUTPUT:")
bloomz_out = gr.Textbox(lines=7,label="BLOOMZ OUTPUT:")
with gr.Row():
btn_clear = gr.Button("Clear", variant="secondary")
btn_run = gr.Button("Run", variant="primary")
btn_stop = gr.Button("Stop", variant="stop")
click_run = btn_run.click(infer, inputs=[prompt, chk_boxes, max_tokens, temperature, top_p, rep_penalty, stop], outputs=[bloom_out,bloomz_out])
btn_clear.click(clear_prompt, outputs=[prompt, bloom_out, bloomz_out])
btn_stop.click(stop_threads,cancels=click_run)
gr.Examples(examples, inputs=[prompt, chk_boxes, max_tokens, temperature, top_p, rep_penalty, stop])
demo.queue(concurrency_count=1)
demo.launch() |