VideoCrafterXtend / VBench /vbench /human_action.py
ychenhq's picture
Upload folder using huggingface_hub
04fbff5 verified
raw
history blame
3.72 kB
import os
import json
import numpy as np
import clip
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from vbench.utils import load_video, load_dimension_info
from vbench.third_party.umt.datasets.video_transforms import (
Compose, Resize, CenterCrop, Normalize,
create_random_augment, random_short_side_scale_jitter,
random_crop, random_resized_crop_with_shift, random_resized_crop,
horizontal_flip, random_short_side_scale_jitter, uniform_crop,
)
from vbench.third_party.umt.datasets.volume_transforms import ClipToTensor
from timm.models import create_model
from vbench.third_party.umt.models.modeling_finetune import vit_large_patch16_224
from tqdm import tqdm
def build_dict():
CUR_DIR = os.path.dirname(os.path.abspath(__file__))
path = f'{CUR_DIR}/third_party/umt/kinetics_400_categories.txt'
results = {}
with open(path, 'r') as f:
cat_list = f.readlines()
cat_list = [c.strip() for c in cat_list]
for line in cat_list:
cat, number = line.split('\t')
results[number] = cat.lower()
return results
def human_action(umt_path, video_list, device):
state_dict = torch.load(umt_path, map_location='cpu')
model = create_model(
"vit_large_patch16_224",
pretrained=False,
num_classes=400,
all_frames=16,
tubelet_size=1,
use_learnable_pos_emb=False,
fc_drop_rate=0.,
drop_rate=0.,
drop_path_rate=0.2,
attn_drop_rate=0.,
drop_block_rate=None,
use_checkpoint=False,
checkpoint_num=16,
use_mean_pooling=True,
init_scale=0.001,
)
data_transform = Compose([
Resize(256, interpolation='bilinear'),
CenterCrop(size=(224, 224)),
ClipToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
model = model.to(device)
model.load_state_dict(state_dict, strict=False)
model.eval()
cat_dict = build_dict()
cnt= 0
cor_num = 0
video_results = []
for video_path in tqdm(video_list):
video_label_ls = video_path.split('/')[-1].lower().split('-')[0].split("person is ")[-1].split('_')[0]
cnt += 1
images = load_video(video_path, data_transform, num_frames=16)
images = images.unsqueeze(0)
images = images.to(device)
with torch.no_grad():
logits = torch.sigmoid(model(images))
results, indices = torch.topk(logits, 5, dim=1)
indices = indices.squeeze().tolist()
results = results.squeeze().tolist()
results = [round(f, 4) for f in results]
cat_ls = []
for i in range(5):
if results[i] >= 0.85:
cat_ls.append(cat_dict[str(indices[i])])
flag = False
for cat in cat_ls:
if cat == video_label_ls:
cor_num += 1
flag = True
# print(f"{cnt}: {video_path} correct, top-5: {cat_ls}, logits: {results}", flush=True)
break
if flag is False:
# print(f"{cnt}: {video_path} false, gt: {video_label_ls}, top-5: {cat_ls}, logits: {results}", flush=True)
pass
video_results.append({'video_path': video_path, 'video_results': flag})
# print(f"cor num: {cor_num}, total: {cnt}")
acc = cor_num / cnt
return acc, video_results
def compute_human_action(json_dir, device, submodules_list):
umt_path = submodules_list[0]
video_list, _ = load_dimension_info(json_dir, dimension='human_action', lang='en')
all_results, video_results = human_action(umt_path, video_list, device)
return all_results, video_results