DGurgurov's picture
Update app.py
ad2f296 verified
import gradio as gr
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from PIL import Image
from datasets import load_dataset
# Load your fine-tuned model and dataset
processor = AutoProcessor.from_pretrained("DGurgurov/clip-vit-base-patch32-oxford-pets")
model = AutoModelForZeroShotImageClassification.from_pretrained("DGurgurov/clip-vit-base-patch32-oxford-pets")
# Load dataset to get labels
dataset = load_dataset("pcuenq/oxford-pets") # Adjust dataset loading as per your setup
labels = list(set(dataset['train']['label']))
label2id = {label: i for i, label in enumerate(labels)}
id2label = {i: label for label, i in label2id.items()}
# Function to classify image using CLIP model
def classify_image(image):
# Preprocess the image
image = Image.fromarray(image)
inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)
# Run inference
outputs = model(**inputs)
# Extract logits and apply softmax
logits_per_image = outputs.logits_per_image # logits_per_image is a tensor with shape [1, num_labels]
probs = logits_per_image[0].softmax(dim=0) # Take the softmax across the labels
# Get predicted label id and score
predicted_label_id = probs.argmax().item()
predicted_label = id2label[predicted_label_id]
return predicted_label
# Gradio interface
iface = gr.Interface(
fn=classify_image,
inputs=gr.Image(label="Upload a picture of an animal"),
outputs=gr.Textbox(label="Predicted Animal"),
title="Animal Classifier",
description="CLIP-based model fine-tuned on Oxford Pets dataset to classify animals.",
)
# Launch the Gradio interface
iface.launch()