Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import spaces | |
try: | |
token =os.environ['HF_TOKEN'] | |
except: | |
print("paste your hf token here!") | |
token = "hf_xxxxxxxxxxxxxxxxxxx" | |
os.environ['HF_TOKEN'] = token | |
import torch | |
import gradio as gr | |
from gradio.themes.utils import colors, fonts, sizes | |
from transformers import AutoTokenizer, AutoModel | |
# ======================================== | |
# Model Initialization | |
# ======================================== | |
tokenizer = AutoTokenizer.from_pretrained('OpenGVLab/InternVideo2_chat_8B_HD', | |
trust_remote_code=True, | |
use_fast=False, | |
token=token) | |
if torch.cuda.is_available(): | |
model = AutoModel.from_pretrained( | |
'OpenGVLab/InternVideo2_chat_8B_HD', | |
torch_dtype=torch.bfloat16, | |
trust_remote_code=True).cuda() | |
else: | |
model = AutoModel.from_pretrained( | |
'OpenGVLab/InternVideo2_chat_8B_HD', | |
torch_dtype=torch.bfloat16, | |
trust_remote_code=True) | |
from decord import VideoReader, cpu | |
from PIL import Image | |
import numpy as np | |
import numpy as np | |
import decord | |
from decord import VideoReader, cpu | |
import torch.nn.functional as F | |
import torchvision.transforms as T | |
from torchvision.transforms import PILToTensor | |
from torchvision import transforms | |
from torchvision.transforms.functional import InterpolationMode | |
decord.bridge.set_bridge("torch") | |
# ======================================== | |
# Define Utils | |
# ======================================== | |
def get_index(num_frames, num_segments): | |
seg_size = float(num_frames - 1) / num_segments | |
start = int(seg_size / 2) | |
offsets = np.array([ | |
start + int(np.round(seg_size * idx)) for idx in range(num_segments) | |
]) | |
return offsets | |
def load_video(video_path, num_segments=8, return_msg=False, resolution=224, hd_num=4, padding=False): | |
decord.bridge.set_bridge("torch") | |
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) | |
num_frames = len(vr) | |
frame_indices = get_index(num_frames, num_segments) | |
mean = (0.485, 0.456, 0.406) | |
std = (0.229, 0.224, 0.225) | |
transform = transforms.Compose([ | |
transforms.Lambda(lambda x: x.float().div(255.0)), | |
transforms.Normalize(mean, std) | |
]) | |
frames = vr.get_batch(frame_indices) | |
# frames = torch.from_numpy(frames) | |
frames = frames.permute(0, 3, 1, 2) | |
if padding: | |
frames = HD_transform_padding(frames.float(), image_size=resolution, hd_num=hd_num) | |
else: | |
frames = HD_transform_no_padding(frames.float(), image_size=resolution, hd_num=hd_num) | |
frames = transform(frames) | |
# print(frames.shape) | |
T_, C, H, W = frames.shape | |
sub_img = frames.reshape( | |
1, T_, 3, H//resolution, resolution, W//resolution, resolution | |
).permute(0, 3, 5, 1, 2, 4, 6).reshape(-1, T_, 3, resolution, resolution).contiguous() | |
glb_img = F.interpolate( | |
frames.float(), size=(resolution, resolution), mode='bicubic', align_corners=False | |
).to(sub_img.dtype).unsqueeze(0) | |
frames = torch.cat([sub_img, glb_img]).unsqueeze(0) | |
if return_msg: | |
fps = float(vr.get_avg_fps()) | |
sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices]) | |
# " " should be added in the start and end | |
msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds." | |
return frames, msg | |
else: | |
return frames | |
def HD_transform_padding(frames, image_size=224, hd_num=6): | |
def _padding_224(frames): | |
_, _, H, W = frames.shape | |
tar = int(np.ceil(H / 224) * 224) | |
top_padding = (tar - H) // 2 | |
bottom_padding = tar - H - top_padding | |
left_padding = 0 | |
right_padding = 0 | |
padded_frames = F.pad( | |
frames, | |
pad=[left_padding, right_padding, top_padding, bottom_padding], | |
mode='constant', value=255 | |
) | |
return padded_frames | |
_, _, H, W = frames.shape | |
trans = False | |
if W < H: | |
frames = frames.flip(-2, -1) | |
trans = True | |
width, height = H, W | |
else: | |
width, height = W, H | |
ratio = width / height | |
scale = 1 | |
while scale * np.ceil(scale / ratio) <= hd_num: | |
scale += 1 | |
scale -= 1 | |
new_w = int(scale * image_size) | |
new_h = int(new_w / ratio) | |
resized_frames = F.interpolate( | |
frames, size=(new_h, new_w), | |
mode='bicubic', | |
align_corners=False | |
) | |
padded_frames = _padding_224(resized_frames) | |
if trans: | |
padded_frames = padded_frames.flip(-2, -1) | |
return padded_frames | |
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): | |
best_ratio_diff = float('inf') | |
best_ratio = (1, 1) | |
area = width * height | |
for ratio in target_ratios: | |
target_aspect_ratio = ratio[0] / ratio[1] | |
ratio_diff = abs(aspect_ratio - target_aspect_ratio) | |
if ratio_diff < best_ratio_diff: | |
best_ratio_diff = ratio_diff | |
best_ratio = ratio | |
elif ratio_diff == best_ratio_diff: | |
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: | |
best_ratio = ratio | |
return best_ratio | |
def HD_transform_no_padding(frames, image_size=224, hd_num=6, fix_ratio=(2,1)): | |
min_num = 1 | |
max_num = hd_num | |
_, _, orig_height, orig_width = frames.shape | |
aspect_ratio = orig_width / orig_height | |
# calculate the existing video aspect ratio | |
target_ratios = set( | |
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if | |
i * j <= max_num and i * j >= min_num) | |
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) | |
# find the closest aspect ratio to the target | |
if fix_ratio: | |
target_aspect_ratio = fix_ratio | |
else: | |
target_aspect_ratio = find_closest_aspect_ratio( | |
aspect_ratio, target_ratios, orig_width, orig_height, image_size) | |
# calculate the target width and height | |
target_width = image_size * target_aspect_ratio[0] | |
target_height = image_size * target_aspect_ratio[1] | |
blocks = target_aspect_ratio[0] * target_aspect_ratio[1] | |
# resize the frames | |
resized_frame = F.interpolate( | |
frames, size=(target_height, target_width), | |
mode='bicubic', align_corners=False | |
) | |
return resized_frame | |
# ======================================== | |
# Gradio Setting | |
# ======================================== | |
def gradio_reset(chat_state, img_list): | |
if chat_state is not None: | |
chat_state = [] | |
if img_list is not None: | |
img_list = None | |
return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your video first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list | |
def upload_img( gr_video, num_segments, hd_num, padding): | |
img_list = [] | |
if gr_video is None: | |
return None, None, gr.update(interactive=True),gr.update(interactive=True, placeholder='Please upload video/image first!'), None | |
if gr_video: | |
video_tensor, msg = load_video(gr_video, num_segments=num_segments, return_msg=True, resolution=224, hd_num=hd_num, padding=padding) | |
video_tensor = video_tensor.to(model.device) | |
return gr.update(interactive=True), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), video_tensor | |
# if gr_img: | |
# llm_message, img_list,chat_state = chat.upload_img(gr_img, chat_state, img_list) | |
# return gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False) | |
def clear_(): | |
return [], [] | |
def gradio_ask(user_message, chatbot): | |
if len(user_message) == 0: | |
return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state | |
chatbot = chatbot + [[user_message, None]] | |
return '', chatbot | |
def gradio_answer(chatbot, sys_prompt, user_prompt, video_tensor, chat_state, num_beams, temperature, do_sample=False): | |
video_tensor = video_tensor.to(model.device) | |
response, chat_state = model.chat(tokenizer, | |
sys_prompt, | |
user_prompt, | |
media_type='video', | |
media_tensor=video_tensor, | |
chat_history= chat_state, | |
return_history=True, | |
generation_config={ | |
"num_beams": num_beams, | |
"temperature": temperature, | |
"do_sample": do_sample}) | |
print(response) | |
chatbot[-1][1] = response | |
return chatbot, chat_state | |
class OpenGVLab(gr.themes.base.Base): | |
def __init__( | |
self, | |
*, | |
primary_hue=colors.blue, | |
secondary_hue=colors.sky, | |
neutral_hue=colors.gray, | |
spacing_size=sizes.spacing_md, | |
radius_size=sizes.radius_sm, | |
text_size=sizes.text_md, | |
font=( | |
fonts.GoogleFont("Noto Sans"), | |
"ui-sans-serif", | |
"sans-serif", | |
), | |
font_mono=( | |
fonts.GoogleFont("IBM Plex Mono"), | |
"ui-monospace", | |
"monospace", | |
), | |
): | |
super().__init__( | |
primary_hue=primary_hue, | |
secondary_hue=secondary_hue, | |
neutral_hue=neutral_hue, | |
spacing_size=spacing_size, | |
radius_size=radius_size, | |
text_size=text_size, | |
font=font, | |
font_mono=font_mono, | |
) | |
super().set( | |
body_background_fill="*neutral_50", | |
) | |
gvlabtheme = OpenGVLab(primary_hue=colors.blue, | |
secondary_hue=colors.sky, | |
neutral_hue=colors.gray, | |
spacing_size=sizes.spacing_md, | |
radius_size=sizes.radius_sm, | |
text_size=sizes.text_md, | |
) | |
title = """<h1 align="center"><a href="https://github.com/OpenGVLab/Ask-Anything"><img src="https://s1.ax1x.com/2023/05/07/p9dBMOU.png" alt="Ask-Anything" border="0" style="margin: 0 auto; height: 100px;" /></a> </h1>""" | |
description =""" | |
VideoChat2 powered by InternVideo!<br><p><a href='https://github.com/OpenGVLab/Ask-Anything'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p> | |
""" | |
SYS_PROMPT ="" | |
with gr.Blocks(title="InternVideo-VideoChat!",theme=gvlabtheme,css="#chatbot {overflow:auto; height:500px;} #InputVideo {overflow:visible; height:320px;} footer {visibility: none}") as demo: | |
gr.Markdown(title) | |
gr.Markdown(description) | |
with gr.Row(): | |
with gr.Column(scale=0.5, visible=True) as video_upload: | |
with gr.Column(elem_id="image", scale=0.5) as img_part: | |
# with gr.Tab("Video", elem_id='video_tab'): | |
up_video = gr.Video(interactive=True, include_audio=True, elem_id="video_upload") | |
# with gr.Tab("Image", elem_id='image_tab'): | |
# up_image = gr.Image(type="pil", interactive=True, elem_id="image_upload") | |
upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary") | |
restart = gr.Button("Restart") | |
sys_prompt = gr.State(f"{SYS_PROMPT}") | |
num_beams = gr.Slider( | |
minimum=1, | |
maximum=10, | |
value=1, | |
step=1, | |
interactive=True, | |
label="beam search numbers)", | |
) | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=2.0, | |
value=1.0, | |
step=0.1, | |
interactive=True,label="Temperature", | |
) | |
num_segments = gr.Slider( | |
minimum=8, | |
maximum=64, | |
value=8, | |
step=1, | |
interactive=True, | |
label="Input Frames", | |
) | |
resolution = gr.Slider( | |
minimum=224, | |
maximum=224, | |
value=224, | |
step=1, | |
interactive=True, | |
label="Vision encoder resolution", | |
) | |
hd_num = gr.Slider( | |
minimum=1, | |
maximum=10, | |
value=4, | |
step=1, | |
interactive=True, | |
label="HD num", | |
) | |
padding = gr.Checkbox( | |
label="padding", | |
info="" | |
) | |
with gr.Column(visible=True) as input_raws: | |
chat_state = gr.State([]) | |
img_list = gr.State() | |
chatbot = gr.Chatbot(elem_id="chatbot",label='VideoChat') | |
with gr.Row(): | |
with gr.Column(scale=0.7): | |
text_input = gr.Textbox(show_label=False, placeholder='Please upload your video first', interactive=False) | |
with gr.Column(scale=0.15, min_width=0): | |
run = gr.Button("💭Send") | |
with gr.Column(scale=0.15, min_width=0): | |
clear = gr.Button("🔄Clear️") | |
upload_button.click(upload_img, [ up_video, num_segments, hd_num, padding], [ up_video, text_input, upload_button, img_list]) | |
text_input.submit(gradio_ask, [text_input, chatbot], [text_input, chatbot]).then( | |
gradio_answer, [chatbot, sys_prompt, text_input, img_list, chat_state, num_beams, temperature], [chatbot, chat_state] | |
) | |
run.click(gradio_ask, [text_input, chatbot], [text_input, chatbot]).then( | |
gradio_answer, [chatbot, sys_prompt, text_input, img_list, chat_state, num_beams, temperature], [chatbot, chat_state] | |
) | |
run.click(lambda: "", None, text_input) | |
clear.click(clear_, None, [chatbot, chat_state]) | |
restart.click(gradio_reset, [chat_state, img_list], [chatbot, up_video, text_input, upload_button, chat_state, img_list], queue=False) | |
demo.launch() |