File size: 2,444 Bytes
e7f1517
1ce3798
e7f1517
 
 
 
 
 
 
 
e296694
e7f1517
 
e296694
e7f1517
8424a77
e7f1517
 
e296694
 
 
 
 
1ce3798
 
 
 
 
 
 
 
 
 
e6e7ab0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7f1517
 
 
 
 
e6e7ab0
1ce3798
 
 
8424a77
 
 
 
e6e7ab0
8424a77
 
 
e6e7ab0
 
 
8424a77
 
 
 
 
 
 
 
 
 
 
e7f1517
 
 
 
 
e6e7ab0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import sys
import numpy as np
import torch
from PIL import Image
import clip
import pickle


def do_batch(batch, embeddings):
    image_batch = torch.tensor(np.stack(batch)).to(device)
    with torch.no_grad():
        image_features = model.encode_image(image_batch).float()
        embeddings += image_features.cpu().numpy().tolist()
        print(f"{len(embeddings)} done")
        sys.stdout.flush()


# even though it's not worth bothering with cuda,
# because 98% of the run time is preprocessing on the cpu.
device = "cuda" if torch.cuda.is_available() else "cpu"

model, preprocess = clip.load('RN50', device=device)

limit = 1e9
batch_size = 100


output_filename = sys.argv[1]
assert output_filename.endswith("pkl"), "first argument is the output pickle"
assert sys.argv[2] in ("thumbs", "no-thumbs"), "second argument either thumbs or no-thumbs"
do_thumbs = sys.argv[2] == "thumbs"


def save(output_filename, embeddings, filenames):
    embeddings = np.array(embeddings)
    assert len(embeddings) == len(filenames)
    print(f"processed {len(embeddings)} images")

    data = {"embeddings": embeddings, "filenames": filenames}
    if do_thumbs:
        assert len(embeddings) == len(thumbs)
        data["thumbs"] = thumbs

    with open(output_filename, "wb") as f:
        pickle.dump(data, f)


embeddings = []
filenames = []
thumbs = []
print("starting processing")
batch = []
batch_count = 0
for filename in sys.stdin:
    filename = filename.rstrip()
    if filename.lower().endswith("jpg") or filename.lower().endswith("jpeg"):
        try:
            rgb = Image.open(filename).convert("RGB")
            img = preprocess(rgb)
            batch.append(img)
            filenames.append(filename)
            if len(batch) >= batch_size:
                do_batch(batch, embeddings)
                batch = []
                batch_count += 1
                if batch_count % 200 == 0:
                    save(output_filename, embeddings, filenames)
            if do_thumbs:
                rgb.thumbnail((128, 128))
                thumb = np.array(rgb)
                thumbs.append(thumb)
            if len(filenames) >= limit:
                break
        except KeyboardInterrupt:
            raise
        except:
            print(f"ERROR, skipping {filename}")
            sys.stdout.flush()

# remaining
if len(batch) > 0:
    do_batch(batch, embeddings)

save(output_filename, embeddings, filenames)