File size: 4,247 Bytes
3352589
 
 
 
 
 
 
 
 
 
 
6dc829b
3352589
 
6dc829b
3352589
 
 
 
6dc829b
3352589
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6dc829b
373be07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

def save_checkpoint(model, optimizer, epoch, loss, path):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }, path)
    print(f"Checkpoint saved at epoch {epoch}")

def load_checkpoint(model, optimizer, path):
    checkpoint = torch.load(path, weights_only=True)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    print(f"Checkpoint loaded, resuming from epoch {epoch}")
    return model, optimizer, epoch, loss

def plot_training_curves(epochs, train_acc1, test_acc1, train_acc5, test_acc5, train_losses, test_losses, learning_rates):
    plt.figure(figsize=(12, 8))
    plt.subplot(2, 2, 1)
    plt.plot(epochs, train_acc1, label='Train Top-1 Acc')
    plt.plot(epochs, test_acc1, label='Test Top-1 Acc')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.title('Top-1 Accuracy')

    plt.subplot(2, 2, 2)
    plt.plot(epochs, train_acc5, label='Train Top-5 Acc')
    plt.plot(epochs, test_acc5, label='Test Top-5 Acc')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.title('Top-5 Accuracy')

    plt.subplot(2, 2, 3)
    plt.plot(epochs, train_losses, label='Train Loss')
    plt.plot(epochs, test_losses, label='Test Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Loss')

    plt.subplot(2, 2, 4)
    plt.plot(epochs, learning_rates, label='Learning Rate')
    plt.xlabel('Epoch')
    plt.ylabel('Learning Rate')
    plt.legend()
    plt.title('Learning Rate')

    plt.tight_layout()
    plt.show()

def plot_misclassified_samples(misclassified_images, misclassified_labels, misclassified_preds, classes):
    if misclassified_images:
        print("\nDisplaying some misclassified samples:")
        misclassified_grid = make_grid(misclassified_images[:16], nrow=4, normalize=True, scale_each=True)
        plt.figure(figsize=(8, 8))
        plt.imshow(misclassified_grid.permute(1, 2, 0))
        plt.title("Misclassified Samples")
        plt.axis('off')
        plt.show() 

def find_lr(model, criterion, optimizer, train_loader, num_epochs=1, start_lr=1e-7, end_lr=10, lr_multiplier=1.1):
    """

    Find the optimal learning rate using LR Finder.

    

    Args:

    - model: The model to train

    - criterion: Loss function (e.g., CrossEntropyLoss)

    - optimizer: Optimizer (e.g., SGD)

    - train_loader: DataLoader for training data

    - num_epochs: Number of epochs to run the LR Finder (typically 1-2)

    - start_lr: Starting learning rate for the experiment

    - end_lr: Maximum learning rate (used for scaling)

    - lr_multiplier: Factor by which the learning rate is increased every batch

    

    Returns:

    - A plot of loss vs learning rate

    """
    lrs = []
    losses = []
    avg_loss = 0.0
    batch_count = 0
    
    lr = start_lr
    for epoch in range(num_epochs):
        model.train()
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.param_groups[0]['lr'] = lr  # Set the learning rate
            
            # Forward pass
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            avg_loss += loss.item()
            batch_count += 1
            lrs.append(lr)
            losses.append(loss.item())
            
            # Increase the learning rate for next batch
            lr *= lr_multiplier
        
        avg_loss /= batch_count
        print(f"Epoch [{epoch+1}/{num_epochs}], Avg Loss: {avg_loss:.4f}")
    
    # Plot the loss vs learning rate
    plt.plot(lrs, losses)
    plt.xscale('log')
    plt.xlabel('Learning Rate')
    plt.ylabel('Loss')
    plt.title('Learning Rate Finder')
    plt.show()