WangZeJun commited on
Commit
abce01a
β€’
1 Parent(s): d53c868

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +270 -0
app.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 MosaicML spaces authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import Optional
4
+ import datetime
5
+ import os
6
+ from threading import Event, Thread
7
+ from uuid import uuid4
8
+
9
+ import gradio as gr
10
+ import requests
11
+ import torch
12
+ from transformers import (
13
+ AutoModelForCausalLM,
14
+ AutoTokenizer,
15
+ StoppingCriteria,
16
+ StoppingCriteriaList,
17
+ TextIteratorStreamer,
18
+ )
19
+
20
+
21
+ model_name = "WangZeJun/bloom-3b-moss-chat"
22
+ max_new_tokens = 1024
23
+
24
+
25
+ print(f"Starting to load the model {model_name} into memory")
26
+
27
+ tok = AutoTokenizer.from_pretrained(model_name)
28
+ m = AutoModelForCausalLM.from_pretrained(model_name).eval()
29
+
30
+ # tok.convert_tokens_to_ids(["<|im_end|>", "<|endoftext|>"])
31
+ stop_token_ids = [tok.eos_token_id]
32
+
33
+ print(f"Successfully loaded the model {model_name} into memory")
34
+
35
+
36
+
37
+ class StopOnTokens(StoppingCriteria):
38
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
39
+ for stop_id in stop_token_ids:
40
+ if input_ids[0][-1] == stop_id:
41
+ return True
42
+ return False
43
+
44
+
45
+ def convert_history_to_text(history):
46
+
47
+ user_input = history[-1][0]
48
+
49
+ input_pattern = "{}</s>"
50
+ text = input_pattern.format(user_input)
51
+ return text
52
+
53
+
54
+
55
+ def log_conversation(conversation_id, history, messages, generate_kwargs):
56
+ logging_url = os.getenv("LOGGING_URL", None)
57
+ if logging_url is None:
58
+ return
59
+
60
+ timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
61
+
62
+ data = {
63
+ "conversation_id": conversation_id,
64
+ "timestamp": timestamp,
65
+ "history": history,
66
+ "messages": messages,
67
+ "generate_kwargs": generate_kwargs,
68
+ }
69
+
70
+ try:
71
+ requests.post(logging_url, json=data)
72
+ except requests.exceptions.RequestException as e:
73
+ print(f"Error logging conversation: {e}")
74
+
75
+
76
+ def user(message, history):
77
+ # Append the user's message to the conversation history
78
+ return "", history + [[message, ""]]
79
+
80
+
81
+ def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id):
82
+ print(f"history: {history}")
83
+ # Initialize a StopOnTokens object
84
+ stop = StopOnTokens()
85
+
86
+ # Construct the input message string for the model by concatenating the current system message and conversation history
87
+ messages = convert_history_to_text(history)
88
+
89
+ # Tokenize the messages string
90
+ input_ids = tok(messages, return_tensors="pt").input_ids
91
+ input_ids = input_ids.to(m.device)
92
+ streamer = TextIteratorStreamer(
93
+ tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
94
+ generate_kwargs = dict(
95
+ input_ids=input_ids,
96
+ max_new_tokens=max_new_tokens,
97
+ temperature=temperature,
98
+ do_sample=temperature > 0.0,
99
+ top_p=top_p,
100
+ top_k=top_k,
101
+ repetition_penalty=repetition_penalty,
102
+ streamer=streamer,
103
+ stopping_criteria=StoppingCriteriaList([stop]),
104
+ )
105
+
106
+ stream_complete = Event()
107
+
108
+ def generate_and_signal_complete():
109
+ m.generate(**generate_kwargs)
110
+ stream_complete.set()
111
+
112
+ def log_after_stream_complete():
113
+ stream_complete.wait()
114
+ log_conversation(
115
+ conversation_id,
116
+ history,
117
+ messages,
118
+ {
119
+ "top_k": top_k,
120
+ "top_p": top_p,
121
+ "temperature": temperature,
122
+ "repetition_penalty": repetition_penalty,
123
+ },
124
+ )
125
+
126
+ t1 = Thread(target=generate_and_signal_complete)
127
+ t1.start()
128
+
129
+ t2 = Thread(target=log_after_stream_complete)
130
+ t2.start()
131
+
132
+ # Initialize an empty string to store the generated text
133
+ partial_text = ""
134
+ for new_text in streamer:
135
+ partial_text += new_text
136
+ history[-1][1] = partial_text
137
+ yield history
138
+
139
+
140
+ def get_uuid():
141
+ return str(uuid4())
142
+
143
+
144
+ with gr.Blocks(
145
+ theme=gr.themes.Soft(),
146
+ css=".disclaimer {font-variant-caps: all-small-caps;}",
147
+ ) as demo:
148
+ conversation_id = gr.State(get_uuid)
149
+ gr.Markdown(
150
+ """
151
+ 基于 bloom-3b-moss-chat ηš„ AI εŠ©ζ‰‹
152
+ ζ¨‘εž‹: https://huggingface.co/WangZeJun/bloom-3b-moss-chat
153
+ """
154
+ )
155
+ chatbot = gr.Chatbot().style(height=500)
156
+ with gr.Row():
157
+ with gr.Column():
158
+ msg = gr.Textbox(
159
+ label="Chat Message Box",
160
+ placeholder="Chat Message Box",
161
+ show_label=False,
162
+ ).style(container=False)
163
+ with gr.Column():
164
+ with gr.Row():
165
+ submit = gr.Button("Submit")
166
+ stop = gr.Button("Stop")
167
+ clear = gr.Button("Clear")
168
+ with gr.Row():
169
+ with gr.Accordion("Advanced Options:", open=False):
170
+ with gr.Row():
171
+ with gr.Column():
172
+ with gr.Row():
173
+ temperature = gr.Slider(
174
+ label="Temperature",
175
+ value=0.1,
176
+ minimum=0.0,
177
+ maximum=1.0,
178
+ step=0.1,
179
+ interactive=True,
180
+ info="Higher values produce more diverse outputs",
181
+ )
182
+ with gr.Column():
183
+ with gr.Row():
184
+ top_p = gr.Slider(
185
+ label="Top-p (nucleus sampling)",
186
+ value=1.0,
187
+ minimum=0.0,
188
+ maximum=1,
189
+ step=0.01,
190
+ interactive=True,
191
+ info=(
192
+ "Sample from the smallest possible set of tokens whose cumulative probability "
193
+ "exceeds top_p. Set to 1 to disable and sample from all tokens."
194
+ ),
195
+ )
196
+ with gr.Column():
197
+ with gr.Row():
198
+ top_k = gr.Slider(
199
+ label="Top-k",
200
+ value=0,
201
+ minimum=0.0,
202
+ maximum=200,
203
+ step=1,
204
+ interactive=True,
205
+ info="Sample from a shortlist of top-k tokens β€” 0 to disable and sample from all tokens.",
206
+ )
207
+ with gr.Column():
208
+ with gr.Row():
209
+ repetition_penalty = gr.Slider(
210
+ label="Repetition Penalty",
211
+ value=1.2,
212
+ minimum=1.0,
213
+ maximum=2.0,
214
+ step=0.1,
215
+ interactive=True,
216
+ info="Penalize repetition β€” 1.0 to disable.",
217
+ )
218
+ # with gr.Row():
219
+ # gr.Markdown(
220
+ # "demo 2",
221
+ # elem_classes=["disclaimer"],
222
+ # )
223
+
224
+ submit_event = msg.submit(
225
+ fn=user,
226
+ inputs=[msg, chatbot],
227
+ outputs=[msg, chatbot],
228
+ queue=False,
229
+ ).then(
230
+ fn=bot,
231
+ inputs=[
232
+ chatbot,
233
+ temperature,
234
+ top_p,
235
+ top_k,
236
+ repetition_penalty,
237
+ conversation_id,
238
+ ],
239
+ outputs=chatbot,
240
+ queue=True,
241
+ )
242
+ submit_click_event = submit.click(
243
+ fn=user,
244
+ inputs=[msg, chatbot],
245
+ outputs=[msg, chatbot],
246
+ queue=False,
247
+ ).then(
248
+ fn=bot,
249
+ inputs=[
250
+ chatbot,
251
+ temperature,
252
+ top_p,
253
+ top_k,
254
+ repetition_penalty,
255
+ conversation_id,
256
+ ],
257
+ outputs=chatbot,
258
+ queue=True,
259
+ )
260
+ stop.click(
261
+ fn=None,
262
+ inputs=None,
263
+ outputs=None,
264
+ cancels=[submit_event, submit_click_event],
265
+ queue=False,
266
+ )
267
+ clear.click(lambda: None, None, chatbot, queue=False)
268
+
269
+ demo.queue(max_size=128, concurrency_count=2)
270
+ demo.launch()