WangZeJun commited on
Commit
13b388c
1 Parent(s): e5c2ced

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -21
app.py CHANGED
@@ -19,7 +19,6 @@ from transformers import (
19
 
20
 
21
  model_name = "WangZeJun/bloom-3b-moss-chat"
22
- max_new_tokens = 1024
23
 
24
 
25
  print(f"Starting to load the model {model_name} into memory")
@@ -43,14 +42,20 @@ class StopOnTokens(StoppingCriteria):
43
 
44
 
45
  def convert_history_to_text(history):
46
-
47
  user_input = history[-1][0]
48
-
49
  input_pattern = "{}</s>"
50
  text = input_pattern.format(user_input)
51
  return text
52
 
53
-
 
 
 
 
 
 
 
 
54
 
55
  def log_conversation(conversation_id, history, messages, generate_kwargs):
56
  logging_url = os.getenv("LOGGING_URL", None)
@@ -78,7 +83,7 @@ def user(message, history):
78
  return "", history + [[message, ""]]
79
 
80
 
81
- def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id):
82
  print(f"history: {history}")
83
  # Initialize a StopOnTokens object
84
  stop = StopOnTokens()
@@ -136,6 +141,64 @@ def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id)
136
  history[-1][1] = partial_text
137
  yield history
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  def get_uuid():
141
  return str(uuid4())
@@ -162,7 +225,8 @@ with gr.Blocks(
162
  ).style(container=False)
163
  with gr.Column():
164
  with gr.Row():
165
- submit = gr.Button("Submit")
 
166
  stop = gr.Button("Stop")
167
  clear = gr.Button("Clear")
168
  with gr.Row():
@@ -172,18 +236,30 @@ with gr.Blocks(
172
  with gr.Row():
173
  temperature = gr.Slider(
174
  label="Temperature",
175
- value=0.1,
176
  minimum=0.0,
177
  maximum=1.0,
178
- step=0.1,
179
  interactive=True,
180
  info="Higher values produce more diverse outputs",
181
  )
 
 
 
 
 
 
 
 
 
 
 
 
182
  with gr.Column():
183
  with gr.Row():
184
  top_p = gr.Slider(
185
  label="Top-p (nucleus sampling)",
186
- value=1.0,
187
  minimum=0.0,
188
  maximum=1,
189
  step=0.01,
@@ -204,17 +280,16 @@ with gr.Blocks(
204
  interactive=True,
205
  info="Sample from a shortlist of top-k tokens — 0 to disable and sample from all tokens.",
206
  )
207
- with gr.Column():
208
- with gr.Row():
209
- repetition_penalty = gr.Slider(
210
- label="Repetition Penalty",
211
- value=1.2,
212
- minimum=1.0,
213
- maximum=2.0,
214
- step=0.1,
215
- interactive=True,
216
- info="Penalize repetition — 1.0 to disable.",
217
- )
218
  # with gr.Row():
219
  # gr.Markdown(
220
  # "demo 2",
@@ -234,12 +309,13 @@ with gr.Blocks(
234
  top_p,
235
  top_k,
236
  repetition_penalty,
 
237
  conversation_id,
238
  ],
239
  outputs=chatbot,
240
  queue=True,
241
  )
242
- submit_click_event = submit.click(
243
  fn=user,
244
  inputs=[msg, chatbot],
245
  outputs=[msg, chatbot],
@@ -252,6 +328,26 @@ with gr.Blocks(
252
  top_p,
253
  top_k,
254
  repetition_penalty,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  conversation_id,
256
  ],
257
  outputs=chatbot,
 
19
 
20
 
21
  model_name = "WangZeJun/bloom-3b-moss-chat"
 
22
 
23
 
24
  print(f"Starting to load the model {model_name} into memory")
 
42
 
43
 
44
  def convert_history_to_text(history):
 
45
  user_input = history[-1][0]
 
46
  input_pattern = "{}</s>"
47
  text = input_pattern.format(user_input)
48
  return text
49
 
50
+ def convert_all_history_to_text(history):
51
+ text = ""
52
+ for instance in history:
53
+ text += instance[0]
54
+ text += "</s>"
55
+ if instance[1]:
56
+ text += instance[1]
57
+ text += "</s>"
58
+ return text
59
 
60
  def log_conversation(conversation_id, history, messages, generate_kwargs):
61
  logging_url = os.getenv("LOGGING_URL", None)
 
83
  return "", history + [[message, ""]]
84
 
85
 
86
+ def bot(history, temperature, top_p, top_k, repetition_penalty, max_new_tokens, conversation_id):
87
  print(f"history: {history}")
88
  # Initialize a StopOnTokens object
89
  stop = StopOnTokens()
 
141
  history[-1][1] = partial_text
142
  yield history
143
 
144
+ def multi_bot(history, temperature, top_p, top_k, repetition_penalty, max_new_tokens, conversation_id):
145
+ print(f"history: {history}")
146
+ # Initialize a StopOnTokens object
147
+ stop = StopOnTokens()
148
+
149
+ # Construct the input message string for the model by concatenating the current system message and conversation history
150
+ messages = convert_all_history_to_text(history)
151
+
152
+ # Tokenize the messages string
153
+ input_ids = tok(messages, return_tensors="pt").input_ids
154
+ input_ids = input_ids.to(m.device)
155
+ streamer = TextIteratorStreamer(
156
+ tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
157
+ generate_kwargs = dict(
158
+ input_ids=input_ids,
159
+ max_new_tokens=max_new_tokens,
160
+ temperature=temperature,
161
+ do_sample=temperature > 0.0,
162
+ top_p=top_p,
163
+ top_k=top_k,
164
+ repetition_penalty=repetition_penalty,
165
+ streamer=streamer,
166
+ stopping_criteria=StoppingCriteriaList([stop]),
167
+ )
168
+
169
+ stream_complete = Event()
170
+
171
+ def generate_and_signal_complete():
172
+ m.generate(**generate_kwargs)
173
+ stream_complete.set()
174
+
175
+ def log_after_stream_complete():
176
+ stream_complete.wait()
177
+ log_conversation(
178
+ conversation_id,
179
+ history,
180
+ messages,
181
+ {
182
+ "top_k": top_k,
183
+ "top_p": top_p,
184
+ "temperature": temperature,
185
+ "repetition_penalty": repetition_penalty,
186
+ },
187
+ )
188
+
189
+ t1 = Thread(target=generate_and_signal_complete)
190
+ t1.start()
191
+
192
+ t2 = Thread(target=log_after_stream_complete)
193
+ t2.start()
194
+
195
+ # Initialize an empty string to store the generated text
196
+ partial_text = ""
197
+ for new_text in streamer:
198
+ partial_text += new_text
199
+ history[-1][1] = partial_text
200
+ yield history
201
+
202
 
203
  def get_uuid():
204
  return str(uuid4())
 
225
  ).style(container=False)
226
  with gr.Column():
227
  with gr.Row():
228
+ single_submit = gr.Button("单轮对话")
229
+ multi_submit = gr.Button("多轮对话")
230
  stop = gr.Button("Stop")
231
  clear = gr.Button("Clear")
232
  with gr.Row():
 
236
  with gr.Row():
237
  temperature = gr.Slider(
238
  label="Temperature",
239
+ value=0.3,
240
  minimum=0.0,
241
  maximum=1.0,
242
+ step=0.05,
243
  interactive=True,
244
  info="Higher values produce more diverse outputs",
245
  )
246
+ with gr.Column():
247
+ with gr.Row():
248
+ repetition_penalty = gr.Slider(
249
+ label="Repetition Penalty",
250
+ value=1.2,
251
+ minimum=1.0,
252
+ maximum=2.0,
253
+ step=0.05,
254
+ interactive=True,
255
+ info="Penalize repetition — 1.0 to disable.",
256
+ )
257
+ with gr.Row():
258
  with gr.Column():
259
  with gr.Row():
260
  top_p = gr.Slider(
261
  label="Top-p (nucleus sampling)",
262
+ value=0.85,
263
  minimum=0.0,
264
  maximum=1,
265
  step=0.01,
 
280
  interactive=True,
281
  info="Sample from a shortlist of top-k tokens — 0 to disable and sample from all tokens.",
282
  )
283
+ with gr.Row():
284
+ max_new_tokens = gr.Slider(
285
+ label="Maximum new tokens",
286
+ value=1024,
287
+ minimum=0,
288
+ maximum=2048,
289
+ step=1,
290
+ interactive=True,
291
+ )
292
+
 
293
  # with gr.Row():
294
  # gr.Markdown(
295
  # "demo 2",
 
309
  top_p,
310
  top_k,
311
  repetition_penalty,
312
+ max_new_tokens,
313
  conversation_id,
314
  ],
315
  outputs=chatbot,
316
  queue=True,
317
  )
318
+ submit_click_event = single_submit.click(
319
  fn=user,
320
  inputs=[msg, chatbot],
321
  outputs=[msg, chatbot],
 
328
  top_p,
329
  top_k,
330
  repetition_penalty,
331
+ max_new_tokens,
332
+ conversation_id,
333
+ ],
334
+ outputs=chatbot,
335
+ queue=True,
336
+ )
337
+ multi_click_event = multi_submit.click(
338
+ fn=user,
339
+ inputs=[msg, chatbot],
340
+ outputs=[msg, chatbot],
341
+ queue=False,
342
+ ).then(
343
+ fn=multi_bot,
344
+ inputs=[
345
+ chatbot,
346
+ temperature,
347
+ top_p,
348
+ top_k,
349
+ repetition_penalty,
350
+ max_new_tokens,
351
  conversation_id,
352
  ],
353
  outputs=chatbot,