import argparse
import csv
import os

import jax.numpy as jnp
from jax import jit
from PIL import Image
from tqdm import tqdm

from utils import load_model


def main(args):
    root = args.image_path
    files = list(os.listdir(root))
    for f in files:
        assert f[-4:] == ".jpg"
    for model_name in ["koclip-base", "koclip-large"]:
        model, processor = load_model(f"koclip/{model_name}")
        with tqdm(total=len(files)) as pbar:
            for counter in range(0, len(files), args.batch_size):
                images = []
                image_ids = []
                for idx in range(counter, min(len(files), counter + args.batch_size)):
                    file_ = files[idx]
                    image = Image.open(os.path.join(root, file_)).convert("RGB")
                    images.append(image)
                    image_ids.append(file_)

                pbar.update(args.batch_size)
                try:
                    inputs = processor(
                        text=[""], images=images, return_tensors="jax", padding=True
                    )
                except:
                    print(image_ids)
                    break
                inputs["pixel_values"] = jnp.transpose(
                    inputs["pixel_values"], axes=[0, 2, 3, 1]
                )
                features = model(**inputs).image_embeds
                with open(os.path.join(args.out_path, f"{model_name}.tsv"), "a+") as f:
                    writer = csv.writer(f, delimiter="\t")
                    for image_id, feature in zip(image_ids, features):
                        writer.writerow(
                            [image_id, ",".join(map(lambda x: str(x), feature))]
                        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", default=16)
    parser.add_argument("--image_path", default="images")
    parser.add_argument("--out_path", default="features")
    args = parser.parse_args()
    main(args)