|
import os |
|
from argparse import ArgumentParser |
|
from tqdm import tqdm |
|
from PIL import Image |
|
from torch.nn import functional as F |
|
from torchvision import transforms |
|
from torchvision.transforms import functional as TF |
|
import torch |
|
from simulacra_fit_linear_model import AestheticMeanPredictionLinearModel |
|
from CLIP import clip |
|
|
|
parser = ArgumentParser() |
|
parser.add_argument("directory") |
|
parser.add_argument("-t", "--top-n", default=50) |
|
args = parser.parse_args() |
|
|
|
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
|
|
|
clip_model_name = 'ViT-B/16' |
|
clip_model = clip.load(clip_model_name, jit=False, device=device)[0] |
|
clip_model.eval().requires_grad_(False) |
|
|
|
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], |
|
std=[0.26862954, 0.26130258, 0.27577711]) |
|
|
|
|
|
model = AestheticMeanPredictionLinearModel(512) |
|
model.load_state_dict( |
|
torch.load("models/sac_public_2022_06_29_vit_b_16_linear.pth") |
|
) |
|
model = model.to(device) |
|
|
|
def get_filepaths(parentpath, filepaths): |
|
paths = [] |
|
for path in filepaths: |
|
try: |
|
new_parent = os.path.join(parentpath, path) |
|
paths += get_filepaths(new_parent, os.listdir(new_parent)) |
|
except NotADirectoryError: |
|
paths.append(os.path.join(parentpath, path)) |
|
return paths |
|
|
|
filepaths = get_filepaths(args.directory, os.listdir(args.directory)) |
|
scores = [] |
|
for path in tqdm(filepaths): |
|
|
|
|
|
if path[-4:] not in (".png", ".jpg"): |
|
continue |
|
img = Image.open(path).convert('RGB') |
|
img = TF.resize(img, 224, transforms.InterpolationMode.LANCZOS) |
|
img = TF.center_crop(img, (224,224)) |
|
img = TF.to_tensor(img).to(device) |
|
img = normalize(img) |
|
clip_image_embed = F.normalize( |
|
clip_model.encode_image(img[None, ...]).float(), |
|
dim=-1) |
|
score = model(clip_image_embed) |
|
if len(scores) < args.top_n: |
|
scores.append((score.item(),path)) |
|
scores.sort() |
|
else: |
|
if scores[0][0] < score: |
|
scores.append((score.item(),path)) |
|
scores.sort(key=lambda x: x[0]) |
|
scores = scores[1:] |
|
|
|
for score, path in scores: |
|
print(f"{score}: {path}") |
|
|