File size: 1,776 Bytes
3ef7bdd
 
b1f6a67
67c9814
352c190
f690417
b1f6a67
3ef7bdd
 
ae3733a
b1f6a67
 
 
 
67c9814
 
f690417
b1f6a67
 
 
1f91ca7
cebbd1f
b1f6a67
 
7785c8d
b1f6a67
ad2f296
 
 
 
 
 
 
b1f6a67
 
 
f690417
ad2f296
b1f6a67
3ef7bdd
b1f6a67
 
 
67c9814
b1f6a67
f690417
 
b1f6a67
ad2f296
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import gradio as gr
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from PIL import Image
from datasets import load_dataset 

# Load your fine-tuned model and dataset
processor = AutoProcessor.from_pretrained("DGurgurov/clip-vit-base-patch32-oxford-pets")
model = AutoModelForZeroShotImageClassification.from_pretrained("DGurgurov/clip-vit-base-patch32-oxford-pets")

# Load dataset to get labels
dataset = load_dataset("pcuenq/oxford-pets")  # Adjust dataset loading as per your setup

labels = list(set(dataset['train']['label']))
label2id = {label: i for i, label in enumerate(labels)}
id2label = {i: label for label, i in label2id.items()}

# Function to classify image using CLIP model
def classify_image(image):
    # Preprocess the image
    image = Image.fromarray(image)
    inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)

    # Run inference
    outputs = model(**inputs)

    # Extract logits and apply softmax
    logits_per_image = outputs.logits_per_image  # logits_per_image is a tensor with shape [1, num_labels]
    probs = logits_per_image[0].softmax(dim=0)  # Take the softmax across the labels

    # Get predicted label id and score
    predicted_label_id = probs.argmax().item()

    predicted_label = id2label[predicted_label_id]

    return predicted_label


# Gradio interface
iface = gr.Interface(
    fn=classify_image,
    inputs=gr.Image(label="Upload a picture of an animal"),
    outputs=gr.Textbox(label="Predicted Animal"),
    title="Animal Classifier",
    description="CLIP-based model fine-tuned on Oxford Pets dataset to classify animals.",
)

# Launch the Gradio interface
iface.launch()