|
import argparse |
|
import os |
|
import torch |
|
import torch.nn as nn |
|
from datasets import load_dataset |
|
from torch.utils.data import DataLoader |
|
from collections import defaultdict |
|
import numpy as np |
|
from PIL import Image |
|
|
|
|
|
class MLP(nn.Module): |
|
def __init__(self, input_size, hidden_sizes, output_size): |
|
super(MLP, self).__init__() |
|
layers = [] |
|
sizes = [input_size] + hidden_sizes + [output_size] |
|
for i in range(len(sizes) - 1): |
|
layers.append(nn.Linear(sizes[i], sizes[i+1])) |
|
if i < len(sizes) - 2: |
|
layers.append(nn.ReLU()) |
|
self.model = nn.Sequential(*layers) |
|
|
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
|
|
class TinyImageNetDataset(Dataset): |
|
def __init__(self, dataset): |
|
self.dataset = dataset |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, idx): |
|
example = self.dataset[idx] |
|
img = example['image'] |
|
img = np.array(img.convert('L')) |
|
img = img.reshape(-1) |
|
img = torch.from_numpy(img).float() |
|
label = torch.tensor(example['label']) |
|
return img, label |
|
|
|
|
|
def evaluate_model(model, val_loader, num_classes): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
model.eval() |
|
|
|
class_correct = defaultdict(int) |
|
class_total = defaultdict(int) |
|
|
|
with torch.no_grad(): |
|
for inputs, labels in val_loader: |
|
inputs, labels = inputs.to(device), labels.to(device) |
|
|
|
outputs = model(inputs) |
|
_, predicted = torch.max(outputs, 1) |
|
|
|
for label, prediction in zip(labels, predicted): |
|
if label == prediction: |
|
class_correct[label.item()] += 1 |
|
class_total[label.item()] += 1 |
|
|
|
class_accuracies = {} |
|
for class_idx in range(num_classes): |
|
if class_total[class_idx] > 0: |
|
class_accuracies[class_idx] = 100 * class_correct[class_idx] / class_total[class_idx] |
|
else: |
|
class_accuracies[class_idx] = 0.0 |
|
|
|
return class_accuracies |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description='Evaluate the MLP model on the zh-plus/tiny-imagenet dataset.') |
|
parser.add_argument('--checkpoint', type=str, required=True, help='Path to the model checkpoint') |
|
parser.add_argument('--layer_count', type=int, default=2, help='Number of hidden layers (default: 2)') |
|
parser.add_argument('--width', type=int, default=512, help='Number of neurons per hidden layer (default: 512)') |
|
parser.add_argument('--output_file', type=str, default='class_accuracies.txt', help='Output file to save class-wise accuracies') |
|
args = parser.parse_args() |
|
|
|
|
|
dataset = load_dataset('zh-plus/tiny-imagenet') |
|
val_dataset = dataset['valid'] |
|
|
|
|
|
num_classes = len(set(val_dataset['label'])) |
|
|
|
|
|
image_size = 64 |
|
|
|
|
|
input_size = image_size * image_size |
|
hidden_sizes = [args.width] * args.layer_count |
|
output_size = num_classes |
|
|
|
model = MLP(input_size, hidden_sizes, output_size) |
|
model.load_state_dict(torch.load(args.checkpoint)) |
|
|
|
|
|
val_loader = DataLoader(TinyImageNetDataset(val_dataset), batch_size=8, shuffle=False) |
|
|
|
|
|
class_accuracies = evaluate_model(model, val_loader, num_classes) |
|
|
|
|
|
print("Class-wise accuracies:") |
|
for class_idx, accuracy in class_accuracies.items(): |
|
print(f"Class {class_idx}: {accuracy:.2f}%") |
|
|
|
|
|
with open(args.output_file, 'w') as f: |
|
for class_idx, accuracy in class_accuracies.items(): |
|
f.write(f"Class {class_idx}: {accuracy:.2f}%\n") |
|
|
|
if __name__ == '__main__': |
|
main() |