feizhengcong's picture
Upload 198 files
074c857
raw
history blame
2.33 kB
import os
from argparse import ArgumentParser
from tqdm import tqdm
from PIL import Image
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms import functional as TF
import torch
from simulacra_fit_linear_model import AestheticMeanPredictionLinearModel
from CLIP import clip
parser = ArgumentParser()
parser.add_argument("directory")
parser.add_argument("-t", "--top-n", default=50)
args = parser.parse_args()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
clip_model_name = 'ViT-B/16'
clip_model = clip.load(clip_model_name, jit=False, device=device)[0]
clip_model.eval().requires_grad_(False)
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711])
# 512 is embed dimension for ViT-B/16 CLIP
model = AestheticMeanPredictionLinearModel(512)
model.load_state_dict(
torch.load("models/sac_public_2022_06_29_vit_b_16_linear.pth")
)
model = model.to(device)
def get_filepaths(parentpath, filepaths):
paths = []
for path in filepaths:
try:
new_parent = os.path.join(parentpath, path)
paths += get_filepaths(new_parent, os.listdir(new_parent))
except NotADirectoryError:
paths.append(os.path.join(parentpath, path))
return paths
filepaths = get_filepaths(args.directory, os.listdir(args.directory))
scores = []
for path in tqdm(filepaths):
# This is obviously a flawed way to check for an image but this is just
# a demo script anyway.
if path[-4:] not in (".png", ".jpg"):
continue
img = Image.open(path).convert('RGB')
img = TF.resize(img, 224, transforms.InterpolationMode.LANCZOS)
img = TF.center_crop(img, (224,224))
img = TF.to_tensor(img).to(device)
img = normalize(img)
clip_image_embed = F.normalize(
clip_model.encode_image(img[None, ...]).float(),
dim=-1)
score = model(clip_image_embed)
if len(scores) < args.top_n:
scores.append((score.item(),path))
scores.sort()
else:
if scores[0][0] < score:
scores.append((score.item(),path))
scores.sort(key=lambda x: x[0])
scores = scores[1:]
for score, path in scores:
print(f"{score}: {path}")