Spaces:
Build error
Build error
import argparse | |
import csv | |
import os | |
from PIL import Image | |
from utils import load_model | |
def main(args): | |
root = args.image_path | |
files = list(os.listdir(root)) | |
for model_name in ["koclip", "koclip/koclip-large"]: | |
counter = 0 | |
images = [] | |
image_ids = [] | |
model, processor = load_model(f"koclip/{model_name}") | |
while counter < len(files): | |
if counter != 0 and counter % args.batch_size == 0: | |
inputs = processor(text=[""], images=images, return_tensors="jax", padding=True) | |
features = model(**inputs).image_embeds | |
with open(os.path.join(args.out_path, f"{model_name}.tsv", "w+")) as f: | |
writer = csv.writer(f, delimiter="\t") | |
for image_id, feature in zip(image_ids, features): | |
writer.writerow([image_id, ",".join(feature)]) | |
images = [] | |
image_ids = [] | |
else: | |
file_ = files[counter] | |
image = Image.open(os.path.join(root, file_)) | |
images.append(image) | |
image_ids.append(file_) | |
counter += 1 | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--batch_size", default=16) | |
parser.add_argument("--image_path", default="images") | |
parser.add_argument("--out_path", default="features") | |
args = parser.parse_args() | |
main(args) | |