File size: 7,804 Bytes
e956bee
 
 
 
 
 
 
26fd787
e956bee
 
 
 
 
26fd787
 
 
e956bee
26fd787
 
 
 
e956bee
26fd787
e956bee
26fd787
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e956bee
 
 
 
26fd787
 
e956bee
 
26fd787
e956bee
 
 
 
 
 
 
 
26fd787
 
 
e956bee
26fd787
 
e956bee
26fd787
 
 
 
a74cf5e
 
e956bee
 
 
 
26fd787
e956bee
 
 
 
26fd787
e956bee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26fd787
 
 
e956bee
 
 
 
 
26fd787
 
 
e956bee
 
 
 
 
 
 
26fd787
e956bee
 
 
 
 
 
 
 
 
26fd787
e956bee
 
26fd787
 
 
 
eb8ab73
 
 
 
 
 
 
 
26fd787
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb8ab73
26fd787
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import gradio as gr
import threading
import codecs
from datetime import datetime
from transformers import BloomTokenizerFast
from petals.client import DistributedBloomForCausalLM
import torch
import time

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TORCH_DTYPE = torch.bfloat16
MODEL_NAMES = ["bigscience/bloom-petals", "bigscience/bloomz-petals"]

models = {MODEL_NAMES[0]:None,MODEL_NAMES[1]:None}
output = {MODEL_NAMES[0]:"",MODEL_NAMES[1]:""}
kill = threading.Event()

def stop_threads():
    global kill
    print("Force stopping threads")
    kill.set()

def gen_thread(model_name, prompt, max_tokens, temperature, top_p, repetition_penalty, stop):
    global output
    
    if kill.is_set():
        return
    
    flag = False
    token_cnt = 0
    with models[model_name][1].inference_session(max_length=512) as sess:
        print(f"Thread Start -> {threading.get_ident()}")
        output[model_name] = ""
        inputs = models[model_name][0](prompt, return_tensors="pt")["input_ids"].to(DEVICE)
        n_input_tokens = inputs.shape[1]
        done = False
        while not done and not kill.is_set():
            outputs = models[model_name][1].generate(
                inputs, 
                max_new_tokens=1, 
                do_sample=True, 
                top_p=top_p, 
                temperature=temperature, 
                repetition_penalty=repetition_penalty,
                session=sess
            )
            output[model_name] += models[model_name][0].decode(outputs[0, n_input_tokens:])
            token_cnt += 1
            print("\n["+ str(threading.get_ident()) + "]" + output[model_name], end="", flush=True)

            for stop_word in stop:
                stop_word = codecs.getdecoder("unicode_escape")(stop_word)[0]
                if stop_word != '' and stop_word in output[model_name]:
                    print(f"\nDONE (stop) -> {threading.get_ident()}")
                    done = True
            if flag or (token_cnt >= max_tokens):
                print(f"\nDONE (max tokens) -> {threading.get_ident()}")
                done = True
            inputs = None  # Prefix is passed only for the 1st token of the bot's response
            n_input_tokens = 0
        print(f"\nThread End -> {threading.get_ident()}")

def to_md(text):
    return text.replace("\n", "<br />")

threads = list()

def infer(
        prompt,
        model_idx = ["BLOOM","BLOOMZ"],
        max_new_tokens=10,
        temperature=0.1,
        top_p=1.0,
        repetition_penalty = 1.0,
        stop="\n",
        num_completions=1,
        seed=42,
):
    global threads
    global output
    global models

    if len(model_idx) == 0:
        return

    kill.clear()    
    print("Loading Models\n")
    for idx in model_idx:
        model_name = MODEL_NAMES[idx]
        if models[model_name] == None:
            print ("Initializing " + model_name)
            tokenizer = BloomTokenizerFast.from_pretrained(model_name)
            model = DistributedBloomForCausalLM.from_pretrained(model_name, torch_dtype=TORCH_DTYPE)
            model = model.to(DEVICE)
            models[model_name] = tokenizer, model
            output[model_name] = ""

    max_new_tokens = int(max_new_tokens)
    temperature = float(temperature)
    top_p = float(top_p)
    stop =  [x.strip(' ') for x in stop.split(',')]
    repetition_penalty = float(repetition_penalty)
    seed = seed

    assert 1 <= max_new_tokens <= 384
    assert 1 <= num_completions <= 5
    assert 0.0 <= temperature <= 1.0
    assert 0.0 <= top_p <= 1.0
    assert 0.9 <= repetition_penalty <= 3.0

    if temperature == 0.0:
        temperature = 0.01
    if prompt == "":
        prompt = " "
    
    print(f"START -> ({datetime.now()})\n")
    print(f"PROMPT ({datetime.now()}):\n-------\n{prompt}\n")
    for idx in model_idx:
        model_name = MODEL_NAMES[idx]
        x = threading.Thread(target=gen_thread, args=(model_name, prompt, max_new_tokens, temperature, top_p, repetition_penalty, stop))
        threads.append(x)
        x.start()
    
    # Join Threads
    for model_name, thread in enumerate(threads):
        while thread.is_alive():
            thread.join(timeout=0.2)
            yield output[MODEL_NAMES[0]], output[MODEL_NAMES[1]]


examples = [
    [
        # Question Answering
        '''Please answer the following question:
Question: What is the capital of Germany?
Answer:''',["BLOOM","BLOOMZ"] , 3, 0.2, 1.0, 1.0, "\\n,</s>", ["BLOOM","BLOOMZ"]],
    [
        # Natural Language Interface
        '''Given a pair of sentences, choose whether the two sentences agree (entailment)/disagree (contradiction) with each other.
Possible labels: 1. entailment 2. contradiction
Sentence 1: The skier was on the edge of the ramp. Sentence 2: The skier was dressed in winter clothes.
Label: entailment
Sentence 1: The boy skated down the staircase railing. Sentence 2: The boy is a newbie skater.
Label: contradiction
Sentence 1: Two middle-aged people stand by a golf hole. Sentence 2: A couple riding in a golf cart.
Label:''',["BLOOM","BLOOMZ"] , 2, 0.2, 1.0, 1.0, "\\n,</s>"]
]

def clear_prompt():
    return "","",""

with gr.Blocks() as demo:
    gr.Markdown("# <p style='text-align: center;'>BLOOM vs BLOOMZ Comparison</p>")
    gr.Markdown("")
    gr.Markdown("Test Inference on the [BLOOM](https://huggingface.co/bigscience/bloom) and [BLOOMZ](https://huggingface.co/bigscience/bloomz) 176 Billion Parameter models using Petals.  \
        Please consider contributing your unused GPU cycles to the [Petals Swarm](https://github.com/bigscience-workshop/petals) to speed up inference. <br />\n \
        Due to heavy resource requirements of these large models, token generation can take upwards of 3-5 seconds per token. Try to keep Max Tokens to a minimum.")
    gr.Markdown("")
    gr.Markdown("Special thanks to [RFT Capital](https://www.rftcapital.com/) for supporting our experiments with compute time dontations.")
    gr.Markdown("Type a Prompt and then click **Run** to see the output.")
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(lines=17,label="Prompt",placeholder="Enter Prompt", interactive=True)
            with gr.Box():
                chk_boxes = gr.CheckboxGroup(choices=["BLOOM","BLOOMZ"],value=["BLOOM","BLOOMZ"], type="index", label="Model")
                #min_length = gr.Slider(minimum=0, maximum=256, value=1, label="Minimum Length") #min_length
                max_tokens = gr.Slider(minimum=1, maximum=256, value=15, label="Max Tokens")  # max_tokens
                temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.2, label="Temperature")  # temperature
                top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.9, label="Top P")  # top_p
                rep_penalty = gr.Slider(minimum=0.9, maximum=3.0, step=0.1, value=1.0, label="Repetition Penalty")  # repetition penalty
                stop = gr.Textbox(lines=1, value="\\n,</s>", label="Stop Token") # stop
        with gr.Column():
            bloom_out = gr.Textbox(lines=7, label="BLOOM OUTPUT:") 
            bloomz_out = gr.Textbox(lines=7,label="BLOOMZ OUTPUT:")
    with gr.Row():
        btn_clear = gr.Button("Clear", variant="secondary")
        btn_run = gr.Button("Run", variant="primary")
        btn_stop = gr.Button("Stop", variant="stop")
        click_run = btn_run.click(infer, inputs=[prompt, chk_boxes, max_tokens, temperature, top_p, rep_penalty, stop], outputs=[bloom_out,bloomz_out])   
        btn_clear.click(clear_prompt, outputs=[prompt, bloom_out, bloomz_out])
        btn_stop.click(stop_threads,cancels=click_run)
    gr.Examples(examples, inputs=[prompt, chk_boxes, max_tokens, temperature, top_p, rep_penalty, stop])

demo.queue(concurrency_count=1)
demo.launch()