se / create_embeddings.py
Daniel Varga
PhotoLibrary. create_embeddings.py refactor, intermediate save.
e6e7ab0
import os
import sys
import numpy as np
import torch
from PIL import Image
import clip
import pickle
def do_batch(batch, embeddings):
image_batch = torch.tensor(np.stack(batch)).to(device)
with torch.no_grad():
image_features = model.encode_image(image_batch).float()
embeddings += image_features.cpu().numpy().tolist()
print(f"{len(embeddings)} done")
sys.stdout.flush()
# even though it's not worth bothering with cuda,
# because 98% of the run time is preprocessing on the cpu.
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('RN50', device=device)
limit = 1e9
batch_size = 100
output_filename = sys.argv[1]
assert output_filename.endswith("pkl"), "first argument is the output pickle"
assert sys.argv[2] in ("thumbs", "no-thumbs"), "second argument either thumbs or no-thumbs"
do_thumbs = sys.argv[2] == "thumbs"
def save(output_filename, embeddings, filenames):
embeddings = np.array(embeddings)
assert len(embeddings) == len(filenames)
print(f"processed {len(embeddings)} images")
data = {"embeddings": embeddings, "filenames": filenames}
if do_thumbs:
assert len(embeddings) == len(thumbs)
data["thumbs"] = thumbs
with open(output_filename, "wb") as f:
pickle.dump(data, f)
embeddings = []
filenames = []
thumbs = []
print("starting processing")
batch = []
batch_count = 0
for filename in sys.stdin:
filename = filename.rstrip()
if filename.lower().endswith("jpg") or filename.lower().endswith("jpeg"):
try:
rgb = Image.open(filename).convert("RGB")
img = preprocess(rgb)
batch.append(img)
filenames.append(filename)
if len(batch) >= batch_size:
do_batch(batch, embeddings)
batch = []
batch_count += 1
if batch_count % 200 == 0:
save(output_filename, embeddings, filenames)
if do_thumbs:
rgb.thumbnail((128, 128))
thumb = np.array(rgb)
thumbs.append(thumb)
if len(filenames) >= limit:
break
except KeyboardInterrupt:
raise
except:
print(f"ERROR, skipping {filename}")
sys.stdout.flush()
# remaining
if len(batch) > 0:
do_batch(batch, embeddings)
save(output_filename, embeddings, filenames)