TeacherPuffy commited on
Commit
6398b6f
·
verified ·
1 Parent(s): e378735

Create eval_model.py

Browse files
Files changed (1) hide show
  1. eval_model.py +117 -0
eval_model.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import torch
4
+ import torch.nn as nn
5
+ from datasets import load_dataset
6
+ from torch.utils.data import DataLoader
7
+ from collections import defaultdict
8
+ import numpy as np
9
+ from PIL import Image
10
+
11
+ # Define the MLP model (same as in the training script)
12
+ class MLP(nn.Module):
13
+ def __init__(self, input_size, hidden_sizes, output_size):
14
+ super(MLP, self).__init__()
15
+ layers = []
16
+ sizes = [input_size] + hidden_sizes + [output_size]
17
+ for i in range(len(sizes) - 1):
18
+ layers.append(nn.Linear(sizes[i], sizes[i+1]))
19
+ if i < len(sizes) - 2:
20
+ layers.append(nn.ReLU())
21
+ self.model = nn.Sequential(*layers)
22
+
23
+ def forward(self, x):
24
+ return self.model(x)
25
+
26
+ # Custom Dataset class to handle image preprocessing (same as in the training script)
27
+ class TinyImageNetDataset(Dataset):
28
+ def __init__(self, dataset):
29
+ self.dataset = dataset
30
+
31
+ def __len__(self):
32
+ return len(self.dataset)
33
+
34
+ def __getitem__(self, idx):
35
+ example = self.dataset[idx]
36
+ img = example['image']
37
+ img = np.array(img.convert('L')) # Convert PIL image to grayscale NumPy array
38
+ img = img.reshape(-1) # Flatten the image
39
+ img = torch.from_numpy(img).float() # Convert to tensor
40
+ label = torch.tensor(example['label'])
41
+ return img, label
42
+
43
+ # Function to evaluate the model on the validation set and compute class-wise accuracy
44
+ def evaluate_model(model, val_loader, num_classes):
45
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
+ model.to(device)
47
+ model.eval()
48
+
49
+ class_correct = defaultdict(int)
50
+ class_total = defaultdict(int)
51
+
52
+ with torch.no_grad():
53
+ for inputs, labels in val_loader:
54
+ inputs, labels = inputs.to(device), labels.to(device)
55
+
56
+ outputs = model(inputs)
57
+ _, predicted = torch.max(outputs, 1)
58
+
59
+ for label, prediction in zip(labels, predicted):
60
+ if label == prediction:
61
+ class_correct[label.item()] += 1
62
+ class_total[label.item()] += 1
63
+
64
+ class_accuracies = {}
65
+ for class_idx in range(num_classes):
66
+ if class_total[class_idx] > 0:
67
+ class_accuracies[class_idx] = 100 * class_correct[class_idx] / class_total[class_idx]
68
+ else:
69
+ class_accuracies[class_idx] = 0.0
70
+
71
+ return class_accuracies
72
+
73
+ # Main function to load the model and evaluate it
74
+ def main():
75
+ parser = argparse.ArgumentParser(description='Evaluate the MLP model on the zh-plus/tiny-imagenet dataset.')
76
+ parser.add_argument('--checkpoint', type=str, required=True, help='Path to the model checkpoint')
77
+ parser.add_argument('--layer_count', type=int, default=2, help='Number of hidden layers (default: 2)')
78
+ parser.add_argument('--width', type=int, default=512, help='Number of neurons per hidden layer (default: 512)')
79
+ parser.add_argument('--output_file', type=str, default='class_accuracies.txt', help='Output file to save class-wise accuracies')
80
+ args = parser.parse_args()
81
+
82
+ # Load the zh-plus/tiny-imagenet dataset
83
+ dataset = load_dataset('zh-plus/tiny-imagenet')
84
+ val_dataset = dataset['valid'] # Assuming 'validation' is the correct key
85
+
86
+ # Determine the number of classes
87
+ num_classes = len(set(val_dataset['label']))
88
+
89
+ # Determine the fixed resolution of the images
90
+ image_size = 64 # Assuming the images are square
91
+
92
+ # Define the model
93
+ input_size = image_size * image_size # Since images are grayscale
94
+ hidden_sizes = [args.width] * args.layer_count
95
+ output_size = num_classes
96
+
97
+ model = MLP(input_size, hidden_sizes, output_size)
98
+ model.load_state_dict(torch.load(args.checkpoint))
99
+
100
+ # Create DataLoader for validation
101
+ val_loader = DataLoader(TinyImageNetDataset(val_dataset), batch_size=8, shuffle=False)
102
+
103
+ # Evaluate the model
104
+ class_accuracies = evaluate_model(model, val_loader, num_classes)
105
+
106
+ # Print the results
107
+ print("Class-wise accuracies:")
108
+ for class_idx, accuracy in class_accuracies.items():
109
+ print(f"Class {class_idx}: {accuracy:.2f}%")
110
+
111
+ # Save the results to a text file
112
+ with open(args.output_file, 'w') as f:
113
+ for class_idx, accuracy in class_accuracies.items():
114
+ f.write(f"Class {class_idx}: {accuracy:.2f}%\n")
115
+
116
+ if __name__ == '__main__':
117
+ main()