slush0 commited on
Commit
3d3362f
·
1 Parent(s): 68d1b54

Initial commit

Browse files
Files changed (3) hide show
  1. app.py +123 -0
  2. chat_client.py +78 -0
  3. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # or gradio app.py
3
+
4
+ import gradio as gr
5
+ import chat_client
6
+
7
+ CHAT_URL='ws://chat.petals.ml/api/v2/generate'
8
+ #CHAT_URL='ws://localhost:8000/api/v2/generate'
9
+
10
+ def generate(prompt, model, endseq, max_length,
11
+ do_sample, top_k, top_p, temperature,
12
+ add_stoptoken, copy_output):
13
+
14
+ client = chat_client.ModelClient(CHAT_URL)
15
+ client.open_session(f"bigscience/{model}-petals", max_length)
16
+
17
+ if add_stoptoken:
18
+ prompt += "</s>" if "bloomz" in model else "\n\n"
19
+
20
+ # Translate checkbox items to actual sequences
21
+ seq = []
22
+ for s in endseq:
23
+ if s == "\\n":
24
+ seq.append("\n")
25
+ elif s == "</s>":
26
+ seq.append("</s>")
27
+ elif s == "? (question mark)":
28
+ seq.append("?")
29
+ elif s == ". (dot)":
30
+ seq.append(".")
31
+
32
+ # only top_k or top_p can be set
33
+ if top_k == 0:
34
+ top_k = None
35
+ if top_p == 0:
36
+ top_p = None
37
+ if top_p and top_k:
38
+ top_k = None
39
+
40
+ prompt2 = prompt
41
+ output = ''
42
+
43
+ # This render prompt dialog immediately and
44
+ # don't wait to generator to return first result
45
+ yield [prompt2, output]
46
+
47
+ for out in client.generate(prompt,
48
+ max_new_tokens=1,
49
+ do_sample=do_sample,
50
+ temperature=temperature,
51
+ top_k=top_k,
52
+ top_p=top_p,
53
+ extra_stop_sequences=seq
54
+ ):
55
+
56
+ output += out
57
+ if copy_output:
58
+ prompt2 += out
59
+
60
+ yield [prompt2, output]
61
+
62
+ with gr.Blocks() as iface:
63
+ gr.Markdown("""# Petals playground
64
+ **Let's play with prompts and inference settings for BLOOM and BLOOMZ 176B models! This space uses websocket API of [chat.petals.ml](https://chat.petals.ml).**
65
+
66
+ Do NOT talk to BLOOM as an entity, it's not a chatbot but a webpage/blog/article completion model.
67
+ For the best results: MIMIC a few sentences of a webpage similar to the content you want to generate.
68
+
69
+ BLOOMZ performs better in chat mode and understands the instructions better.""")
70
+
71
+ with gr.Row():
72
+ model = gr.Radio(["bloom", "bloomz", "bloom-7b1"], value='bloom', label="Use model")
73
+
74
+ # Additional ending sequence, at which generation shoud stop
75
+ endseq = gr.CheckboxGroup(["\\n", "</s>", "? (question mark)", ". (dot)"],
76
+ value=["\\n", "</s>"], label='Extra end sequences')
77
+
78
+ # Maximum length of inference session
79
+ max_length = gr.Radio([128, 256, 512, 1024, 2048], value=512, interactive=True, label="Max length")
80
+
81
+ with gr.Row():
82
+ with gr.Column():
83
+ # Switch between sampling and greedy generation
84
+ do_sample = gr.Checkbox(value=True, interactive=True, label="do_sample")
85
+
86
+ # Should the app append stop sequence at the end of prompt or should it leave the prompt open?
87
+ add_stoptoken = gr.Checkbox(value=True, interactive=True, label="Automatically add stop token to prompt.")
88
+
89
+ # Only one of top_k and top_p can be set. Requires "do_sample=True" to work.
90
+ top_k = gr.Number(value=0, precision=0, interactive=True, label="top_k")
91
+ top_p = gr.Number(value=0.9, precision=2, interactive=True, label="top_p")
92
+
93
+ # Generation temperature
94
+ temperature = gr.Number(value=0.75, precision=2, interactive=True, label="Temperature")
95
+
96
+ prompt = gr.Textbox(lines=2, label='Prompt', placeholder="Prompt Here...")
97
+
98
+ with gr.Row():
99
+ button_generate = gr.Button("Generate")
100
+ button_stop = gr.Button("Stop") # TODO, not supported by websocket API yet.
101
+
102
+ # Automatically copy the output at the end of prompt
103
+ copy_output = gr.Checkbox(label="Output -> Prompt")
104
+
105
+ output = gr.Textbox(lines=3, label='Output')
106
+
107
+ button_generate.click(generate, inputs=[prompt, model, endseq,
108
+ max_length, do_sample, top_k, top_p, temperature, add_stoptoken, copy_output], outputs=[prompt, output])
109
+
110
+ examples = gr.Examples(inputs=[prompt, model, do_sample, top_k, top_p, temperature, add_stoptoken],
111
+ examples=[
112
+ ["The SQL command to extract all the users whose name starts with A is: ", "bloom", False, 0, 0, 1, False],
113
+ ["The Spanish translation of thank you for your help is: ", "bloom", False, 0, 0, 1, False],
114
+ ["A human talks to a powerful AI that follows the human's instructions "
115
+ "and writes exhaustive, very detailed answer.</s>\n"
116
+ "Human: Hi!</s>\n"
117
+ "AI: Hi! How can I help you?</s>\n"
118
+ "Human: What's the capital of Portugal?</s>\n"
119
+ "AI: ", "bloomz", True, 0, 0.9, 0.75, False]
120
+ ])
121
+
122
+ iface.queue()
123
+ iface.launch()
chat_client.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import json
3
+ import sys
4
+
5
+ # pip install websocket-client
6
+ import websocket
7
+
8
+ class ModelClient(object):
9
+ def __init__(self, endpoint_url):
10
+ self.endpoint_url = endpoint_url
11
+ self.ws = None
12
+ self.model = None
13
+
14
+ def open_session(self, model, max_length):
15
+ self.ws = websocket.create_connection(self.endpoint_url)
16
+ self.model = model
17
+ payload = {
18
+ "type": "open_inference_session",
19
+ "model": self.model,
20
+ "max_length": max_length,
21
+ }
22
+ self.ws.send(json.dumps(payload))
23
+ assert json.loads(self.ws.recv())['ok'] == True
24
+
25
+ def close_session(self):
26
+ if self.ws:
27
+ self.ws.close()
28
+
29
+ def generate(self, prompt, **kwargs):
30
+ payload = {
31
+ "type": "generate",
32
+ "inputs": prompt,
33
+ "max_new_tokens": 1,
34
+ "do_sample": 0,
35
+ "temperature": 0,
36
+ "stop_sequence": "</s>" if "bloomz" in self.model else "\n\n",
37
+ }
38
+ payload = {**payload, **kwargs}
39
+ self.ws.send(json.dumps(payload))
40
+
41
+ while True:
42
+ try:
43
+ data = json.loads(self.ws.recv())
44
+ except json.decoder.JSONDecodeError:
45
+ self.close_session()
46
+ raise
47
+
48
+ if not data['ok']:
49
+ raise Exception(data['traceback'])
50
+ yield data['outputs']
51
+ if data['stop']:
52
+ break
53
+
54
+ def main():
55
+ client = ModelClient("ws://localhost:8000/api/v2/generate")
56
+ # client = ModelClient("ws://chat.petals.ml/api/v2/generate")
57
+ client.open_session("bigscience/bloom-petals", 128)
58
+
59
+ if len(sys.argv) > 1:
60
+ prompt = sys.argv[1]
61
+ # Bloomz variant uses </s> instead of \n\n as an eos token
62
+ if not prompt.endswith("\n\n"):
63
+ prompt += "\n\n"
64
+ else:
65
+ prompt = "The SQL command to extract all the users whose name starts with A is: \n\n"
66
+ print(f"Prompt: {prompt}")
67
+
68
+ # petals.client.routing.sequence_manager.MissingBlocksError
69
+ for out in client.generate(prompt,
70
+ do_sample=True,
71
+ temperature=0.75,
72
+ top_p=0.9):
73
+ print(out, end="", flush=True)
74
+
75
+ client.close_session()
76
+
77
+ if __name__ == '__main__':
78
+ main()
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ websocket-client
2
+ gradio