ynhe commited on
Commit
e3db206
1 Parent(s): 7c64e6d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +384 -0
app.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ try:
3
+ token =os.environ['HF_TOKEN']
4
+ except:
5
+ print("paste your hf token here!")
6
+ token = "hf_xxxxxxxxxxxxxxxxxxx"
7
+ os.environ['HF_TOKEN'] = token
8
+ import torch
9
+ import gradio as gr
10
+ from gradio.themes.utils import colors, fonts, sizes
11
+
12
+ from transformers import AutoTokenizer, AutoModel
13
+
14
+ # ========================================
15
+ # Model Initialization
16
+ # ========================================
17
+
18
+ tokenizer = AutoTokenizer.from_pretrained('OpenGVLab/InternVideo2_chat_8B_HD',
19
+ trust_remote_code=True,
20
+ use_fast=False,
21
+ token=token)
22
+ if torch.cuda.is_available():
23
+ model = AutoModel.from_pretrained(
24
+ 'OpenGVLab/InternVideo2_chat_8B_HD',
25
+ torch_dtype=torch.bfloat16,
26
+ trust_remote_code=True).cuda()
27
+ else:
28
+ model = AutoModel.from_pretrained(
29
+ 'OpenGVLab/InternVideo2_chat_8B_HD',
30
+ torch_dtype=torch.bfloat16,
31
+ trust_remote_code=True)
32
+
33
+ from decord import VideoReader, cpu
34
+ from PIL import Image
35
+ import numpy as np
36
+ import numpy as np
37
+ import decord
38
+ from decord import VideoReader, cpu
39
+ import torch.nn.functional as F
40
+ import torchvision.transforms as T
41
+ from torchvision.transforms import PILToTensor
42
+ from torchvision import transforms
43
+ from torchvision.transforms.functional import InterpolationMode
44
+ decord.bridge.set_bridge("torch")
45
+
46
+ # ========================================
47
+ # Define Utils
48
+ # ========================================
49
+ def get_index(num_frames, num_segments):
50
+ seg_size = float(num_frames - 1) / num_segments
51
+ start = int(seg_size / 2)
52
+ offsets = np.array([
53
+ start + int(np.round(seg_size * idx)) for idx in range(num_segments)
54
+ ])
55
+ return offsets
56
+
57
+
58
+ def load_video(video_path, num_segments=8, return_msg=False, resolution=224, hd_num=4, padding=False):
59
+ decord.bridge.set_bridge("torch")
60
+ vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
61
+ num_frames = len(vr)
62
+ frame_indices = get_index(num_frames, num_segments)
63
+
64
+ mean = (0.485, 0.456, 0.406)
65
+ std = (0.229, 0.224, 0.225)
66
+
67
+ transform = transforms.Compose([
68
+ transforms.Lambda(lambda x: x.float().div(255.0)),
69
+ transforms.Normalize(mean, std)
70
+ ])
71
+
72
+ frames = vr.get_batch(frame_indices)
73
+ # frames = torch.from_numpy(frames)
74
+ frames = frames.permute(0, 3, 1, 2)
75
+
76
+ if padding:
77
+ frames = HD_transform_padding(frames.float(), image_size=resolution, hd_num=hd_num)
78
+ else:
79
+ frames = HD_transform_no_padding(frames.float(), image_size=resolution, hd_num=hd_num)
80
+
81
+ frames = transform(frames)
82
+ # print(frames.shape)
83
+ T_, C, H, W = frames.shape
84
+
85
+ sub_img = frames.reshape(
86
+ 1, T_, 3, H//resolution, resolution, W//resolution, resolution
87
+ ).permute(0, 3, 5, 1, 2, 4, 6).reshape(-1, T_, 3, resolution, resolution).contiguous()
88
+
89
+ glb_img = F.interpolate(
90
+ frames.float(), size=(resolution, resolution), mode='bicubic', align_corners=False
91
+ ).to(sub_img.dtype).unsqueeze(0)
92
+
93
+ frames = torch.cat([sub_img, glb_img]).unsqueeze(0)
94
+
95
+ if return_msg:
96
+ fps = float(vr.get_avg_fps())
97
+ sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices])
98
+ # " " should be added in the start and end
99
+ msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds."
100
+ return frames, msg
101
+ else:
102
+ return frames
103
+
104
+ def HD_transform_padding(frames, image_size=224, hd_num=6):
105
+ def _padding_224(frames):
106
+ _, _, H, W = frames.shape
107
+ tar = int(np.ceil(H / 224) * 224)
108
+ top_padding = (tar - H) // 2
109
+ bottom_padding = tar - H - top_padding
110
+ left_padding = 0
111
+ right_padding = 0
112
+
113
+ padded_frames = F.pad(
114
+ frames,
115
+ pad=[left_padding, right_padding, top_padding, bottom_padding],
116
+ mode='constant', value=255
117
+ )
118
+ return padded_frames
119
+
120
+ _, _, H, W = frames.shape
121
+ trans = False
122
+ if W < H:
123
+ frames = frames.flip(-2, -1)
124
+ trans = True
125
+ width, height = H, W
126
+ else:
127
+ width, height = W, H
128
+
129
+ ratio = width / height
130
+ scale = 1
131
+ while scale * np.ceil(scale / ratio) <= hd_num:
132
+ scale += 1
133
+ scale -= 1
134
+ new_w = int(scale * image_size)
135
+ new_h = int(new_w / ratio)
136
+
137
+ resized_frames = F.interpolate(
138
+ frames, size=(new_h, new_w),
139
+ mode='bicubic',
140
+ align_corners=False
141
+ )
142
+ padded_frames = _padding_224(resized_frames)
143
+
144
+ if trans:
145
+ padded_frames = padded_frames.flip(-2, -1)
146
+
147
+ return padded_frames
148
+
149
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
150
+ best_ratio_diff = float('inf')
151
+ best_ratio = (1, 1)
152
+ area = width * height
153
+ for ratio in target_ratios:
154
+ target_aspect_ratio = ratio[0] / ratio[1]
155
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
156
+ if ratio_diff < best_ratio_diff:
157
+ best_ratio_diff = ratio_diff
158
+ best_ratio = ratio
159
+ elif ratio_diff == best_ratio_diff:
160
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
161
+ best_ratio = ratio
162
+ return best_ratio
163
+
164
+
165
+ def HD_transform_no_padding(frames, image_size=224, hd_num=6, fix_ratio=(2,1)):
166
+ min_num = 1
167
+ max_num = hd_num
168
+ _, _, orig_height, orig_width = frames.shape
169
+ aspect_ratio = orig_width / orig_height
170
+
171
+ # calculate the existing video aspect ratio
172
+ target_ratios = set(
173
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
174
+ i * j <= max_num and i * j >= min_num)
175
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
176
+
177
+ # find the closest aspect ratio to the target
178
+ if fix_ratio:
179
+ target_aspect_ratio = fix_ratio
180
+ else:
181
+ target_aspect_ratio = find_closest_aspect_ratio(
182
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
183
+
184
+ # calculate the target width and height
185
+ target_width = image_size * target_aspect_ratio[0]
186
+ target_height = image_size * target_aspect_ratio[1]
187
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
188
+
189
+ # resize the frames
190
+ resized_frame = F.interpolate(
191
+ frames, size=(target_height, target_width),
192
+ mode='bicubic', align_corners=False
193
+ )
194
+ return resized_frame
195
+
196
+ # ========================================
197
+ # Gradio Setting
198
+ # ========================================
199
+ def gradio_reset(chat_state, img_list):
200
+ if chat_state is not None:
201
+ chat_state = []
202
+ if img_list is not None:
203
+ img_list = None
204
+ return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your video first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list
205
+
206
+
207
+ def upload_img( gr_video, num_segments, hd_num, padding):
208
+ img_list = []
209
+ if gr_video is None:
210
+ return None, None, gr.update(interactive=True),gr.update(interactive=True, placeholder='Please upload video/image first!'), None
211
+ if gr_video:
212
+ video_tensor, msg = load_video(gr_video, num_segments=num_segments, return_msg=True, resolution=224, hd_num=hd_num, padding=padding)
213
+ video_tensor = video_tensor.to(model.device)
214
+ return gr.update(interactive=True), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), video_tensor
215
+ # if gr_img:
216
+ # llm_message, img_list,chat_state = chat.upload_img(gr_img, chat_state, img_list)
217
+ # return gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False)
218
+
219
+ def clear_():
220
+ return [], []
221
+
222
+ def gradio_ask(user_message, chatbot):
223
+ if len(user_message) == 0:
224
+ return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
225
+ chatbot = chatbot + [[user_message, None]]
226
+ return '', chatbot
227
+
228
+
229
+ def gradio_answer(chatbot, sys_prompt, user_prompt, video_tensor, chat_state, num_beams, temperature, do_sample=False):
230
+ response, chat_state = model.chat(tokenizer,
231
+ sys_prompt,
232
+ user_prompt,
233
+ media_type='video',
234
+ media_tensor=video_tensor,
235
+ chat_history= chat_state,
236
+ return_history=True,
237
+ generation_config={
238
+ "num_beams": num_beams,
239
+ "temperature": temperature,
240
+ "do_sample": do_sample})
241
+ print(response)
242
+ chatbot[-1][1] = response
243
+ return chatbot, chat_state
244
+
245
+
246
+ class OpenGVLab(gr.themes.base.Base):
247
+ def __init__(
248
+ self,
249
+ *,
250
+ primary_hue=colors.blue,
251
+ secondary_hue=colors.sky,
252
+ neutral_hue=colors.gray,
253
+ spacing_size=sizes.spacing_md,
254
+ radius_size=sizes.radius_sm,
255
+ text_size=sizes.text_md,
256
+ font=(
257
+ fonts.GoogleFont("Noto Sans"),
258
+ "ui-sans-serif",
259
+ "sans-serif",
260
+ ),
261
+ font_mono=(
262
+ fonts.GoogleFont("IBM Plex Mono"),
263
+ "ui-monospace",
264
+ "monospace",
265
+ ),
266
+ ):
267
+ super().__init__(
268
+ primary_hue=primary_hue,
269
+ secondary_hue=secondary_hue,
270
+ neutral_hue=neutral_hue,
271
+ spacing_size=spacing_size,
272
+ radius_size=radius_size,
273
+ text_size=text_size,
274
+ font=font,
275
+ font_mono=font_mono,
276
+ )
277
+ super().set(
278
+ body_background_fill="*neutral_50",
279
+ )
280
+
281
+
282
+ gvlabtheme = OpenGVLab(primary_hue=colors.blue,
283
+ secondary_hue=colors.sky,
284
+ neutral_hue=colors.gray,
285
+ spacing_size=sizes.spacing_md,
286
+ radius_size=sizes.radius_sm,
287
+ text_size=sizes.text_md,
288
+ )
289
+
290
+ title = """<h1 align="center"><a href="https://github.com/OpenGVLab/Ask-Anything"><img src="https://s1.ax1x.com/2023/05/07/p9dBMOU.png" alt="Ask-Anything" border="0" style="margin: 0 auto; height: 100px;" /></a> </h1>"""
291
+ description ="""
292
+ VideoChat2 powered by InternVideo!<br><p><a href='https://github.com/OpenGVLab/Ask-Anything'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p>
293
+ """
294
+ SYS_PROMPT =""
295
+
296
+ with gr.Blocks(title="InternVideo-VideoChat!",theme=gvlabtheme,css="#chatbot {overflow:auto; height:500px;} #InputVideo {overflow:visible; height:320px;} footer {visibility: none}") as demo:
297
+ gr.Markdown(title)
298
+ gr.Markdown(description)
299
+
300
+ with gr.Row():
301
+ with gr.Column(scale=0.5, visible=True) as video_upload:
302
+ with gr.Column(elem_id="image", scale=0.5) as img_part:
303
+ # with gr.Tab("Video", elem_id='video_tab'):
304
+ up_video = gr.Video(interactive=True, include_audio=True, elem_id="video_upload")
305
+ # with gr.Tab("Image", elem_id='image_tab'):
306
+ # up_image = gr.Image(type="pil", interactive=True, elem_id="image_upload")
307
+ upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
308
+ restart = gr.Button("Restart")
309
+ sys_prompt = gr.State(f"{SYS_PROMPT}")
310
+
311
+ num_beams = gr.Slider(
312
+ minimum=1,
313
+ maximum=10,
314
+ value=1,
315
+ step=1,
316
+ interactive=True,
317
+ label="beam search numbers)",
318
+ )
319
+
320
+ temperature = gr.Slider(
321
+ minimum=0.1,
322
+ maximum=2.0,
323
+ value=1.0,
324
+ step=0.1,
325
+ interactive=True,label="Temperature",
326
+ )
327
+
328
+ num_segments = gr.Slider(
329
+ minimum=8,
330
+ maximum=64,
331
+ value=8,
332
+ step=1,
333
+ interactive=True,
334
+ label="Input Frames",
335
+ )
336
+
337
+ resolution = gr.Slider(
338
+ minimum=224,
339
+ maximum=224,
340
+ value=224,
341
+ step=1,
342
+ interactive=True,
343
+ label="Vision encoder resolution",
344
+ )
345
+
346
+ hd_num = gr.Slider(
347
+ minimum=1,
348
+ maximum=10,
349
+ value=4,
350
+ step=1,
351
+ interactive=True,
352
+ label="HD num",
353
+ )
354
+
355
+ padding = gr.Checkbox(
356
+ label="padding",
357
+ info=""
358
+ )
359
+
360
+ with gr.Column(visible=True) as input_raws:
361
+ chat_state = gr.State([])
362
+ img_list = gr.State()
363
+ chatbot = gr.Chatbot(elem_id="chatbot",label='VideoChat')
364
+ with gr.Row():
365
+ with gr.Column(scale=0.7):
366
+ text_input = gr.Textbox(show_label=False, placeholder='Please upload your video first', interactive=False)
367
+ with gr.Column(scale=0.15, min_width=0):
368
+ run = gr.Button("💭Send")
369
+ with gr.Column(scale=0.15, min_width=0):
370
+ clear = gr.Button("🔄Clear️")
371
+
372
+ upload_button.click(upload_img, [ up_video, num_segments, hd_num, padding], [ up_video, text_input, upload_button, img_list])
373
+
374
+ text_input.submit(gradio_ask, [text_input, chatbot], [text_input, chatbot]).then(
375
+ gradio_answer, [chatbot, sys_prompt, text_input, img_list, chat_state, num_beams, temperature], [chatbot, chat_state]
376
+ )
377
+ run.click(gradio_ask, [text_input, chatbot], [text_input, chatbot]).then(
378
+ gradio_answer, [chatbot, sys_prompt, text_input, img_list, chat_state, num_beams, temperature], [chatbot, chat_state]
379
+ )
380
+ run.click(lambda: "", None, text_input)
381
+ clear.click(clear_, None, [chatbot, chat_state])
382
+ restart.click(gradio_reset, [chat_state, img_list], [chatbot, up_video, text_input, upload_button, chat_state, img_list], queue=False)
383
+
384
+ demo.launch()