File size: 1,776 Bytes
3ef7bdd b1f6a67 67c9814 352c190 f690417 b1f6a67 3ef7bdd ae3733a b1f6a67 67c9814 f690417 b1f6a67 1f91ca7 cebbd1f b1f6a67 7785c8d b1f6a67 ad2f296 b1f6a67 f690417 ad2f296 b1f6a67 3ef7bdd b1f6a67 67c9814 b1f6a67 f690417 b1f6a67 ad2f296 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import gradio as gr
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from PIL import Image
from datasets import load_dataset
# Load your fine-tuned model and dataset
processor = AutoProcessor.from_pretrained("DGurgurov/clip-vit-base-patch32-oxford-pets")
model = AutoModelForZeroShotImageClassification.from_pretrained("DGurgurov/clip-vit-base-patch32-oxford-pets")
# Load dataset to get labels
dataset = load_dataset("pcuenq/oxford-pets") # Adjust dataset loading as per your setup
labels = list(set(dataset['train']['label']))
label2id = {label: i for i, label in enumerate(labels)}
id2label = {i: label for label, i in label2id.items()}
# Function to classify image using CLIP model
def classify_image(image):
# Preprocess the image
image = Image.fromarray(image)
inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)
# Run inference
outputs = model(**inputs)
# Extract logits and apply softmax
logits_per_image = outputs.logits_per_image # logits_per_image is a tensor with shape [1, num_labels]
probs = logits_per_image[0].softmax(dim=0) # Take the softmax across the labels
# Get predicted label id and score
predicted_label_id = probs.argmax().item()
predicted_label = id2label[predicted_label_id]
return predicted_label
# Gradio interface
iface = gr.Interface(
fn=classify_image,
inputs=gr.Image(label="Upload a picture of an animal"),
outputs=gr.Textbox(label="Predicted Animal"),
title="Animal Classifier",
description="CLIP-based model fine-tuned on Oxford Pets dataset to classify animals.",
)
# Launch the Gradio interface
iface.launch() |