DGurgurov's picture
Update app.py
cebbd1f verified
raw
history blame
1.8 kB
import gradio as gr
import torch
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from PIL import Image
import requests
from datasets import load_dataset
# Load your fine-tuned model and dataset
processor = AutoProcessor.from_pretrained("DGurgurov/clip-vit-base-patch32-oxford-pets")
model = AutoModelForZeroShotImageClassification.from_pretrained("DGurgurov/clip-vit-base-patch32-oxford-pets")
# Load dataset to get labels
dataset = load_dataset("pcuenq/oxford-pets") # Adjust dataset loading as per your setup
labels = list(set(dataset['train']['label']))
label2id = {label: i for i, label in enumerate(labels)}
id2label = {i: label for label, i in label2id.items()}
# Define transformations for input images
transform = Compose([
Resize((224, 224)),
CenterCrop(224),
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Function to classify image using CLIP model
def classify_image(image):
# Preprocess the image
image = Image.fromarray(image)
inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)
# Run inference
outputs = model(**inputs)
# Get predicted label
predicted_label_id = torch.argmax(outputs, dim=1).item()
print(predicted_label_id)
predicted_label = id2label[predicted_label_id]
return predicted_label
# Gradio interface
iface = gr.Interface(
fn=classify_image,
inputs=gr.Image(label="Upload a picture of an animal"),
outputs=gr.Textbox(label="Predicted Animal"),
title="Animal Classifier",
description="CLIP-based model fine-tuned on Oxford Pets dataset to classify animals.",
)
# Launch the Gradio interface
iface.launch()