koclip / executables /embed_images.py
Trent
Model list improvement
587ab22
raw
history blame
2.02 kB
import argparse
import csv
import os
import jax.numpy as jnp
from jax import jit
from PIL import Image
from tqdm import tqdm
from config import MODEL_LIST
from utils import load_model
def main(args):
root = args.image_path
files = list(os.listdir(root))
for f in files:
assert f[-4:] == ".jpg"
for model_name in MODEL_LIST:
model, processor = load_model(f"koclip/{model_name}")
with tqdm(total=len(files)) as pbar:
for counter in range(0, len(files), args.batch_size):
images = []
image_ids = []
for idx in range(counter, min(len(files), counter + args.batch_size)):
file_ = files[idx]
image = Image.open(os.path.join(root, file_)).convert("RGB")
images.append(image)
image_ids.append(file_)
pbar.update(args.batch_size)
try:
inputs = processor(
text=[""], images=images, return_tensors="jax", padding=True
)
except:
print(image_ids)
break
inputs["pixel_values"] = jnp.transpose(
inputs["pixel_values"], axes=[0, 2, 3, 1]
)
features = model(**inputs).image_embeds
with open(os.path.join(args.out_path, f"{model_name}.tsv"), "a+") as f:
writer = csv.writer(f, delimiter="\t")
for image_id, feature in zip(image_ids, features):
writer.writerow(
[image_id, ",".join(map(lambda x: str(x), feature))]
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", default=16)
parser.add_argument("--image_path", default="images")
parser.add_argument("--out_path", default="features")
args = parser.parse_args()
main(args)