merve HF staff commited on
Commit
1cabf49
·
verified ·
1 Parent(s): be70f48
Files changed (1) hide show
  1. app.py +61 -36
app.py CHANGED
@@ -30,66 +30,91 @@ def sample_frames(video_file, num_frames):
30
  video.release()
31
  return frames
32
 
33
- @spaces.GPU
34
  def bot_streaming(message, history):
35
 
36
- txt = message.text
37
- ext_buffer = f"user\n{txt} assistant"
38
 
39
- if message.files:
40
- if len(message.files) == 1:
41
  image = [message.files[0].path]
42
  # interleaved images or video
43
- elif len(message.files) > 1:
44
- image = [msg.path for msg in message.files]
45
  else:
46
- # if there's no image uploaded for this turn, look for images in the past turns
47
- # kept inside tuples, take the last one
48
- for hist in history:
49
- if type(hist[0])==tuple:
50
- image = hist[0][0]
 
 
 
 
 
 
 
51
 
52
- if message.files is None:
 
 
53
  gr.Error("You need to upload an image or video for LLaVA to work.")
54
 
55
  video_extensions = ("avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg")
56
  image_extensions = Image.registered_extensions()
57
  image_extensions = tuple([ex for ex, f in image_extensions.items()])
 
 
 
 
58
  if len(image) == 1:
59
  if image[0].endswith(video_extensions):
60
 
61
- video = sample_frames(image[0], 32)
62
- image = None
63
- prompt = f"<|im_start|>user <video>\n{message.text}<|im_end|><|im_start|>assistant"
64
  elif image[0].endswith(image_extensions):
65
- image = Image.open(image[0]).convert("RGB")
66
- video = None
67
- prompt = f"<|im_start|>user <image>\n{message.text}<|im_end|><|im_start|>assistant"
68
 
69
  elif len(image) > 1:
70
- image_list = []
71
- user_prompt = message.text
72
 
73
  for img in image:
74
  if img.endswith(image_extensions):
75
  img = Image.open(img).convert("RGB")
76
  image_list.append(img)
77
 
78
- elif img.endswith(video_extensions):
79
- frames = sample_frames(img, 6)
80
- for frame in frames:
81
- image_list.append(frame)
82
-
83
- toks = "<image>" * len(image_list)
84
- prompt = "<|im_start|>user"+ toks + f"\n{user_prompt}<|im_end|><|im_start|>assistant"
85
 
86
- image = image_list
87
- video = None
 
 
 
 
 
 
88
 
 
89
 
90
- inputs = processor(text=prompt, images=image, videos=video, return_tensors="pt").to("cuda", torch.float16)
91
- streamer = TextIteratorStreamer(processor, **{"max_new_tokens": 200, "skip_special_tokens": True})
92
- generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=100)
 
 
 
 
 
 
 
93
  generated_text = ""
94
 
95
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
@@ -101,10 +126,10 @@ def bot_streaming(message, history):
101
  for new_text in streamer:
102
 
103
  buffer += new_text
104
-
105
- generated_text_without_prompt = buffer[len(ext_buffer):]
106
  time.sleep(0.01)
107
- yield generated_text_without_prompt
108
 
109
 
110
  demo = gr.ChatInterface(fn=bot_streaming, title="LLaVA Onevision", examples=[
 
30
  video.release()
31
  return frames
32
 
 
33
  def bot_streaming(message, history):
34
 
35
+ txt = message["text"]
36
+ ext_buffer = f"USER: {txt} ASSISTANT: "
37
 
38
+ if message["files"]:
39
+ if len(message["files"]) == 1:
40
  image = [message.files[0].path]
41
  # interleaved images or video
42
+ elif len(message["files"]) > 1:
43
+ image = [msg["path"] for msg in message["files"]]
44
  else:
45
+
46
+ def has_file_data(lst):
47
+ return any(isinstance(item, FileData) for sublist in lst if isinstance(sublist, tuple) for item in sublist)
48
+
49
+ def extract_paths(lst):
50
+ return [item["path"] for sublist in lst if isinstance(sublist, tuple) for item in sublist if isinstance(item, FileData)]
51
+
52
+ latest_text_only_index = -1
53
+
54
+ for i, item in enumerate(history):
55
+ if all(isinstance(sub_item, str) for sub_item in item):
56
+ latest_text_only_index = i
57
 
58
+ image = [path for i, item in enumerate(history) if i < latest_text_only_index and has_file_data(item) for path in extract_paths(item)]
59
+
60
+ if message["files"] is None:
61
  gr.Error("You need to upload an image or video for LLaVA to work.")
62
 
63
  video_extensions = ("avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg")
64
  image_extensions = Image.registered_extensions()
65
  image_extensions = tuple([ex for ex, f in image_extensions.items()])
66
+ image_list = []
67
+ video_list = []
68
+
69
+ print("media", image)
70
  if len(image) == 1:
71
  if image[0].endswith(video_extensions):
72
 
73
+ video_list = sample_frames(image[0], 12)
74
+
75
+ prompt = f"USER: <video> {message.text} ASSISTANT:"
76
  elif image[0].endswith(image_extensions):
77
+ image_list.append(Image.open(image[0]).convert("RGB"))
78
+ msg = message["text"]
79
+ prompt = f"USER: <image> {message.text} ASSISTANT:"
80
 
81
  elif len(image) > 1:
82
+ user_prompt = message["text"]
 
83
 
84
  for img in image:
85
  if img.endswith(image_extensions):
86
  img = Image.open(img).convert("RGB")
87
  image_list.append(img)
88
 
89
+ elif img.endswith(video_extensions):
90
+ video_list.append(sample_frames(img, 7))
91
+ #for frame in sample_frames(img, 6):
92
+ #video_list.append(frame)
93
+
94
+ image_tokens = ""
95
+ video_tokens = ""
96
 
97
+ if image_list != []:
98
+ image_tokens = "<image>" * len(image_list)
99
+ if video_list != []:
100
+
101
+ toks = len(video_list)
102
+ video_tokens = "<video>" * toks
103
+
104
+
105
 
106
+ prompt = f"USER: {image_tokens}{video_tokens} {user_prompt} ASSISTANT:"
107
 
108
+ if image_list != [] and video_list != []:
109
+ inputs = processor(text=prompt, images=image_list, videos=video_list, padding=True, return_tensors="pt").to("cuda",torch.float16)
110
+ elif image_list != [] and video_list == []:
111
+ inputs = processor(text=prompt, images=image_list, padding=True, return_tensors="pt").to("cuda", torch.float16)
112
+ elif image_list == [] and video_list != []:
113
+ inputs = processor(text=prompt, videos=video_list, padding=True, return_tensors="pt").to("cuda", torch.float16)
114
+
115
+
116
+ streamer = TextIteratorStreamer(processor, **{"max_new_tokens": 200, "skip_special_tokens": True, "clean_up_tokenization_spaces":True})
117
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=200)
118
  generated_text = ""
119
 
120
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
 
126
  for new_text in streamer:
127
 
128
  buffer += new_text
129
+ print("new_text", new_text)
130
+ #generated_text_without_prompt = buffer[len(ext_buffer):][:-1]
131
  time.sleep(0.01)
132
+ yield buffer #generated_text_without_prompt
133
 
134
 
135
  demo = gr.ChatInterface(fn=bot_streaming, title="LLaVA Onevision", examples=[