Spaces:
Sleeping
Sleeping
Upload zero_shot_classification.py
Browse files- zero_shot_classification.py +90 -0
zero_shot_classification.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import pipeline
|
2 |
+
from datasets import load_dataset
|
3 |
+
from PIL import Image
|
4 |
+
import io
|
5 |
+
from tqdm import tqdm
|
6 |
+
from sklearn.metrics import accuracy_score, precision_score, recall_score
|
7 |
+
import os
|
8 |
+
|
9 |
+
# Clear the dataset cache
|
10 |
+
cache_dir = os.path.expanduser("~/.cache/huggingface/datasets")
|
11 |
+
if os.path.exists(cache_dir):
|
12 |
+
import shutil
|
13 |
+
shutil.rmtree(cache_dir)
|
14 |
+
|
15 |
+
# Load the CLIP model for zero-shot classification
|
16 |
+
print("Loading CLIP model...")
|
17 |
+
checkpoint = "openai/clip-vit-large-patch14"
|
18 |
+
detector = pipeline(model=checkpoint, task="zero-shot-image-classification")
|
19 |
+
|
20 |
+
# Load the Oxford Pets dataset
|
21 |
+
print("Loading Oxford Pets dataset...")
|
22 |
+
try:
|
23 |
+
# Only use first 100 images for faster testing
|
24 |
+
dataset = load_dataset('pcuenq/oxford-pets', split='train[:100]')
|
25 |
+
print(f"Loaded {len(dataset)} images")
|
26 |
+
except Exception as e:
|
27 |
+
print(f"Error loading dataset: {e}")
|
28 |
+
exit(1)
|
29 |
+
|
30 |
+
# Define the labels for Oxford Pets
|
31 |
+
labels_oxford_pets = [
|
32 |
+
'Siamese', 'Birman', 'shiba inu', 'staffordshire bull terrier', 'basset hound', 'Bombay', 'japanese chin',
|
33 |
+
'chihuahua', 'german shorthaired', 'pomeranian', 'beagle', 'english cocker spaniel', 'american pit bull terrier',
|
34 |
+
'Ragdoll', 'Persian', 'Egyptian Mau', 'miniature pinscher', 'Sphynx', 'Maine Coon', 'keeshond', 'yorkshire terrier',
|
35 |
+
'havanese', 'leonberger', 'wheaten terrier', 'american bulldog', 'english setter', 'boxer', 'newfoundland', 'Bengal',
|
36 |
+
'samoyed', 'British Shorthair', 'great pyrenees', 'Abyssinian', 'pug', 'saint bernard', 'Russian Blue', 'scottish terrier'
|
37 |
+
]
|
38 |
+
|
39 |
+
# Lists to store true and predicted labels
|
40 |
+
true_labels = []
|
41 |
+
predicted_labels = []
|
42 |
+
|
43 |
+
print("Processing images...")
|
44 |
+
for i in tqdm(range(len(dataset)), desc="Processing images"):
|
45 |
+
try:
|
46 |
+
# Get the image bytes from the dataset
|
47 |
+
image_bytes = dataset[i]['image']['bytes']
|
48 |
+
|
49 |
+
# Convert the bytes to a PIL image
|
50 |
+
image = Image.open(io.BytesIO(image_bytes))
|
51 |
+
|
52 |
+
# Run the detector on the image with the provided labels
|
53 |
+
results = detector(image, candidate_labels=labels_oxford_pets)
|
54 |
+
# Sort the results by score in descending order
|
55 |
+
sorted_results = sorted(results, key=lambda x: x['score'], reverse=True)
|
56 |
+
|
57 |
+
# Get the top predicted label
|
58 |
+
predicted_label = sorted_results[0]['label']
|
59 |
+
|
60 |
+
# Append the true and predicted labels to the respective lists
|
61 |
+
true_labels.append(dataset[i]['label'])
|
62 |
+
predicted_labels.append(predicted_label)
|
63 |
+
|
64 |
+
# Print progress every 10 images
|
65 |
+
if (i + 1) % 10 == 0:
|
66 |
+
print(f"Processed {i + 1}/{len(dataset)} images")
|
67 |
+
|
68 |
+
except Exception as e:
|
69 |
+
print(f"Error processing image {i}: {e}")
|
70 |
+
continue
|
71 |
+
|
72 |
+
# Calculate metrics
|
73 |
+
accuracy = accuracy_score(true_labels, predicted_labels)
|
74 |
+
precision = precision_score(true_labels, predicted_labels, average='weighted', labels=labels_oxford_pets)
|
75 |
+
recall = recall_score(true_labels, predicted_labels, average='weighted', labels=labels_oxford_pets)
|
76 |
+
|
77 |
+
# Print and save results
|
78 |
+
results = f"""
|
79 |
+
Zero-Shot Classification Results using CLIP (openai/clip-vit-large-patch14)
|
80 |
+
====================================================================
|
81 |
+
Accuracy: {accuracy:.4f}
|
82 |
+
Precision: {precision:.4f}
|
83 |
+
Recall: {recall:.4f}
|
84 |
+
"""
|
85 |
+
|
86 |
+
print(results)
|
87 |
+
|
88 |
+
# Save results to a file
|
89 |
+
with open('zero_shot_results.md', 'w') as f:
|
90 |
+
f.write(results)
|