|
import torch
|
|
import torchvision
|
|
import torchvision.transforms as transforms
|
|
from torch.utils.data import DataLoader
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from resnet_model import ResNet50
|
|
from tqdm import tqdm
|
|
|
|
|
|
transform = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
|
|
])
|
|
|
|
|
|
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
|
|
trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4)
|
|
|
|
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
|
|
testloader = DataLoader(testset, batch_size=1000, shuffle=False, num_workers=4)
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
model = ResNet50().to(device)
|
|
criterion = nn.CrossEntropyLoss()
|
|
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
|
|
|
|
|
|
def train(model, device, train_loader, optimizer, criterion, epoch):
|
|
model.train()
|
|
running_loss = 0.0
|
|
correct = 0
|
|
total = 0
|
|
pbar = tqdm(train_loader)
|
|
|
|
for batch_idx, (inputs, targets) in enumerate(pbar):
|
|
inputs, targets = inputs.to(device), targets.to(device)
|
|
optimizer.zero_grad()
|
|
|
|
outputs = model(inputs)
|
|
loss = criterion(outputs, targets)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
running_loss += loss.item()
|
|
_, predicted = outputs.max(1)
|
|
total += targets.size(0)
|
|
correct += predicted.eq(targets).sum().item()
|
|
|
|
pbar.set_description(desc=f'Epoch {epoch} | Loss: {loss.item():.4f} | Accuracy: {100.*correct/total:.2f}%')
|
|
|
|
return 100.*correct/total
|
|
|
|
|
|
def test(model, device, test_loader, criterion):
|
|
model.eval()
|
|
test_loss = 0
|
|
correct = 0
|
|
total = 0
|
|
|
|
with torch.no_grad():
|
|
for inputs, targets in test_loader:
|
|
inputs, targets = inputs.to(device), targets.to(device)
|
|
outputs = model(inputs)
|
|
loss = criterion(outputs, targets)
|
|
|
|
test_loss += loss.item()
|
|
_, predicted = outputs.max(1)
|
|
total += targets.size(0)
|
|
correct += predicted.eq(targets).sum().item()
|
|
|
|
test_accuracy = 100.*correct/total
|
|
print(f'Test Loss: {test_loss/len(test_loader):.4f}, Accuracy: {test_accuracy:.2f}%')
|
|
return test_accuracy
|
|
|
|
|
|
if __name__ == '__main__':
|
|
for epoch in range(1, 6):
|
|
train_accuracy = train(model, device, trainloader, optimizer, criterion, epoch)
|
|
test_accuracy = test(model, device, testloader, criterion)
|
|
print(f'Epoch {epoch} | Train Accuracy: {train_accuracy:.2f}% | Test Accuracy: {test_accuracy:.2f}%') |