harry commited on
Commit
aaea685
·
1 Parent(s): f774571

feat: update training loop with learning rate scheduler and progress bar enhancements

Browse files
mnist_classifier/train.py CHANGED
@@ -10,6 +10,7 @@ import os
10
  import random
11
  import numpy as np
12
  from tqdm import tqdm
 
13
 
14
  def set_seed(seed):
15
  torch.manual_seed(seed)
@@ -23,7 +24,7 @@ def set_seed(seed):
23
  def train():
24
  # Training loop
25
  learning_rate = 0.001
26
- batch_size = 128
27
  epochs = 10
28
 
29
  # Set seed for reproducibility
@@ -34,7 +35,7 @@ def train():
34
  print(f"Using device: {device}")
35
 
36
  # Initialize tensorboard
37
- log_dir = 'runs/mnist_experiment_' + datetime.now().strftime('%Y%m%d-%H%M%S')
38
  writer = SummaryWriter(log_dir)
39
 
40
  # Setup data
@@ -44,6 +45,7 @@ def train():
44
  # Initialize model, optimizer, and loss function
45
  model = MNISTModel().to(device)
46
  optimizer = optim.Adam(model.parameters(), lr=learning_rate)
 
47
  criterion = nn.CrossEntropyLoss()
48
 
49
 
@@ -53,6 +55,7 @@ def train():
53
  running_loss = 0.0
54
  correct = 0
55
  total = 0
 
56
 
57
  with tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch") as pbar:
58
  for batch_idx, batch in enumerate(train_loader):
@@ -76,11 +79,13 @@ def train():
76
  total += labels.size(0)
77
  correct += predicted.eq(labels).sum().item()
78
 
 
79
  # Update tqdm progress bar
80
  pbar.set_postfix({
81
  'loss': running_loss / (batch_idx + 1),
82
  'accuracy': 100. * correct / total,
83
- 'step': batch_idx + 1
 
84
  })
85
  pbar.update(1)
86
 
@@ -93,6 +98,10 @@ def train():
93
  epoch * len(train_loader) + batch_idx)
94
  running_loss = 0.0
95
 
 
 
 
 
96
  # Validation phase
97
  model.eval()
98
  test_loss = 0
 
10
  import random
11
  import numpy as np
12
  from tqdm import tqdm
13
+ from torch.optim.lr_scheduler import StepLR
14
 
15
  def set_seed(seed):
16
  torch.manual_seed(seed)
 
24
  def train():
25
  # Training loop
26
  learning_rate = 0.001
27
+ batch_size = 64
28
  epochs = 10
29
 
30
  # Set seed for reproducibility
 
35
  print(f"Using device: {device}")
36
 
37
  # Initialize tensorboard
38
+ log_dir = 'runs/mnist_experiment_' + f"lr{learning_rate}_bs{batch_size}_ep{epochs}_" + datetime.now().strftime('%Y%m%d-%H%M%S')
39
  writer = SummaryWriter(log_dir)
40
 
41
  # Setup data
 
45
  # Initialize model, optimizer, and loss function
46
  model = MNISTModel().to(device)
47
  optimizer = optim.Adam(model.parameters(), lr=learning_rate)
48
+ scheduler = StepLR(optimizer, step_size=2, gamma=0.5) # Decay LR by a factor of 0.1 every 2 epochs
49
  criterion = nn.CrossEntropyLoss()
50
 
51
 
 
55
  running_loss = 0.0
56
  correct = 0
57
  total = 0
58
+ current_lr = optimizer.param_groups[0]['lr'] # Get current learning rate
59
 
60
  with tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch") as pbar:
61
  for batch_idx, batch in enumerate(train_loader):
 
79
  total += labels.size(0)
80
  correct += predicted.eq(labels).sum().item()
81
 
82
+
83
  # Update tqdm progress bar
84
  pbar.set_postfix({
85
  'loss': running_loss / (batch_idx + 1),
86
  'accuracy': 100. * correct / total,
87
+ 'step': batch_idx + 1,
88
+ 'lr': current_lr,
89
  })
90
  pbar.update(1)
91
 
 
98
  epoch * len(train_loader) + batch_idx)
99
  running_loss = 0.0
100
 
101
+ writer.add_scalar('learning rate', current_lr, epoch)
102
+
103
+ scheduler.step() # Update the learning rate
104
+
105
  # Validation phase
106
  model.eval()
107
  test_loss = 0
models/mnist_model_lr0.001_bs32_ep10.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d77abc86d2b18d3a3b4b7a560186ac1695d0c2e8e708028dd2b65211cefde6ae
3
+ size 4803144
models/mnist_model_lr0.001_bs64_ep10.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d1b474acf8a447dea4e3aaaf0371346ee7a7055d1c716fb371c059b9a1799bab
3
  size 4803144
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b7ed0e38b2b8663379c3655a73c1e6f7e4165bf7d9f792491c7cc9fa99e1e97f
3
  size 4803144