File size: 4,370 Bytes
6398b6f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import argparse
import os
import torch
import torch.nn as nn
from datasets import load_dataset
from torch.utils.data import DataLoader
from collections import defaultdict
import numpy as np
from PIL import Image
# Define the MLP model (same as in the training script)
class MLP(nn.Module):
def __init__(self, input_size, hidden_sizes, output_size):
super(MLP, self).__init__()
layers = []
sizes = [input_size] + hidden_sizes + [output_size]
for i in range(len(sizes) - 1):
layers.append(nn.Linear(sizes[i], sizes[i+1]))
if i < len(sizes) - 2:
layers.append(nn.ReLU())
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
# Custom Dataset class to handle image preprocessing (same as in the training script)
class TinyImageNetDataset(Dataset):
def __init__(self, dataset):
self.dataset = dataset
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
example = self.dataset[idx]
img = example['image']
img = np.array(img.convert('L')) # Convert PIL image to grayscale NumPy array
img = img.reshape(-1) # Flatten the image
img = torch.from_numpy(img).float() # Convert to tensor
label = torch.tensor(example['label'])
return img, label
# Function to evaluate the model on the validation set and compute class-wise accuracy
def evaluate_model(model, val_loader, num_classes):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
class_correct = defaultdict(int)
class_total = defaultdict(int)
with torch.no_grad():
for inputs, labels in val_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs, 1)
for label, prediction in zip(labels, predicted):
if label == prediction:
class_correct[label.item()] += 1
class_total[label.item()] += 1
class_accuracies = {}
for class_idx in range(num_classes):
if class_total[class_idx] > 0:
class_accuracies[class_idx] = 100 * class_correct[class_idx] / class_total[class_idx]
else:
class_accuracies[class_idx] = 0.0
return class_accuracies
# Main function to load the model and evaluate it
def main():
parser = argparse.ArgumentParser(description='Evaluate the MLP model on the zh-plus/tiny-imagenet dataset.')
parser.add_argument('--checkpoint', type=str, required=True, help='Path to the model checkpoint')
parser.add_argument('--layer_count', type=int, default=2, help='Number of hidden layers (default: 2)')
parser.add_argument('--width', type=int, default=512, help='Number of neurons per hidden layer (default: 512)')
parser.add_argument('--output_file', type=str, default='class_accuracies.txt', help='Output file to save class-wise accuracies')
args = parser.parse_args()
# Load the zh-plus/tiny-imagenet dataset
dataset = load_dataset('zh-plus/tiny-imagenet')
val_dataset = dataset['valid'] # Assuming 'validation' is the correct key
# Determine the number of classes
num_classes = len(set(val_dataset['label']))
# Determine the fixed resolution of the images
image_size = 64 # Assuming the images are square
# Define the model
input_size = image_size * image_size # Since images are grayscale
hidden_sizes = [args.width] * args.layer_count
output_size = num_classes
model = MLP(input_size, hidden_sizes, output_size)
model.load_state_dict(torch.load(args.checkpoint))
# Create DataLoader for validation
val_loader = DataLoader(TinyImageNetDataset(val_dataset), batch_size=8, shuffle=False)
# Evaluate the model
class_accuracies = evaluate_model(model, val_loader, num_classes)
# Print the results
print("Class-wise accuracies:")
for class_idx, accuracy in class_accuracies.items():
print(f"Class {class_idx}: {accuracy:.2f}%")
# Save the results to a text file
with open(args.output_file, 'w') as f:
for class_idx, accuracy in class_accuracies.items():
f.write(f"Class {class_idx}: {accuracy:.2f}%\n")
if __name__ == '__main__':
main() |