|
import torch |
|
import numpy as np |
|
import pandas as pd |
|
import gradio as gr |
|
from io import BytesIO |
|
from PIL import Image as PILIMAGE |
|
from IPython.display import Image |
|
from IPython.core.display import HTML |
|
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer |
|
import os |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model = CLIPModel.from_pretrained("vesteinn/clip-nabirds").to(device) |
|
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
|
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") |
|
|
|
|
|
def load_class_names(dataset_path=''): |
|
names = {} |
|
with open(os.path.join(dataset_path, 'classes.txt')) as f: |
|
for line in f: |
|
pieces = line.strip().split() |
|
class_id = pieces[0] |
|
names[class_id] = ' '.join(pieces[1:]) |
|
|
|
return names |
|
|
|
|
|
def get_labels(): |
|
labels = [] |
|
class_names = load_class_names(".") |
|
for _, name in class_names.items(): |
|
labels.append(f"This is a photo of {name}.") |
|
return labels |
|
|
|
|
|
def encode_text(text): |
|
with torch.no_grad(): |
|
inputs = tokenizer([text], padding=True, return_tensors="pt") |
|
text_encoded = model.get_text_features(**inputs).detach().numpy() |
|
return text_encoded |
|
|
|
|
|
ALL_LABELS = get_labels() |
|
try: |
|
LABEL_FEATURES = np.load("label_features.np") |
|
except: |
|
LABEL_FEATURES = [] |
|
for label in ALL_LABELS: |
|
LABEL_FEATURES.append(encode_text(label)) |
|
LABEL_FEATURES = np.vstack(LABEL_FEATURES) |
|
np.save(open("label_features.np", "wb"), LABEL_FEATURES) |
|
|
|
|
|
def encode_image(image): |
|
image = PILIMAGE.fromarray(image.astype('uint8'), 'RGB') |
|
with torch.no_grad(): |
|
photo_preprocessed = processor(text=None, images=image, return_tensors="pt", padding=True)["pixel_values"] |
|
search_photo_feature = model.get_image_features(photo_preprocessed.to(device)) |
|
search_photo_feature /= search_photo_feature.norm(dim=-1, keepdim=True) |
|
image_encoded = search_photo_feature.cpu().numpy() |
|
return image_encoded |
|
|
|
|
|
def similarity(feature, label_features): |
|
similarities = list((feature @ label_features.T).squeeze(0)) |
|
return similarities |
|
|
|
|
|
def find_best_matches(image): |
|
image_features = encode_image(image) |
|
similarities = similarity(image_features, LABEL_FEATURES) |
|
best_spec = sorted(zip(similarities, range(LABEL_FEATURES.shape[0])), key=lambda x: x[0], reverse=True) |
|
idx = best_spec[0][1] |
|
label = ALL_LABELS[idx] |
|
return label |
|
|
|
|
|
gr.Interface(fn=find_best_matches, |
|
inputs=[ |
|
gr.inputs.Image(label="Image to classify", optional=True), |
|
], |
|
theme="grass", |
|
outputs=gr.outputs.Label(), enable_queue=True, title="North American Bird Classifier", |
|
description="This application can classify North American Birds.").launch() |
|
|
|
|
|
|
|
|
|
|