Monyrak commited on
Commit
dafa0bc
·
verified ·
1 Parent(s): 5b015c6

Upload zero_shot_classification.py

Browse files
Files changed (1) hide show
  1. zero_shot_classification.py +90 -0
zero_shot_classification.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+ from datasets import load_dataset
3
+ from PIL import Image
4
+ import io
5
+ from tqdm import tqdm
6
+ from sklearn.metrics import accuracy_score, precision_score, recall_score
7
+ import os
8
+
9
+ # Clear the dataset cache
10
+ cache_dir = os.path.expanduser("~/.cache/huggingface/datasets")
11
+ if os.path.exists(cache_dir):
12
+ import shutil
13
+ shutil.rmtree(cache_dir)
14
+
15
+ # Load the CLIP model for zero-shot classification
16
+ print("Loading CLIP model...")
17
+ checkpoint = "openai/clip-vit-large-patch14"
18
+ detector = pipeline(model=checkpoint, task="zero-shot-image-classification")
19
+
20
+ # Load the Oxford Pets dataset
21
+ print("Loading Oxford Pets dataset...")
22
+ try:
23
+ # Only use first 100 images for faster testing
24
+ dataset = load_dataset('pcuenq/oxford-pets', split='train[:100]')
25
+ print(f"Loaded {len(dataset)} images")
26
+ except Exception as e:
27
+ print(f"Error loading dataset: {e}")
28
+ exit(1)
29
+
30
+ # Define the labels for Oxford Pets
31
+ labels_oxford_pets = [
32
+ 'Siamese', 'Birman', 'shiba inu', 'staffordshire bull terrier', 'basset hound', 'Bombay', 'japanese chin',
33
+ 'chihuahua', 'german shorthaired', 'pomeranian', 'beagle', 'english cocker spaniel', 'american pit bull terrier',
34
+ 'Ragdoll', 'Persian', 'Egyptian Mau', 'miniature pinscher', 'Sphynx', 'Maine Coon', 'keeshond', 'yorkshire terrier',
35
+ 'havanese', 'leonberger', 'wheaten terrier', 'american bulldog', 'english setter', 'boxer', 'newfoundland', 'Bengal',
36
+ 'samoyed', 'British Shorthair', 'great pyrenees', 'Abyssinian', 'pug', 'saint bernard', 'Russian Blue', 'scottish terrier'
37
+ ]
38
+
39
+ # Lists to store true and predicted labels
40
+ true_labels = []
41
+ predicted_labels = []
42
+
43
+ print("Processing images...")
44
+ for i in tqdm(range(len(dataset)), desc="Processing images"):
45
+ try:
46
+ # Get the image bytes from the dataset
47
+ image_bytes = dataset[i]['image']['bytes']
48
+
49
+ # Convert the bytes to a PIL image
50
+ image = Image.open(io.BytesIO(image_bytes))
51
+
52
+ # Run the detector on the image with the provided labels
53
+ results = detector(image, candidate_labels=labels_oxford_pets)
54
+ # Sort the results by score in descending order
55
+ sorted_results = sorted(results, key=lambda x: x['score'], reverse=True)
56
+
57
+ # Get the top predicted label
58
+ predicted_label = sorted_results[0]['label']
59
+
60
+ # Append the true and predicted labels to the respective lists
61
+ true_labels.append(dataset[i]['label'])
62
+ predicted_labels.append(predicted_label)
63
+
64
+ # Print progress every 10 images
65
+ if (i + 1) % 10 == 0:
66
+ print(f"Processed {i + 1}/{len(dataset)} images")
67
+
68
+ except Exception as e:
69
+ print(f"Error processing image {i}: {e}")
70
+ continue
71
+
72
+ # Calculate metrics
73
+ accuracy = accuracy_score(true_labels, predicted_labels)
74
+ precision = precision_score(true_labels, predicted_labels, average='weighted', labels=labels_oxford_pets)
75
+ recall = recall_score(true_labels, predicted_labels, average='weighted', labels=labels_oxford_pets)
76
+
77
+ # Print and save results
78
+ results = f"""
79
+ Zero-Shot Classification Results using CLIP (openai/clip-vit-large-patch14)
80
+ ====================================================================
81
+ Accuracy: {accuracy:.4f}
82
+ Precision: {precision:.4f}
83
+ Recall: {recall:.4f}
84
+ """
85
+
86
+ print(results)
87
+
88
+ # Save results to a file
89
+ with open('zero_shot_results.md', 'w') as f:
90
+ f.write(results)