import gradio as gr import torch from transformers import AutoProcessor, AutoModelForZeroShotImageClassification from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize from PIL import Image import requests from datasets import load_dataset # Load your fine-tuned model and dataset processor = AutoProcessor.from_pretrained("DGurgurov/clip-vit-base-patch32-oxford-pets") model = AutoModelForZeroShotImageClassification.from_pretrained("DGurgurov/clip-vit-base-patch32-oxford-pets") # Load dataset to get labels dataset = load_dataset("pcuenq/oxford-pets") # Adjust dataset loading as per your setup labels = list(set(dataset['train']['label'])) label2id = {label: i for i, label in enumerate(labels)} id2label = {i: label for label, i in label2id.items()} # Define transformations for input images transform = Compose([ Resize((224, 224)), CenterCrop(224), ToTensor(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Function to classify image using CLIP model def classify_image(image): # Preprocess the image image = Image.fromarray(image) inputs = processor(text=labels, images=image, return_tensors="pt", padding=True) # Run inference outputs = model(**inputs) # Get predicted label predicted_label_id = torch.argmax(outputs, dim=1).item() print(predicted_label_id) predicted_label = id2label[predicted_label_id] return predicted_label # Gradio interface iface = gr.Interface( fn=classify_image, inputs=gr.Image(label="Upload a picture of an animal"), outputs=gr.Textbox(label="Predicted Animal"), title="Animal Classifier", description="CLIP-based model fine-tuned on Oxford Pets dataset to classify animals.", ) # Launch the Gradio interface iface.launch()