|
import gradio as gr |
|
import torch |
|
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification |
|
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize |
|
from PIL import Image |
|
import requests |
|
from datasets import load_dataset |
|
|
|
|
|
processor = AutoProcessor.from_pretrained("DGurgurov/clip-vit-base-patch32-oxford-pets") |
|
model = AutoModelForZeroShotImageClassification.from_pretrained("DGurgurov/clip-vit-base-patch32-oxford-pets") |
|
|
|
|
|
dataset = load_dataset("pcuenq/oxford-pets") |
|
|
|
labels = list(set(dataset['train']['label'])) |
|
label2id = {label: i for i, label in enumerate(labels)} |
|
id2label = {i: label for label, i in label2id.items()} |
|
|
|
|
|
transform = Compose([ |
|
Resize((224, 224)), |
|
CenterCrop(224), |
|
ToTensor(), |
|
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
|
|
def classify_image(image): |
|
|
|
image = Image.fromarray(image) |
|
inputs = processor(text=labels, images=image, return_tensors="pt", padding=True) |
|
|
|
|
|
outputs = model(**inputs) |
|
|
|
|
|
predicted_label_id = torch.argmax(outputs, dim=1).item() |
|
print(predicted_label_id) |
|
predicted_label = id2label[predicted_label_id] |
|
|
|
return predicted_label |
|
|
|
|
|
iface = gr.Interface( |
|
fn=classify_image, |
|
inputs=gr.Image(label="Upload a picture of an animal"), |
|
outputs=gr.Textbox(label="Predicted Animal"), |
|
title="Animal Classifier", |
|
description="CLIP-based model fine-tuned on Oxford Pets dataset to classify animals.", |
|
) |
|
|
|
|
|
iface.launch() |
|
|