import torch import torchvision import torchvision.transforms as transforms from torch.utils.data import DataLoader import torch.nn as nn import torch.optim as optim from resnet_model import ResNet50 from tqdm import tqdm # Define transformations transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)) ]) # Load CIFAR-10 dataset trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4) testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) testloader = DataLoader(testset, batch_size=1000, shuffle=False, num_workers=4) # Initialize model, loss function, and optimizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = ResNet50().to(device) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) # Training function def train(model, device, train_loader, optimizer, criterion, epoch): model.train() running_loss = 0.0 correct = 0 total = 0 pbar = tqdm(train_loader) for batch_idx, (inputs, targets) in enumerate(pbar): inputs, targets = inputs.to(device), targets.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() running_loss += loss.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() pbar.set_description(desc=f'Epoch {epoch} | Loss: {loss.item():.4f} | Accuracy: {100.*correct/total:.2f}%') return 100.*correct/total # Testing function def test(model, device, test_loader, criterion): model.eval() test_loss = 0 correct = 0 total = 0 with torch.no_grad(): for inputs, targets in test_loader: inputs, targets = inputs.to(device), targets.to(device) outputs = model(inputs) loss = criterion(outputs, targets) test_loss += loss.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() test_accuracy = 100.*correct/total print(f'Test Loss: {test_loss/len(test_loader):.4f}, Accuracy: {test_accuracy:.2f}%') return test_accuracy # Main execution if __name__ == '__main__': for epoch in range(1, 6): # 20 epochs train_accuracy = train(model, device, trainloader, optimizer, criterion, epoch) test_accuracy = test(model, device, testloader, criterion) print(f'Epoch {epoch} | Train Accuracy: {train_accuracy:.2f}% | Test Accuracy: {test_accuracy:.2f}%')