Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import numpy as np | |
import torch | |
from PIL import Image | |
import clip | |
import pickle | |
def do_batch(batch, embeddings): | |
image_batch = torch.tensor(np.stack(batch)).to(device) | |
with torch.no_grad(): | |
image_features = model.encode_image(image_batch).float() | |
embeddings += image_features.cpu().numpy().tolist() | |
print(f"{len(embeddings)} done") | |
sys.stdout.flush() | |
# even though it's not worth bothering with cuda, | |
# because 98% of the run time is preprocessing on the cpu. | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model, preprocess = clip.load('RN50', device=device) | |
limit = 1e9 | |
batch_size = 100 | |
output_filename = sys.argv[1] | |
assert output_filename.endswith("pkl"), "first argument is the output pickle" | |
assert sys.argv[2] in ("thumbs", "no-thumbs"), "second argument either thumbs or no-thumbs" | |
do_thumbs = sys.argv[2] == "thumbs" | |
def save(output_filename, embeddings, filenames): | |
embeddings = np.array(embeddings) | |
assert len(embeddings) == len(filenames) | |
print(f"processed {len(embeddings)} images") | |
data = {"embeddings": embeddings, "filenames": filenames} | |
if do_thumbs: | |
assert len(embeddings) == len(thumbs) | |
data["thumbs"] = thumbs | |
with open(output_filename, "wb") as f: | |
pickle.dump(data, f) | |
embeddings = [] | |
filenames = [] | |
thumbs = [] | |
print("starting processing") | |
batch = [] | |
batch_count = 0 | |
for filename in sys.stdin: | |
filename = filename.rstrip() | |
if filename.lower().endswith("jpg") or filename.lower().endswith("jpeg"): | |
try: | |
rgb = Image.open(filename).convert("RGB") | |
img = preprocess(rgb) | |
batch.append(img) | |
filenames.append(filename) | |
if len(batch) >= batch_size: | |
do_batch(batch, embeddings) | |
batch = [] | |
batch_count += 1 | |
if batch_count % 200 == 0: | |
save(output_filename, embeddings, filenames) | |
if do_thumbs: | |
rgb.thumbnail((128, 128)) | |
thumb = np.array(rgb) | |
thumbs.append(thumb) | |
if len(filenames) >= limit: | |
break | |
except KeyboardInterrupt: | |
raise | |
except: | |
print(f"ERROR, skipping {filename}") | |
sys.stdout.flush() | |
# remaining | |
if len(batch) > 0: | |
do_batch(batch, embeddings) | |
save(output_filename, embeddings, filenames) | |