|
import os |
|
import json |
|
import numpy as np |
|
import clip |
|
from PIL import Image |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from vbench.utils import load_video, load_dimension_info |
|
from vbench.third_party.umt.datasets.video_transforms import ( |
|
Compose, Resize, CenterCrop, Normalize, |
|
create_random_augment, random_short_side_scale_jitter, |
|
random_crop, random_resized_crop_with_shift, random_resized_crop, |
|
horizontal_flip, random_short_side_scale_jitter, uniform_crop, |
|
) |
|
from vbench.third_party.umt.datasets.volume_transforms import ClipToTensor |
|
from timm.models import create_model |
|
from vbench.third_party.umt.models.modeling_finetune import vit_large_patch16_224 |
|
from tqdm import tqdm |
|
|
|
def build_dict(): |
|
CUR_DIR = os.path.dirname(os.path.abspath(__file__)) |
|
path = f'{CUR_DIR}/third_party/umt/kinetics_400_categories.txt' |
|
results = {} |
|
with open(path, 'r') as f: |
|
cat_list = f.readlines() |
|
cat_list = [c.strip() for c in cat_list] |
|
for line in cat_list: |
|
cat, number = line.split('\t') |
|
results[number] = cat.lower() |
|
return results |
|
|
|
|
|
def human_action(umt_path, video_list, device): |
|
state_dict = torch.load(umt_path, map_location='cpu') |
|
model = create_model( |
|
"vit_large_patch16_224", |
|
pretrained=False, |
|
num_classes=400, |
|
all_frames=16, |
|
tubelet_size=1, |
|
use_learnable_pos_emb=False, |
|
fc_drop_rate=0., |
|
drop_rate=0., |
|
drop_path_rate=0.2, |
|
attn_drop_rate=0., |
|
drop_block_rate=None, |
|
use_checkpoint=False, |
|
checkpoint_num=16, |
|
use_mean_pooling=True, |
|
init_scale=0.001, |
|
) |
|
data_transform = Compose([ |
|
Resize(256, interpolation='bilinear'), |
|
CenterCrop(size=(224, 224)), |
|
ClipToTensor(), |
|
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
]) |
|
model = model.to(device) |
|
model.load_state_dict(state_dict, strict=False) |
|
model.eval() |
|
cat_dict = build_dict() |
|
cnt= 0 |
|
cor_num = 0 |
|
video_results = [] |
|
for video_path in tqdm(video_list): |
|
video_label_ls = video_path.split('/')[-1].lower().split('-')[0].split("person is ")[-1].split('_')[0] |
|
cnt += 1 |
|
images = load_video(video_path, data_transform, num_frames=16) |
|
images = images.unsqueeze(0) |
|
images = images.to(device) |
|
with torch.no_grad(): |
|
logits = torch.sigmoid(model(images)) |
|
results, indices = torch.topk(logits, 5, dim=1) |
|
indices = indices.squeeze().tolist() |
|
results = results.squeeze().tolist() |
|
results = [round(f, 4) for f in results] |
|
cat_ls = [] |
|
for i in range(5): |
|
if results[i] >= 0.85: |
|
cat_ls.append(cat_dict[str(indices[i])]) |
|
flag = False |
|
for cat in cat_ls: |
|
if cat == video_label_ls: |
|
cor_num += 1 |
|
flag = True |
|
|
|
break |
|
if flag is False: |
|
|
|
pass |
|
video_results.append({'video_path': video_path, 'video_results': flag}) |
|
|
|
acc = cor_num / cnt |
|
return acc, video_results |
|
|
|
|
|
def compute_human_action(json_dir, device, submodules_list): |
|
umt_path = submodules_list[0] |
|
video_list, _ = load_dimension_info(json_dir, dimension='human_action', lang='en') |
|
all_results, video_results = human_action(umt_path, video_list, device) |
|
return all_results, video_results |
|
|