video-stable-diffusion / deforum-stable-diffusion /src /simulacra_compute_embeddings.py
feizhengcong's picture
Upload 198 files
074c857
raw
history blame
3.33 kB
#!/usr/bin/env python3
"""Precomputes CLIP embeddings for Simulacra Aesthetic Captions."""
import argparse
import os
from pathlib import Path
import sqlite3
from PIL import Image
import torch
from torch import multiprocessing as mp
from torch.utils import data
import torchvision.transforms as transforms
from tqdm import tqdm
from CLIP import clip
class SimulacraDataset(data.Dataset):
"""Simulacra dataset
Args:
images_dir: directory
transform: preprocessing and augmentation of the training images
"""
def __init__(self, images_dir, db, transform=None):
self.images_dir = Path(images_dir)
self.transform = transform
self.conn = sqlite3.connect(db)
self.ratings = []
for row in self.conn.execute('SELECT generations.id, images.idx, paths.path, AVG(ratings.rating) FROM images JOIN generations ON images.gid=generations.id JOIN ratings ON images.id=ratings.iid JOIN paths ON images.id=paths.iid GROUP BY images.id'):
self.ratings.append(row)
def __len__(self):
return len(self.ratings)
def __getitem__(self, key):
gid, idx, filename, rating = self.ratings[key]
image = Image.open(self.images_dir / filename).convert('RGB')
if self.transform:
image = self.transform(image)
return image, torch.tensor(rating)
def main():
p = argparse.ArgumentParser(description=__doc__)
p.add_argument('--batch-size', '-bs', type=int, default=10,
help='the CLIP model')
p.add_argument('--clip-model', type=str, default='ViT-B/16',
help='the CLIP model')
p.add_argument('--db', type=str, required=True,
help='the database location')
p.add_argument('--device', type=str,
help='the device to use')
p.add_argument('--images-dir', type=str, required=True,
help='the dataset images directory')
p.add_argument('--num-workers', type=int, default=8,
help='the number of data loader workers')
p.add_argument('--output', type=str, required=True,
help='the output file')
p.add_argument('--start-method', type=str, default='spawn',
choices=['fork', 'forkserver', 'spawn'],
help='the multiprocessing start method')
args = p.parse_args()
mp.set_start_method(args.start_method)
if args.device:
device = torch.device(device)
else:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
clip_model, clip_tf = clip.load(args.clip_model, device=device, jit=False)
clip_model = clip_model.eval().requires_grad_(False)
dataset = SimulacraDataset(args.images_dir, args.db, transform=clip_tf)
loader = data.DataLoader(dataset, args.batch_size, num_workers=args.num_workers)
embeds, ratings = [], []
for batch in tqdm(loader):
images_batch, ratings_batch = batch
embeds.append(clip_model.encode_image(images_batch.to(device)).cpu())
ratings.append(ratings_batch.clone())
obj = {'clip_model': args.clip_model,
'embeds': torch.cat(embeds),
'ratings': torch.cat(ratings)}
torch.save(obj, args.output)
if __name__ == '__main__':
main()