File size: 5,612 Bytes
5d21dd2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import math
from collections import Counter
from collections import defaultdict
import torch
from torch.optim.lr_scheduler import _LRScheduler
class MultiStepLR_Restart(_LRScheduler):
def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1,
clear_state=False, last_epoch=-1):
self.milestones = Counter(milestones)
self.gamma = gamma
self.clear_state = clear_state
self.restarts = restarts if restarts else [0]
self.restart_weights = weights if weights else [1]
assert len(self.restarts) == len(
self.restart_weights), 'restarts and their weights do not match.'
super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch)
def get_lr(self):
if self.last_epoch in self.restarts:
if self.clear_state:
self.optimizer.state = defaultdict(dict)
weight = self.restart_weights[self.restarts.index(self.last_epoch)]
return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
if self.last_epoch not in self.milestones:
return [group['lr'] for group in self.optimizer.param_groups]
return [
group['lr'] * self.gamma**self.milestones[self.last_epoch]
for group in self.optimizer.param_groups
class CosineAnnealingLR_Restart(_LRScheduler):
def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1):
self.T_period = T_period
self.T_max = self.T_period[0] # current T period
self.eta_min = eta_min
self.restarts = restarts if restarts else [0]
self.restart_weights = weights if weights else [1]
self.last_restart = 0
assert len(self.restarts) == len(
self.restart_weights), 'restarts and their weights do not match.'
super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch)
def get_lr(self):
if self.last_epoch == 0:
return self.base_lrs
elif self.last_epoch in self.restarts:
self.last_restart = self.last_epoch
self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1]
weight = self.restart_weights[self.restarts.index(self.last_epoch)]
return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0:
return [
group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
return [(1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) /
(1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) *
(group['lr'] - self.eta_min) + self.eta_min
for group in self.optimizer.param_groups]
if __name__ == "__main__":
optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=2e-4, weight_decay=0,
betas=(0.9, 0.99))
# MultiStepLR_Restart
## Original
lr_steps = [200000, 400000, 600000, 800000]
restarts = None
restart_weights = None
## two
lr_steps = [100000, 200000, 300000, 400000, 490000, 600000, 700000, 800000, 900000, 990000]
restarts = [500000]
restart_weights = [1]
## four
lr_steps = [
50000, 100000, 150000, 200000, 240000, 300000, 350000, 400000, 450000, 490000, 550000,
600000, 650000, 700000, 740000, 800000, 850000, 900000, 950000, 990000
restarts = [250000, 500000, 750000]
restart_weights = [1, 1, 1]
scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5,
# Cosine Annealing Restart
## two
T_period = [500000, 500000]
restarts = [500000]
restart_weights = [1]
## four
T_period = [250000, 250000, 250000, 250000]
restarts = [250000, 500000, 750000]
restart_weights = [1, 1, 1]
scheduler = CosineAnnealingLR_Restart(optimizer, T_period, eta_min=1e-7, restarts=restarts,
# Draw figure
N_iter = 1000000
lr_l = list(range(N_iter))
for i in range(N_iter):
current_lr = optimizer.param_groups[0]['lr']
lr_l[i] = current_lr
import matplotlib as mpl
from matplotlib import pyplot as plt
import matplotlib.ticker as mtick'default')
import seaborn
plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
plt.title('Title', fontsize=16, color='k')
plt.plot(list(range(N_iter)), lr_l, linewidth=1.5, label='learning rate scheme')
legend = plt.legend(loc='upper right', shadow=False)
ax = plt.gca()
labels = ax.get_xticks().tolist()
for k, v in enumerate(labels):
labels[k] = str(int(v / 1000)) + 'K'
ax.set_ylabel('Learning rate')
fig = plt.gcf()