cwkuo commited on
Commit
ef2dc13
·
1 Parent(s): d8c6a57

disable beam search as it may cause OoM

Browse files
Files changed (1) hide show
  1. app.py +14 -27
app.py CHANGED
@@ -159,7 +159,7 @@ def retrieve_knowledge(image):
159
 
160
 
161
  @torch.inference_mode()
162
- def generate(state: Conversation, temperature, top_p, max_new_tokens, add_knwl, do_sampling, do_beam_search):
163
  if state.skip_next: # This generate call is skipped due to invalid inputs
164
  yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 3 + knwl_unchange
165
  return
@@ -210,36 +210,24 @@ def generate(state: Conversation, temperature, top_p, max_new_tokens, add_knwl,
210
  prompt = prompt.split("USER:")[-1].replace("ASSISTANT:", "")
211
  image_pt = gptk_trans(image).to(device).unsqueeze(0)
212
  samples = {"image": image_pt, "knowledge": knwl_embd, "prompt": prompt}
213
- if bool(do_beam_search):
214
- new_text = gptk_model.generate(
 
 
 
 
215
  samples=samples,
216
  use_nucleus_sampling=bool(do_sampling),
217
  max_length=min(int(max_new_tokens), 1024),
218
  top_p=float(top_p),
219
  temperature=float(temperature),
 
 
220
  length_penalty=0.0,
221
  auto_cast=True
222
- )[0]
223
- streamer = [new_text, ]
224
- else:
225
- streamer = TextIteratorStreamer(
226
- gptk_model.llm_tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15
227
  )
228
- thread = Thread(
229
- target=gptk_model.generate,
230
- kwargs=dict(
231
- samples=samples,
232
- use_nucleus_sampling=bool(do_sampling),
233
- max_length=min(int(max_new_tokens), 1024),
234
- top_p=float(top_p),
235
- temperature=float(temperature),
236
- streamer=streamer,
237
- num_beams=1,
238
- length_penalty=0.0,
239
- auto_cast=True
240
- )
241
- )
242
- thread.start()
243
 
244
  generated_text = ""
245
  for new_text in streamer:
@@ -301,7 +289,6 @@ def build_demo():
301
  with gr.Row():
302
  add_knwl = gr.Checkbox(value=True, interactive=True, label="Knowledge")
303
  do_sampling = gr.Checkbox(value=False, interactive=True, label="Sampling")
304
- do_beam_search = gr.Checkbox(value=False, interactive=True, label="Beam search")
305
  temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, interactive=True, label="Temperature",)
306
  top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
307
  max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
@@ -331,7 +318,7 @@ def build_demo():
331
  regenerate, [state], [state, chatbot, textbox, imagebox] + btn_list
332
  ).then(
333
  generate,
334
- [state, temperature, top_p, max_output_tokens, add_knwl, do_sampling, do_beam_search],
335
  [state, chatbot] + btn_list + knwl_vis
336
  )
337
 
@@ -343,7 +330,7 @@ def build_demo():
343
  add_text, [state, textbox, imagebox], [state, chatbot, textbox, imagebox] + btn_list
344
  ).then(
345
  generate,
346
- [state, temperature, top_p, max_output_tokens, add_knwl, do_sampling, do_beam_search],
347
  [state, chatbot] + btn_list + knwl_vis
348
  )
349
 
@@ -351,7 +338,7 @@ def build_demo():
351
  add_text, [state, textbox, imagebox], [state, chatbot, textbox, imagebox] + btn_list
352
  ).then(
353
  generate,
354
- [state, temperature, top_p, max_output_tokens, add_knwl, do_sampling, do_beam_search],
355
  [state, chatbot] + btn_list + knwl_vis
356
  )
357
 
 
159
 
160
 
161
  @torch.inference_mode()
162
+ def generate(state: Conversation, temperature, top_p, max_new_tokens, add_knwl, do_sampling):
163
  if state.skip_next: # This generate call is skipped due to invalid inputs
164
  yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 3 + knwl_unchange
165
  return
 
210
  prompt = prompt.split("USER:")[-1].replace("ASSISTANT:", "")
211
  image_pt = gptk_trans(image).to(device).unsqueeze(0)
212
  samples = {"image": image_pt, "knowledge": knwl_embd, "prompt": prompt}
213
+ streamer = TextIteratorStreamer(
214
+ gptk_model.llm_tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15
215
+ )
216
+ thread = Thread(
217
+ target=gptk_model.generate,
218
+ kwargs=dict(
219
  samples=samples,
220
  use_nucleus_sampling=bool(do_sampling),
221
  max_length=min(int(max_new_tokens), 1024),
222
  top_p=float(top_p),
223
  temperature=float(temperature),
224
+ streamer=streamer,
225
+ num_beams=1,
226
  length_penalty=0.0,
227
  auto_cast=True
 
 
 
 
 
228
  )
229
+ )
230
+ thread.start()
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
  generated_text = ""
233
  for new_text in streamer:
 
289
  with gr.Row():
290
  add_knwl = gr.Checkbox(value=True, interactive=True, label="Knowledge")
291
  do_sampling = gr.Checkbox(value=False, interactive=True, label="Sampling")
 
292
  temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, interactive=True, label="Temperature",)
293
  top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
294
  max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
 
318
  regenerate, [state], [state, chatbot, textbox, imagebox] + btn_list
319
  ).then(
320
  generate,
321
+ [state, temperature, top_p, max_output_tokens, add_knwl, do_sampling],
322
  [state, chatbot] + btn_list + knwl_vis
323
  )
324
 
 
330
  add_text, [state, textbox, imagebox], [state, chatbot, textbox, imagebox] + btn_list
331
  ).then(
332
  generate,
333
+ [state, temperature, top_p, max_output_tokens, add_knwl, do_sampling],
334
  [state, chatbot] + btn_list + knwl_vis
335
  )
336
 
 
338
  add_text, [state, textbox, imagebox], [state, chatbot, textbox, imagebox] + btn_list
339
  ).then(
340
  generate,
341
+ [state, temperature, top_p, max_output_tokens, add_knwl, do_sampling],
342
  [state, chatbot] + btn_list + knwl_vis
343
  )
344