|
import json |
|
from clip_for_image_classification import FlaxCLIPForImageClassification |
|
from PIL import Image |
|
import jax |
|
import numpy as np |
|
from transformers import CLIPImageProcessor |
|
import os |
|
|
|
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" |
|
|
|
model = FlaxCLIPForImageClassification.from_pretrained("Thouph/clip-vit-l-224-patch14-datacomp-image-classification") |
|
image_processor = CLIPImageProcessor.from_pretrained("laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K") |
|
image = Image.open("/your/image/here.jpg") |
|
inputs = image_processor(images=image, return_tensors="jax") |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
probabilities = jax.nn.sigmoid(logits) |
|
probabilities = np.asarray(probabilities).copy() |
|
|
|
|
|
def topk_by_sort(input, k, axis=None, ascending=False): |
|
if not ascending: |
|
input *= -1 |
|
ind = np.argsort(input, axis=axis) |
|
ind = np.take(ind, np.arange(k), axis=axis) |
|
if not ascending: |
|
input *= -1 |
|
val = np.take_along_axis(input, ind, axis=axis) |
|
return ind, val |
|
|
|
|
|
indices, values = topk_by_sort(probabilities, 100) |
|
|
|
with open("7748tags.json", "r") as file: |
|
allowed_tags = json.load(file) |
|
|
|
for index, value in zip(indices, values, strict=True): |
|
print(allowed_tags[index], value) |
|
|