|
|
|
|
|
"""Precomputes CLIP embeddings for Simulacra Aesthetic Captions.""" |
|
|
|
import argparse |
|
import os |
|
from pathlib import Path |
|
import sqlite3 |
|
|
|
from PIL import Image |
|
|
|
import torch |
|
from torch import multiprocessing as mp |
|
from torch.utils import data |
|
import torchvision.transforms as transforms |
|
from tqdm import tqdm |
|
|
|
from CLIP import clip |
|
|
|
|
|
class SimulacraDataset(data.Dataset): |
|
"""Simulacra dataset |
|
Args: |
|
images_dir: directory |
|
transform: preprocessing and augmentation of the training images |
|
""" |
|
|
|
def __init__(self, images_dir, db, transform=None): |
|
self.images_dir = Path(images_dir) |
|
self.transform = transform |
|
self.conn = sqlite3.connect(db) |
|
self.ratings = [] |
|
for row in self.conn.execute('SELECT generations.id, images.idx, paths.path, AVG(ratings.rating) FROM images JOIN generations ON images.gid=generations.id JOIN ratings ON images.id=ratings.iid JOIN paths ON images.id=paths.iid GROUP BY images.id'): |
|
self.ratings.append(row) |
|
|
|
def __len__(self): |
|
return len(self.ratings) |
|
|
|
def __getitem__(self, key): |
|
gid, idx, filename, rating = self.ratings[key] |
|
image = Image.open(self.images_dir / filename).convert('RGB') |
|
if self.transform: |
|
image = self.transform(image) |
|
return image, torch.tensor(rating) |
|
|
|
|
|
def main(): |
|
p = argparse.ArgumentParser(description=__doc__) |
|
p.add_argument('--batch-size', '-bs', type=int, default=10, |
|
help='the CLIP model') |
|
p.add_argument('--clip-model', type=str, default='ViT-B/16', |
|
help='the CLIP model') |
|
p.add_argument('--db', type=str, required=True, |
|
help='the database location') |
|
p.add_argument('--device', type=str, |
|
help='the device to use') |
|
p.add_argument('--images-dir', type=str, required=True, |
|
help='the dataset images directory') |
|
p.add_argument('--num-workers', type=int, default=8, |
|
help='the number of data loader workers') |
|
p.add_argument('--output', type=str, required=True, |
|
help='the output file') |
|
p.add_argument('--start-method', type=str, default='spawn', |
|
choices=['fork', 'forkserver', 'spawn'], |
|
help='the multiprocessing start method') |
|
args = p.parse_args() |
|
|
|
mp.set_start_method(args.start_method) |
|
if args.device: |
|
device = torch.device(device) |
|
else: |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
print('Using device:', device) |
|
|
|
clip_model, clip_tf = clip.load(args.clip_model, device=device, jit=False) |
|
clip_model = clip_model.eval().requires_grad_(False) |
|
|
|
dataset = SimulacraDataset(args.images_dir, args.db, transform=clip_tf) |
|
loader = data.DataLoader(dataset, args.batch_size, num_workers=args.num_workers) |
|
|
|
embeds, ratings = [], [] |
|
|
|
for batch in tqdm(loader): |
|
images_batch, ratings_batch = batch |
|
embeds.append(clip_model.encode_image(images_batch.to(device)).cpu()) |
|
ratings.append(ratings_batch.clone()) |
|
|
|
obj = {'clip_model': args.clip_model, |
|
'embeds': torch.cat(embeds), |
|
'ratings': torch.cat(ratings)} |
|
|
|
torch.save(obj, args.output) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|