|
import gradio as gr |
|
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification |
|
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize |
|
from PIL import Image |
|
from datasets import load_dataset |
|
|
|
|
|
processor = AutoProcessor.from_pretrained("DGurgurov/clip-vit-base-patch32-oxford-pets") |
|
model = AutoModelForZeroShotImageClassification.from_pretrained("DGurgurov/clip-vit-base-patch32-oxford-pets") |
|
|
|
|
|
dataset = load_dataset("pcuenq/oxford-pets") |
|
|
|
labels = list(set(dataset['train']['label'])) |
|
label2id = {label: i for i, label in enumerate(labels)} |
|
id2label = {i: label for label, i in label2id.items()} |
|
|
|
|
|
def classify_image(image): |
|
|
|
image = Image.fromarray(image) |
|
inputs = processor(text=labels, images=image, return_tensors="pt", padding=True) |
|
|
|
|
|
outputs = model(**inputs) |
|
|
|
|
|
logits_per_image = outputs.logits_per_image |
|
probs = logits_per_image[0].softmax(dim=0) |
|
|
|
|
|
predicted_label_id = probs.argmax().item() |
|
|
|
predicted_label = id2label[predicted_label_id] |
|
|
|
return predicted_label |
|
|
|
|
|
|
|
iface = gr.Interface( |
|
fn=classify_image, |
|
inputs=gr.Image(label="Upload a picture of an animal"), |
|
outputs=gr.Textbox(label="Predicted Animal"), |
|
title="Animal Classifier", |
|
description="CLIP-based model fine-tuned on Oxford Pets dataset to classify animals.", |
|
) |
|
|
|
|
|
iface.launch() |