Spaces:
Runtime error
Runtime error
cwkuo
commited on
Commit
·
ef2dc13
1
Parent(s):
d8c6a57
disable beam search as it may cause OoM
Browse files
app.py
CHANGED
@@ -159,7 +159,7 @@ def retrieve_knowledge(image):
|
|
159 |
|
160 |
|
161 |
@torch.inference_mode()
|
162 |
-
def generate(state: Conversation, temperature, top_p, max_new_tokens, add_knwl, do_sampling
|
163 |
if state.skip_next: # This generate call is skipped due to invalid inputs
|
164 |
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 3 + knwl_unchange
|
165 |
return
|
@@ -210,36 +210,24 @@ def generate(state: Conversation, temperature, top_p, max_new_tokens, add_knwl,
|
|
210 |
prompt = prompt.split("USER:")[-1].replace("ASSISTANT:", "")
|
211 |
image_pt = gptk_trans(image).to(device).unsqueeze(0)
|
212 |
samples = {"image": image_pt, "knowledge": knwl_embd, "prompt": prompt}
|
213 |
-
|
214 |
-
|
|
|
|
|
|
|
|
|
215 |
samples=samples,
|
216 |
use_nucleus_sampling=bool(do_sampling),
|
217 |
max_length=min(int(max_new_tokens), 1024),
|
218 |
top_p=float(top_p),
|
219 |
temperature=float(temperature),
|
|
|
|
|
220 |
length_penalty=0.0,
|
221 |
auto_cast=True
|
222 |
-
)[0]
|
223 |
-
streamer = [new_text, ]
|
224 |
-
else:
|
225 |
-
streamer = TextIteratorStreamer(
|
226 |
-
gptk_model.llm_tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15
|
227 |
)
|
228 |
-
|
229 |
-
|
230 |
-
kwargs=dict(
|
231 |
-
samples=samples,
|
232 |
-
use_nucleus_sampling=bool(do_sampling),
|
233 |
-
max_length=min(int(max_new_tokens), 1024),
|
234 |
-
top_p=float(top_p),
|
235 |
-
temperature=float(temperature),
|
236 |
-
streamer=streamer,
|
237 |
-
num_beams=1,
|
238 |
-
length_penalty=0.0,
|
239 |
-
auto_cast=True
|
240 |
-
)
|
241 |
-
)
|
242 |
-
thread.start()
|
243 |
|
244 |
generated_text = ""
|
245 |
for new_text in streamer:
|
@@ -301,7 +289,6 @@ def build_demo():
|
|
301 |
with gr.Row():
|
302 |
add_knwl = gr.Checkbox(value=True, interactive=True, label="Knowledge")
|
303 |
do_sampling = gr.Checkbox(value=False, interactive=True, label="Sampling")
|
304 |
-
do_beam_search = gr.Checkbox(value=False, interactive=True, label="Beam search")
|
305 |
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, interactive=True, label="Temperature",)
|
306 |
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
|
307 |
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
|
@@ -331,7 +318,7 @@ def build_demo():
|
|
331 |
regenerate, [state], [state, chatbot, textbox, imagebox] + btn_list
|
332 |
).then(
|
333 |
generate,
|
334 |
-
[state, temperature, top_p, max_output_tokens, add_knwl, do_sampling
|
335 |
[state, chatbot] + btn_list + knwl_vis
|
336 |
)
|
337 |
|
@@ -343,7 +330,7 @@ def build_demo():
|
|
343 |
add_text, [state, textbox, imagebox], [state, chatbot, textbox, imagebox] + btn_list
|
344 |
).then(
|
345 |
generate,
|
346 |
-
[state, temperature, top_p, max_output_tokens, add_knwl, do_sampling
|
347 |
[state, chatbot] + btn_list + knwl_vis
|
348 |
)
|
349 |
|
@@ -351,7 +338,7 @@ def build_demo():
|
|
351 |
add_text, [state, textbox, imagebox], [state, chatbot, textbox, imagebox] + btn_list
|
352 |
).then(
|
353 |
generate,
|
354 |
-
[state, temperature, top_p, max_output_tokens, add_knwl, do_sampling
|
355 |
[state, chatbot] + btn_list + knwl_vis
|
356 |
)
|
357 |
|
|
|
159 |
|
160 |
|
161 |
@torch.inference_mode()
|
162 |
+
def generate(state: Conversation, temperature, top_p, max_new_tokens, add_knwl, do_sampling):
|
163 |
if state.skip_next: # This generate call is skipped due to invalid inputs
|
164 |
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 3 + knwl_unchange
|
165 |
return
|
|
|
210 |
prompt = prompt.split("USER:")[-1].replace("ASSISTANT:", "")
|
211 |
image_pt = gptk_trans(image).to(device).unsqueeze(0)
|
212 |
samples = {"image": image_pt, "knowledge": knwl_embd, "prompt": prompt}
|
213 |
+
streamer = TextIteratorStreamer(
|
214 |
+
gptk_model.llm_tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15
|
215 |
+
)
|
216 |
+
thread = Thread(
|
217 |
+
target=gptk_model.generate,
|
218 |
+
kwargs=dict(
|
219 |
samples=samples,
|
220 |
use_nucleus_sampling=bool(do_sampling),
|
221 |
max_length=min(int(max_new_tokens), 1024),
|
222 |
top_p=float(top_p),
|
223 |
temperature=float(temperature),
|
224 |
+
streamer=streamer,
|
225 |
+
num_beams=1,
|
226 |
length_penalty=0.0,
|
227 |
auto_cast=True
|
|
|
|
|
|
|
|
|
|
|
228 |
)
|
229 |
+
)
|
230 |
+
thread.start()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
|
232 |
generated_text = ""
|
233 |
for new_text in streamer:
|
|
|
289 |
with gr.Row():
|
290 |
add_knwl = gr.Checkbox(value=True, interactive=True, label="Knowledge")
|
291 |
do_sampling = gr.Checkbox(value=False, interactive=True, label="Sampling")
|
|
|
292 |
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, interactive=True, label="Temperature",)
|
293 |
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
|
294 |
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
|
|
|
318 |
regenerate, [state], [state, chatbot, textbox, imagebox] + btn_list
|
319 |
).then(
|
320 |
generate,
|
321 |
+
[state, temperature, top_p, max_output_tokens, add_knwl, do_sampling],
|
322 |
[state, chatbot] + btn_list + knwl_vis
|
323 |
)
|
324 |
|
|
|
330 |
add_text, [state, textbox, imagebox], [state, chatbot, textbox, imagebox] + btn_list
|
331 |
).then(
|
332 |
generate,
|
333 |
+
[state, temperature, top_p, max_output_tokens, add_knwl, do_sampling],
|
334 |
[state, chatbot] + btn_list + knwl_vis
|
335 |
)
|
336 |
|
|
|
338 |
add_text, [state, textbox, imagebox], [state, chatbot, textbox, imagebox] + btn_list
|
339 |
).then(
|
340 |
generate,
|
341 |
+
[state, temperature, top_p, max_output_tokens, add_knwl, do_sampling],
|
342 |
[state, chatbot] + btn_list + knwl_vis
|
343 |
)
|
344 |
|