VideoCrafterXtend / VBench /vbench /dynamic_degree.py
ychenhq's picture
Upload folder using huggingface_hub
04fbff5 verified
raw
history blame
5.04 kB
import argparse
import os
import cv2
import glob
import numpy as np
import torch
from tqdm import tqdm
from easydict import EasyDict as edict
from vbench.utils import load_dimension_info
from vbench.third_party.RAFT.core.raft import RAFT
from vbench.third_party.RAFT.core.utils_core.utils import InputPadder
class DynamicDegree:
def __init__(self, args, device):
self.args = args
self.device = device
self.load_model()
def load_model(self):
self.model = torch.nn.DataParallel(RAFT(self.args))
self.model.load_state_dict(torch.load(self.args.model))
self.model = self.model.module
self.model.to(self.device)
self.model.eval()
def get_score(self, img, flo):
img = img[0].permute(1,2,0).cpu().numpy()
flo = flo[0].permute(1,2,0).cpu().numpy()
u = flo[:,:,0]
v = flo[:,:,1]
rad = np.sqrt(np.square(u) + np.square(v))
h, w = rad.shape
rad_flat = rad.flatten()
cut_index = int(h*w*0.05)
max_rad = np.mean(abs(np.sort(-rad_flat))[:cut_index])
return max_rad.item()
def set_params(self, frame, count):
scale = min(list(frame.shape)[-2:])
self.params = {"thres":6.0*(scale/256.0), "count_num":round(4*(count/16.0))}
def infer(self, video_path):
with torch.no_grad():
if video_path.endswith('.mp4'):
frames = self.get_frames(video_path)
elif os.path.isdir(video_path):
frames = self.get_frames_from_img_folder(video_path)
else:
raise NotImplementedError
self.set_params(frame=frames[0], count=len(frames))
static_score = []
for image1, image2 in zip(frames[:-1], frames[1:]):
padder = InputPadder(image1.shape)
image1, image2 = padder.pad(image1, image2)
_, flow_up = self.model(image1, image2, iters=20, test_mode=True)
max_rad = self.get_score(image1, flow_up)
static_score.append(max_rad)
whether_move = self.check_move(static_score)
return whether_move
def check_move(self, score_list):
thres = self.params["thres"]
count_num = self.params["count_num"]
count = 0
for score in score_list:
if score > thres:
count += 1
if count >= count_num:
return True
return False
def get_frames(self, video_path):
frame_list = []
video = cv2.VideoCapture(video_path)
fps = video.get(cv2.CAP_PROP_FPS) # get fps
interval = round(fps/8)
while video.isOpened():
success, frame = video.read()
if success:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # convert to rgb
frame = torch.from_numpy(frame.astype(np.uint8)).permute(2, 0, 1).float()
frame = frame[None].to(self.device)
frame_list.append(frame)
else:
break
video.release()
assert frame_list != []
frame_list = self.extract_frame(frame_list, interval)
return frame_list
def extract_frame(self, frame_list, interval=1):
extract = []
for i in range(0, len(frame_list), interval):
extract.append(frame_list[i])
return extract
def get_frames_from_img_folder(self, img_folder):
exts = ['jpg', 'png', 'jpeg', 'bmp', 'tif',
'tiff', 'JPG', 'PNG', 'JPEG', 'BMP',
'TIF', 'TIFF']
frame_list = []
imgs = sorted([p for p in glob.glob(os.path.join(img_folder, "*")) if os.path.splitext(p)[1][1:] in exts])
# imgs = sorted(glob.glob(os.path.join(img_folder, "*.png")))
for img in imgs:
frame = cv2.imread(img, cv2.IMREAD_COLOR)
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = torch.from_numpy(frame.astype(np.uint8)).permute(2, 0, 1).float()
frame = frame[None].to(self.device)
frame_list.append(frame)
assert frame_list != []
return frame_list
def dynamic_degree(dynamic, video_list):
sim = []
video_results = []
for video_path in tqdm(video_list):
score_per_video = dynamic.infer(video_path)
video_results.append({'video_path': video_path, 'video_results': score_per_video})
sim.append(score_per_video)
avg_score = np.mean(sim)
return avg_score, video_results
def compute_dynamic_degree(json_dir, device, submodules_list):
model_path = submodules_list["model"]
# set_args
args_new = edict({"model":model_path, "small":False, "mixed_precision":False, "alternate_corr":False})
dynamic = DynamicDegree(args_new, device)
video_list, _ = load_dimension_info(json_dir, dimension='dynamic_degree', lang='en')
all_results, video_results = dynamic_degree(dynamic, video_list)
return all_results, video_results