paralym commited on
Commit
e2029e4
·
verified ·
1 Parent(s): c10811c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +134 -339
app.py CHANGED
@@ -1,22 +1,20 @@
1
- # from .demo_modelpart import InferenceDemo
2
  import gradio as gr
3
  import os
4
  from threading import Thread
5
-
6
- # import time
7
  import cv2
8
-
9
  import datetime
10
- # import copy
11
  import torch
12
-
13
  import spaces
14
  import numpy as np
 
 
 
 
15
 
16
  from llava import conversation as conversation_lib
17
  from llava.constants import DEFAULT_IMAGE_TOKEN
18
-
19
-
20
  from llava.constants import (
21
  IMAGE_TOKEN_INDEX,
22
  DEFAULT_IMAGE_TOKEN,
@@ -37,16 +35,7 @@ from serve_constants import html_header
37
  import requests
38
  from PIL import Image
39
  from io import BytesIO
40
- from transformers import TextStreamer, TextIteratorStreamer
41
-
42
- import hashlib
43
- import PIL
44
- import base64
45
-
46
- import gradio as gr
47
- import gradio_client
48
- import subprocess
49
- import sys
50
 
51
  external_log_dir = "./logs"
52
  LOGDIR = external_log_dir
@@ -61,13 +50,9 @@ def install_gradio_4_35_0():
61
  else:
62
  print("Gradio 4.35.0 is already installed.")
63
 
64
- # Call the function to install Gradio 4.35.0 if needed
65
  install_gradio_4_35_0()
66
 
67
- import gradio as gr
68
- import gradio_client
69
  print(f"Gradio version: {gr.__version__}")
70
- print(f"Gradio-client version: {gradio_client.__version__}")
71
 
72
  def get_conv_log_filename():
73
  t = datetime.datetime.now()
@@ -80,12 +65,12 @@ class InferenceDemo(object):
80
  ) -> None:
81
  disable_torch_init()
82
 
83
- self.tokenizer, self.model, self.image_processor, self.context_len = (
84
- tokenizer,
85
- model,
86
- image_processor,
87
- context_len,
88
- )
89
 
90
  if "llama-2" in model_name.lower():
91
  conv_mode = "llava_llama_2"
@@ -108,31 +93,43 @@ class InferenceDemo(object):
108
  )
109
  else:
110
  args.conv_mode = conv_mode
 
111
  self.conv_mode = conv_mode
112
  self.conversation = conv_templates[args.conv_mode].copy()
113
  self.num_frames = args.num_frames
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  def is_valid_video_filename(name):
117
  video_extensions = ["avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg"]
118
-
119
  ext = name.split(".")[-1].lower()
120
-
121
- if ext in video_extensions:
122
- return True
123
- else:
124
- return False
125
 
126
  def is_valid_image_filename(name):
127
- image_extensions = ["jpg", "jpeg", "png", "bmp", "gif", "tiff", "webp", "heic", "heif", "jfif", "svg", "eps", "raw"]
128
-
129
  ext = name.split(".")[-1].lower()
130
-
131
- if ext in image_extensions:
132
- return True
133
- else:
134
- return False
135
-
136
 
137
  def sample_frames(video_file, num_frames):
138
  video = cv2.VideoCapture(video_file)
@@ -141,54 +138,33 @@ def sample_frames(video_file, num_frames):
141
  frames = []
142
  for i in range(total_frames):
143
  ret, frame = video.read()
144
- pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
145
  if not ret:
146
  continue
 
147
  if i % interval == 0:
148
  frames.append(pil_img)
149
  video.release()
150
  return frames
151
 
152
-
153
  def load_image(image_file):
154
- if image_file.startswith("http") or image_file.startswith("https"):
155
  response = requests.get(image_file)
156
  if response.status_code == 200:
157
  image = Image.open(BytesIO(response.content)).convert("RGB")
158
  else:
159
- print("failed to load the image")
 
160
  else:
161
- print("Load image from local file")
162
- print(image_file)
163
  image = Image.open(image_file).convert("RGB")
164
-
165
  return image
166
 
167
-
168
  def clear_history(history):
169
-
170
  our_chatbot.conversation = conv_templates[our_chatbot.conv_mode].copy()
171
-
172
  return None
173
 
174
-
175
- def clear_response(history):
176
- for index_conv in range(1, len(history)):
177
- # loop until get a text response from our model.
178
- conv = history[-index_conv]
179
- if not (conv[0] is None):
180
- break
181
- question = history[-index_conv][0]
182
- history = history[:-index_conv]
183
- return history, question
184
-
185
-
186
- # def print_like_dislike(x: gr.LikeData):
187
- # print(x.index, x.value, x.liked)
188
-
189
-
190
  def add_message(history, message):
191
- # history=[]
192
  global our_chatbot
193
  if len(history) == 0:
194
  our_chatbot = InferenceDemo(
@@ -201,36 +177,46 @@ def add_message(history, message):
201
  history.append((message["text"], None))
202
  return history, gr.MultimodalTextbox(value=None, interactive=False)
203
 
204
-
205
  @spaces.GPU
206
  def bot(history):
 
 
 
207
  text = history[-1][0]
208
  images_this_term = []
209
- text_this_term = ""
210
- # import pdb;pdb.set_trace()
211
  num_new_images = 0
 
212
  for i, message in enumerate(history[:-1]):
213
- if type(message[0]) is tuple:
214
  images_this_term.append(message[0][0])
215
  if is_valid_video_filename(message[0][0]):
216
- # 不接受视频
217
  raise ValueError("Video is not supported")
218
- num_new_images += our_chatbot.num_frames
219
  elif is_valid_image_filename(message[0][0]):
220
- print("#### Load image from local file",message[0][0])
221
  num_new_images += 1
222
  else:
223
  raise ValueError("Invalid image file")
224
  else:
225
  num_new_images = 0
226
 
227
- # for message in history[-i-1:]:
228
- # images_this_term.append(message[0][0])
229
-
230
- assert len(images_this_term) > 0, "must have an image"
231
- # image_files = (args.image_file).split(',')
232
- # image = [load_image(f) for f in images_this_term if f]
233
-
 
 
 
 
 
 
 
 
 
 
 
 
234
  all_image_hash = []
235
  for image_path in images_this_term:
236
  with open(image_path, "rb") as image_file:
@@ -248,35 +234,12 @@ def bot(history):
248
  if not os.path.isfile(filename):
249
  os.makedirs(os.path.dirname(filename), exist_ok=True)
250
  image.save(filename)
251
-
252
- image_list = []
253
- for f in images_this_term:
254
- if is_valid_video_filename(f):
255
- image_list += sample_frames(f, our_chatbot.num_frames)
256
- elif is_valid_image_filename(f):
257
- image_list.append(load_image(f))
258
- else:
259
- raise ValueError("Invalid image file")
260
-
261
- image_tensor = [
262
- our_chatbot.image_processor.preprocess(f, return_tensors="pt")["pixel_values"][
263
- 0
264
- ]
265
- .half()
266
- .to(our_chatbot.model.device)
267
- for f in image_list
268
- ]
269
-
270
 
271
  image_tensor = torch.stack(image_tensor)
272
  image_token = DEFAULT_IMAGE_TOKEN * num_new_images
273
- # if our_chatbot.model.config.mm_use_im_start_end:
274
- # inp = DEFAULT_IM_START_TOKEN + image_token + DEFAULT_IM_END_TOKEN + "\n" + inp
275
- # else:
276
- inp = text
277
- inp = image_token + "\n" + inp
278
  our_chatbot.conversation.append_message(our_chatbot.conversation.roles[0], inp)
279
- # image = None
280
  our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], None)
281
  prompt = our_chatbot.conversation.get_prompt()
282
 
@@ -287,10 +250,7 @@ def bot(history):
287
  .unsqueeze(0)
288
  .to(our_chatbot.model.device)
289
  )
290
- # input_ids = tokenizer_image_token(
291
- # prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
292
- # ).unsqueeze(0).to(our_chatbot.model.device)
293
- # print("### input_id",input_ids)
294
  stop_str = (
295
  our_chatbot.conversation.sep
296
  if our_chatbot.conversation.sep_style != SeparatorStyle.TWO
@@ -300,16 +260,23 @@ def bot(history):
300
  stopping_criteria = KeywordsStoppingCriteria(
301
  keywords, our_chatbot.tokenizer, input_ids
302
  )
303
- streamer = TextStreamer(
304
- our_chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True
 
 
 
 
 
 
 
 
 
 
 
305
  )
306
- # streamer = TextIteratorStreamer(
307
- # our_chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True
308
- # )
309
- print(our_chatbot.model.device)
310
- print(input_ids.device)
311
- print(image_tensor.device)
312
- # import pdb;pdb.set_trace()
313
  with torch.inference_mode():
314
  output_ids = our_chatbot.model.generate(
315
  input_ids,
@@ -318,86 +285,29 @@ def bot(history):
318
  temperature=0.2,
319
  max_new_tokens=1024,
320
  streamer=streamer,
321
- use_cache=False,
322
  stopping_criteria=[stopping_criteria],
323
  )
324
 
325
- outputs = our_chatbot.tokenizer.decode(output_ids[0]).strip()
326
- if outputs.endswith(stop_str):
327
- outputs = outputs[: -len(stop_str)]
328
- our_chatbot.conversation.messages[-1][-1] = outputs
329
-
330
- history[-1] = [text, outputs]
331
- print("#### history",history)
332
 
 
333
  with open(get_conv_log_filename(), "a") as fout:
334
  data = {
 
335
  "type": "chat",
336
  "model": "Pangea-7b",
 
 
337
  "state": history,
338
  "images": all_image_hash,
339
  }
340
  fout.write(json.dumps(data) + "\n")
341
- return history
342
- # generate_kwargs = dict(
343
- # inputs=input_ids,
344
- # streamer=streamer,
345
- # images=image_tensor,
346
- # max_new_tokens=1024,
347
- # do_sample=True,
348
- # temperature=0.2,
349
- # num_beams=1,
350
- # use_cache=False,
351
- # stopping_criteria=[stopping_criteria],
352
- # )
353
-
354
- # t = Thread(target=our_chatbot.model.generate, kwargs=generate_kwargs)
355
- # t.start()
356
-
357
- # outputs = []
358
- # for text in streamer:
359
- # outputs.append(text)
360
- # yield "".join(outputs)
361
-
362
- # our_chatbot.conversation.messages[-1][-1] = "".join(outputs)
363
- # history[-1] = [text, "".join(outputs)]
364
-
365
-
366
- txt = gr.Textbox(
367
- scale=4,
368
- show_label=False,
369
- placeholder="Enter text and press enter.",
370
- container=False,
371
- )
372
 
373
- with gr.Blocks(
374
- css=".message-wrap.svelte-1lcyrx4>div.svelte-1lcyrx4 img {min-width: 40px}",
375
- ) as demo:
376
 
377
- # Informations
378
- title_markdown = """
379
- # LLaVA-NeXT Interleave
380
- [[Blog]](https://llava-vl.github.io/blog/2024-06-16-llava-next-interleave/) [[Code]](https://github.com/LLaVA-VL/LLaVA-NeXT) [[Model]](https://huggingface.co/lmms-lab/llava-next-interleave-7b)
381
- Note: The internleave checkpoint is updated (Date: Jul. 24, 2024), the wrong checkpiont is used before.
382
- """
383
-
384
-
385
- tos_markdown = """
386
- ### TODO!. Terms of use
387
- By using this service, users are required to agree to the following terms:
388
- The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
389
- Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
390
- For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
391
- """
392
- learn_more_markdown = """
393
- ### TODO!. License
394
- The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
395
- """
396
- models = [
397
- "LLaVA-Interleave-7B",
398
- ]
399
- cur_dir = os.path.dirname(os.path.abspath(__file__))
400
- # gr.Markdown(title_markdown)
401
  gr.HTML(html_header)
402
 
403
  with gr.Column():
@@ -408,10 +318,8 @@ with gr.Blocks(
408
  upvote_btn = gr.Button(value="👍 Upvote", interactive=True)
409
  downvote_btn = gr.Button(value="👎 Downvote", interactive=True)
410
  flag_btn = gr.Button(value="⚠️ Flag", interactive=True)
411
- # stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=True)
412
  regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=True)
413
  clear_btn = gr.Button(value="🗑️ Clear history", interactive=True)
414
-
415
 
416
  chat_input = gr.MultimodalTextbox(
417
  interactive=True,
@@ -421,11 +329,11 @@ with gr.Blocks(
421
  submit_btn="🚀"
422
  )
423
 
424
- print(cur_dir)
425
  gr.Examples(
426
- examples_per_page=20,
427
- examples=[
428
- [
429
  {
430
  "files": [
431
  f"{cur_dir}/examples/user_example_07.jpg",
@@ -449,160 +357,47 @@ with gr.Blocks(
449
  "text": "Why this image funny?",
450
  },
451
  ],
452
- [
453
- {
454
- "files": [
455
- f"{cur_dir}/examples/norway.jpg",
456
- ],
457
- "text": "Analysieren, in welchem Land diese Szene höchstwahrscheinlich gedreht wurde.",
458
- },
459
- ],
460
- [
461
- {
462
- "files": [
463
- f"{cur_dir}/examples/totoro.jpg",
464
- ],
465
- "text": "¿En qué anime aparece esta escena? ¿Puedes presentarlo?",
466
- },
467
- ],
468
- [
469
- {
470
- "files": [
471
- f"{cur_dir}/examples/africa.jpg",
472
- ],
473
- "text": "इस तस्वीर में हर एक दृश्य तत्व का क्या प्रतिनिधित्व करता है?",
474
- },
475
- ],
476
- [
477
- {
478
- "files": [
479
- f"{cur_dir}/examples/hot_ballon.jpg",
480
- ],
481
- "text": "ฉากบอลลูนลมร้อนในภาพนี้อาจอยู่ที่ไหน? สถานที่นี้มีความพิเศษอย่างไร?",
482
- },
483
- ],
484
- [
485
- {
486
- "files": [
487
- f"{cur_dir}/examples/bar.jpg",
488
- ],
489
- "text": "Você pode me dar ideias de design baseadas no tema de coquetéis deste letreiro?",
490
- },
491
- ],
492
- [
493
- {
494
- "files": [
495
- f"{cur_dir}/examples/pink_lake.jpg",
496
- ],
497
- "text": "Обясни защо езерото на този остров е в този цвят.",
498
- },
499
- ],
500
- [
501
- {
502
- "files": [
503
- f"{cur_dir}/examples/hanzi.jpg",
504
- ],
505
- "text": "Can you describe in Hebrew the evolution process of these four Chinese characters from pictographs to modern characters?",
506
- },
507
- ],
508
- [
509
- {
510
- "files": [
511
- f"{cur_dir}/examples/ballon.jpg",
512
- ],
513
- "text": "இந்த காட்சியை விவரிக்கவும், மேலும் இந்த படத்தின் அடிப்படையில் துருக்கியில் இந்த காட்சியுடன் தொடர்பான சில பிரபலமான நிகழ்வுகள் என்ன?",
514
- },
515
- ],
516
- [
517
- {
518
- "files": [
519
- f"{cur_dir}/examples/pie.jpg",
520
- ],
521
- "text": "Décrivez ce graphique. Quelles informations pouvons-nous en tirer?",
522
- },
523
- ],
524
- [
525
- {
526
- "files": [
527
- f"{cur_dir}/examples/camera.jpg",
528
- ],
529
- "text": "Apa arti dari dua angka di sebelah kiri yang ditampilkan di layar kamera?",
530
- },
531
- ],
532
- [
533
- {
534
- "files": [
535
- f"{cur_dir}/examples/dog.jpg",
536
- ],
537
- "text": "이 강아지의 표정을 보고 어떤 기분이나 감정을 느끼고 있는지 설명해 주시겠어요?",
538
- },
539
- ],
540
- [
541
- {
542
- "files": [
543
- f"{cur_dir}/examples/book.jpg",
544
- ],
545
- "text": "What language is the text in, and what does the title mean in English?",
546
- },
547
- ],
548
- [
549
- {
550
- "files": [
551
- f"{cur_dir}/examples/food.jpg",
552
- ],
553
- "text": "Unaweza kunipa kichocheo cha kutengeneza hii pancake?",
554
- },
555
- ],
556
- [
557
- {
558
- "files": [
559
- f"{cur_dir}/examples/line chart.jpg",
560
- ],
561
- "text": "Hãy trình bày những xu hướng mà bạn quan sát được từ biểu đồ và hiện tượng xã hội tiềm ẩn từ đó.",
562
- },
563
- ],
564
- [
565
- {
566
- "files": [
567
- f"{cur_dir}/examples/south africa.jpg",
568
- ],
569
- "text": "Waar is hierdie plek? Help my om ’n reisroete vir hierdie land te beplan.",
570
- },
571
- ],
572
- [
573
- {
574
- "files": [
575
- f"{cur_dir}/examples/girl.jpg",
576
- ],
577
- "text": "لماذا هذه الصورة مضحكة؟",
578
- },
579
- ],
580
- [
581
- {
582
- "files": [
583
- f"{cur_dir}/examples/eagles.jpg",
584
- ],
585
- "text": "Какой креатив должен быть в этом логотипе?",
586
- },
587
- ],
588
- ],
589
- inputs=[chat_input],
590
- label="Image",
591
- )
592
 
593
  chat_msg = chat_input.submit(
594
- add_message, [chatbot, chat_input], [chatbot, chat_input]
 
 
 
 
 
 
 
 
 
 
 
 
595
  )
596
- bot_msg = chat_msg.then(bot, chatbot, chatbot, api_name="bot_response")
597
- bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])
598
 
599
- # chatbot.like(print_like_dislike, None, None)
600
  clear_btn.click(
601
- fn=clear_history, inputs=[chatbot], outputs=[chatbot], api_name="clear_all"
 
 
 
 
602
  )
603
 
 
 
 
 
 
 
 
 
 
 
604
 
605
- demo.queue()
606
 
607
  if __name__ == "__main__":
608
  import argparse
 
 
1
  import gradio as gr
2
  import os
3
  from threading import Thread
4
+ from queue import Queue
5
+ import time
6
  import cv2
 
7
  import datetime
 
8
  import torch
 
9
  import spaces
10
  import numpy as np
11
+ import json
12
+ import hashlib
13
+ import PIL
14
+ from typing import Iterator
15
 
16
  from llava import conversation as conversation_lib
17
  from llava.constants import DEFAULT_IMAGE_TOKEN
 
 
18
  from llava.constants import (
19
  IMAGE_TOKEN_INDEX,
20
  DEFAULT_IMAGE_TOKEN,
 
35
  import requests
36
  from PIL import Image
37
  from io import BytesIO
38
+ from transformers import TextIteratorStreamer
 
 
 
 
 
 
 
 
 
39
 
40
  external_log_dir = "./logs"
41
  LOGDIR = external_log_dir
 
50
  else:
51
  print("Gradio 4.35.0 is already installed.")
52
 
 
53
  install_gradio_4_35_0()
54
 
 
 
55
  print(f"Gradio version: {gr.__version__}")
 
56
 
57
  def get_conv_log_filename():
58
  t = datetime.datetime.now()
 
65
  ) -> None:
66
  disable_torch_init()
67
 
68
+ self.tokenizer = tokenizer
69
+ self.model = model
70
+ self.image_processor = image_processor
71
+ self.context_len = context_len
72
+
73
+ model_name = get_model_name_from_path(model_path)
74
 
75
  if "llama-2" in model_name.lower():
76
  conv_mode = "llava_llama_2"
 
93
  )
94
  else:
95
  args.conv_mode = conv_mode
96
+
97
  self.conv_mode = conv_mode
98
  self.conversation = conv_templates[args.conv_mode].copy()
99
  self.num_frames = args.num_frames
100
 
101
+ def process_stream(streamer: TextIteratorStreamer, history: list, q: Queue):
102
+ """Process the output stream and put partial text into a queue"""
103
+ try:
104
+ current_message = ""
105
+ for new_text in streamer:
106
+ current_message += new_text
107
+ history[-1][1] = current_message
108
+ q.put(history.copy())
109
+ time.sleep(0.02) # Add a small delay to prevent overloading
110
+ except Exception as e:
111
+ print(f"Error in process_stream: {e}")
112
+ finally:
113
+ q.put(None) # Signal that we're done
114
+
115
+ def stream_output(history: list, q: Queue) -> Iterator[list]:
116
+ """Yield updated history as it comes through the queue"""
117
+ while True:
118
+ val = q.get()
119
+ if val is None:
120
+ break
121
+ yield val
122
+ q.task_done()
123
 
124
  def is_valid_video_filename(name):
125
  video_extensions = ["avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg"]
 
126
  ext = name.split(".")[-1].lower()
127
+ return ext in video_extensions
 
 
 
 
128
 
129
  def is_valid_image_filename(name):
130
+ image_extensions = ["jpg", "jpeg", "png", "bmp", "gif", "tiff", "webp", "heic", "heif", "jfif", "svg", "eps", "raw"]
 
131
  ext = name.split(".")[-1].lower()
132
+ return ext in image_extensions
 
 
 
 
 
133
 
134
  def sample_frames(video_file, num_frames):
135
  video = cv2.VideoCapture(video_file)
 
138
  frames = []
139
  for i in range(total_frames):
140
  ret, frame = video.read()
 
141
  if not ret:
142
  continue
143
+ pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
144
  if i % interval == 0:
145
  frames.append(pil_img)
146
  video.release()
147
  return frames
148
 
 
149
  def load_image(image_file):
150
+ if image_file.startswith(("http://", "https://")):
151
  response = requests.get(image_file)
152
  if response.status_code == 200:
153
  image = Image.open(BytesIO(response.content)).convert("RGB")
154
  else:
155
+ print("Failed to load the image")
156
+ return None
157
  else:
158
+ print("Load image from local file:", image_file)
 
159
  image = Image.open(image_file).convert("RGB")
 
160
  return image
161
 
 
162
  def clear_history(history):
163
+ global our_chatbot
164
  our_chatbot.conversation = conv_templates[our_chatbot.conv_mode].copy()
 
165
  return None
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  def add_message(history, message):
 
168
  global our_chatbot
169
  if len(history) == 0:
170
  our_chatbot = InferenceDemo(
 
177
  history.append((message["text"], None))
178
  return history, gr.MultimodalTextbox(value=None, interactive=False)
179
 
 
180
  @spaces.GPU
181
  def bot(history):
182
+ global start_tstamp, finish_tstamp
183
+
184
+ start_tstamp = time.time()
185
  text = history[-1][0]
186
  images_this_term = []
 
 
187
  num_new_images = 0
188
+
189
  for i, message in enumerate(history[:-1]):
190
+ if isinstance(message[0], tuple):
191
  images_this_term.append(message[0][0])
192
  if is_valid_video_filename(message[0][0]):
 
193
  raise ValueError("Video is not supported")
 
194
  elif is_valid_image_filename(message[0][0]):
 
195
  num_new_images += 1
196
  else:
197
  raise ValueError("Invalid image file")
198
  else:
199
  num_new_images = 0
200
 
201
+ assert len(images_this_term) > 0, "Must have an image"
202
+
203
+ image_list = []
204
+ for f in images_this_term:
205
+ if is_valid_video_filename(f):
206
+ image_list += sample_frames(f, our_chatbot.num_frames)
207
+ elif is_valid_image_filename(f):
208
+ image_list.append(load_image(f))
209
+ else:
210
+ raise ValueError("Invalid image file")
211
+
212
+ image_tensor = [
213
+ our_chatbot.image_processor.preprocess(f, return_tensors="pt")["pixel_values"][0]
214
+ .half()
215
+ .to(our_chatbot.model.device)
216
+ for f in image_list
217
+ ]
218
+
219
+ # Process image hashes
220
  all_image_hash = []
221
  for image_path in images_this_term:
222
  with open(image_path, "rb") as image_file:
 
234
  if not os.path.isfile(filename):
235
  os.makedirs(os.path.dirname(filename), exist_ok=True)
236
  image.save(filename)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
  image_tensor = torch.stack(image_tensor)
239
  image_token = DEFAULT_IMAGE_TOKEN * num_new_images
240
+ inp = image_token + "\n" + text
241
+
 
 
 
242
  our_chatbot.conversation.append_message(our_chatbot.conversation.roles[0], inp)
 
243
  our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], None)
244
  prompt = our_chatbot.conversation.get_prompt()
245
 
 
250
  .unsqueeze(0)
251
  .to(our_chatbot.model.device)
252
  )
253
+
 
 
 
254
  stop_str = (
255
  our_chatbot.conversation.sep
256
  if our_chatbot.conversation.sep_style != SeparatorStyle.TWO
 
260
  stopping_criteria = KeywordsStoppingCriteria(
261
  keywords, our_chatbot.tokenizer, input_ids
262
  )
263
+
264
+ # Set up streaming
265
+ q = Queue()
266
+ streamer = TextIteratorStreamer(
267
+ our_chatbot.tokenizer,
268
+ skip_prompt=True,
269
+ skip_special_tokens=True
270
+ )
271
+
272
+ # Start generation in a separate thread
273
+ thread = Thread(
274
+ target=process_stream,
275
+ args=(streamer, history, q)
276
  )
277
+ thread.start()
278
+
279
+ # Start the generation
 
 
 
 
280
  with torch.inference_mode():
281
  output_ids = our_chatbot.model.generate(
282
  input_ids,
 
285
  temperature=0.2,
286
  max_new_tokens=1024,
287
  streamer=streamer,
288
+ use_cache=True,
289
  stopping_criteria=[stopping_criteria],
290
  )
291
 
292
+ finish_tstamp = time.time()
 
 
 
 
 
 
293
 
294
+ # Log conversation
295
  with open(get_conv_log_filename(), "a") as fout:
296
  data = {
297
+ "tstamp": round(finish_tstamp, 4),
298
  "type": "chat",
299
  "model": "Pangea-7b",
300
+ "start": round(start_tstamp, 4),
301
+ "finish": round(finish_tstamp, 4),
302
  "state": history,
303
  "images": all_image_hash,
304
  }
305
  fout.write(json.dumps(data) + "\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
 
307
+ # Return a generator that will yield updated history
308
+ return stream_output(history, q)
 
309
 
310
+ with gr.Blocks(css=".message-wrap.svelte-1lcyrx4>div.svelte-1lcyrx4 img {min-width: 40px}") as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  gr.HTML(html_header)
312
 
313
  with gr.Column():
 
318
  upvote_btn = gr.Button(value="👍 Upvote", interactive=True)
319
  downvote_btn = gr.Button(value="👎 Downvote", interactive=True)
320
  flag_btn = gr.Button(value="⚠️ Flag", interactive=True)
 
321
  regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=True)
322
  clear_btn = gr.Button(value="🗑️ Clear history", interactive=True)
 
323
 
324
  chat_input = gr.MultimodalTextbox(
325
  interactive=True,
 
329
  submit_btn="🚀"
330
  )
331
 
332
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
333
  gr.Examples(
334
+ examples_per_page=20,
335
+ examples=[
336
+ [
337
  {
338
  "files": [
339
  f"{cur_dir}/examples/user_example_07.jpg",
 
357
  "text": "Why this image funny?",
358
  },
359
  ],
360
+ ],
361
+ inputs=[chat_input],
362
+ label="Image",
363
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
 
365
  chat_msg = chat_input.submit(
366
+ add_message,
367
+ [chatbot, chat_input],
368
+ [chatbot, chat_input],
369
+ queue=False
370
+ ).then(
371
+ bot,
372
+ chatbot,
373
+ chatbot,
374
+ api_name="bot_response"
375
+ ).then(
376
+ lambda: gr.MultimodalTextbox(interactive=True),
377
+ None,
378
+ [chat_input]
379
  )
 
 
380
 
 
381
  clear_btn.click(
382
+ fn=clear_history,
383
+ inputs=[chatbot],
384
+ outputs=[chatbot],
385
+ api_name="clear_all",
386
+ queue=False
387
  )
388
 
389
+ regenerate_btn.click(
390
+ fn=lambda history: history[:-1],
391
+ inputs=[chatbot],
392
+ outputs=[chatbot],
393
+ queue=False
394
+ ).then(
395
+ bot,
396
+ chatbot,
397
+ chatbot
398
+ )
399
 
400
+ demo.queue(concurrency_count=5)
401
 
402
  if __name__ == "__main__":
403
  import argparse