Spaces:
Sleeping
Sleeping
File size: 2,444 Bytes
e7f1517 1ce3798 e7f1517 e296694 e7f1517 e296694 e7f1517 8424a77 e7f1517 e296694 1ce3798 e6e7ab0 e7f1517 e6e7ab0 1ce3798 8424a77 e6e7ab0 8424a77 e6e7ab0 8424a77 e7f1517 e6e7ab0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import os
import sys
import numpy as np
import torch
from PIL import Image
import clip
import pickle
def do_batch(batch, embeddings):
image_batch = torch.tensor(np.stack(batch)).to(device)
with torch.no_grad():
image_features = model.encode_image(image_batch).float()
embeddings += image_features.cpu().numpy().tolist()
print(f"{len(embeddings)} done")
sys.stdout.flush()
# even though it's not worth bothering with cuda,
# because 98% of the run time is preprocessing on the cpu.
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('RN50', device=device)
limit = 1e9
batch_size = 100
output_filename = sys.argv[1]
assert output_filename.endswith("pkl"), "first argument is the output pickle"
assert sys.argv[2] in ("thumbs", "no-thumbs"), "second argument either thumbs or no-thumbs"
do_thumbs = sys.argv[2] == "thumbs"
def save(output_filename, embeddings, filenames):
embeddings = np.array(embeddings)
assert len(embeddings) == len(filenames)
print(f"processed {len(embeddings)} images")
data = {"embeddings": embeddings, "filenames": filenames}
if do_thumbs:
assert len(embeddings) == len(thumbs)
data["thumbs"] = thumbs
with open(output_filename, "wb") as f:
pickle.dump(data, f)
embeddings = []
filenames = []
thumbs = []
print("starting processing")
batch = []
batch_count = 0
for filename in sys.stdin:
filename = filename.rstrip()
if filename.lower().endswith("jpg") or filename.lower().endswith("jpeg"):
try:
rgb = Image.open(filename).convert("RGB")
img = preprocess(rgb)
batch.append(img)
filenames.append(filename)
if len(batch) >= batch_size:
do_batch(batch, embeddings)
batch = []
batch_count += 1
if batch_count % 200 == 0:
save(output_filename, embeddings, filenames)
if do_thumbs:
rgb.thumbnail((128, 128))
thumb = np.array(rgb)
thumbs.append(thumb)
if len(filenames) >= limit:
break
except KeyboardInterrupt:
raise
except:
print(f"ERROR, skipping {filename}")
sys.stdout.flush()
# remaining
if len(batch) > 0:
do_batch(batch, embeddings)
save(output_filename, embeddings, filenames)
|