File size: 1,167 Bytes
8ec10cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
import math
from bisect import bisect_right
from torch.optim.lr_scheduler import _LRScheduler
# MultiStep learning rate scheduler with warm restart
class WarmMultiStepLR(_LRScheduler):
def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1, scale=1):
if not list(milestones) == sorted(milestones):
raise ValueError(
'Milestones should be a list of increasing integers. Got {}',
milestones
)
self.milestones = milestones
self.gamma = gamma
self.scale = scale
self.warmup_epochs = 5
self.gradual = (self.scale - 1) / self.warmup_epochs
super(WarmMultiStepLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
if self.last_epoch < self.warmup_epochs:
return [
base_lr * (1 + self.last_epoch * self.gradual) / self.scale
for base_lr in self.base_lrs
]
else:
return [
base_lr * self.gamma ** bisect_right(self.milestones, self.last_epoch)
for base_lr in self.base_lrs
]
|