VideoCrafterXtend / VBench /vbench /aesthetic_quality.py
ychenhq's picture
Upload folder using huggingface_hub
04fbff5 verified
raw
history blame
2.83 kB
import os
import clip
import torch
import torch.nn as nn
import torch.nn.functional as F
import subprocess
from urllib.request import urlretrieve
from vbench.utils import load_video, load_dimension_info, clip_transform
from tqdm import tqdm
def get_aesthetic_model(cache_folder):
"""load the aethetic model"""
path_to_model = cache_folder + "/sa_0_4_vit_l_14_linear.pth"
if not os.path.exists(path_to_model):
os.makedirs(cache_folder, exist_ok=True)
url_model = (
"https://github.com/LAION-AI/aesthetic-predictor/blob/main/sa_0_4_vit_l_14_linear.pth?raw=true"
)
# download aesthetic predictor
if not os.path.isfile(path_to_model):
try:
print(f'trying urlretrieve to download {url_model} to {path_to_model}')
urlretrieve(url_model, path_to_model) # unable to download https://github.com/LAION-AI/aesthetic-predictor/blob/main/sa_0_4_vit_l_14_linear.pth?raw=true to pretrained/aesthetic_model/emb_reader/sa_0_4_vit_l_14_linear.pth
except:
print(f'unable to download {url_model} to {path_to_model} using urlretrieve, trying wget')
wget_command = ['wget', url_model, '-P', os.path.dirname(path_to_model)]
subprocess.run(wget_command)
m = nn.Linear(768, 1)
s = torch.load(path_to_model)
m.load_state_dict(s)
m.eval()
return m
def laion_aesthetic(aesthetic_model, clip_model, video_list, device):
aesthetic_model.eval()
clip_model.eval()
aesthetic_avg = 0.0
num = 0
video_results = []
for video_path in tqdm(video_list):
images = load_video(video_path)
image_transform = clip_transform(224)
images = image_transform(images)
images = images.to(device)
image_feats = clip_model.encode_image(images).to(torch.float32)
image_feats = F.normalize(image_feats, dim=-1, p=2)
aesthetic_scores = aesthetic_model(image_feats).squeeze()
normalized_aesthetic_scores = aesthetic_scores/10
cur_avg = torch.mean(normalized_aesthetic_scores, dim=0, keepdim=True)
aesthetic_avg += cur_avg.item()
num += 1
video_results.append({'video_path': video_path, 'video_results': cur_avg.item()})
aesthetic_avg /= num
return aesthetic_avg, video_results
def compute_aesthetic_quality(json_dir, device, submodules_list):
vit_path = submodules_list[0]
aes_path = submodules_list[1]
aesthetic_model = get_aesthetic_model(aes_path).to(device)
clip_model, preprocess = clip.load(vit_path, device=device)
video_list, _ = load_dimension_info(json_dir, dimension='aesthetic_quality', lang='en')
all_results, video_results = laion_aesthetic(aesthetic_model, clip_model, video_list, device)
return all_results, video_results