|
""" Plateau Scheduler |
|
|
|
Adapts PyTorch plateau scheduler and allows application of noise, warmup. |
|
|
|
Hacked together by / Copyright 2020 Ross Wightman |
|
""" |
|
import torch |
|
from typing import List |
|
|
|
from .scheduler import Scheduler |
|
|
|
|
|
class PlateauLRScheduler(Scheduler): |
|
"""Decay the LR by a factor every time the validation loss plateaus.""" |
|
|
|
def __init__( |
|
self, |
|
optimizer, |
|
decay_rate=0.1, |
|
patience_t=10, |
|
verbose=True, |
|
threshold=1e-4, |
|
cooldown_t=0, |
|
warmup_t=0, |
|
warmup_lr_init=0, |
|
lr_min=0, |
|
mode='max', |
|
noise_range_t=None, |
|
noise_type='normal', |
|
noise_pct=0.67, |
|
noise_std=1.0, |
|
noise_seed=None, |
|
initialize=True, |
|
): |
|
super().__init__( |
|
optimizer, |
|
'lr', |
|
noise_range_t=noise_range_t, |
|
noise_type=noise_type, |
|
noise_pct=noise_pct, |
|
noise_std=noise_std, |
|
noise_seed=noise_seed, |
|
initialize=initialize, |
|
) |
|
|
|
self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( |
|
self.optimizer, |
|
patience=patience_t, |
|
factor=decay_rate, |
|
verbose=verbose, |
|
threshold=threshold, |
|
cooldown=cooldown_t, |
|
mode=mode, |
|
min_lr=lr_min |
|
) |
|
|
|
self.warmup_t = warmup_t |
|
self.warmup_lr_init = warmup_lr_init |
|
if self.warmup_t: |
|
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] |
|
super().update_groups(self.warmup_lr_init) |
|
else: |
|
self.warmup_steps = [1 for _ in self.base_values] |
|
self.restore_lr = None |
|
|
|
def state_dict(self): |
|
return { |
|
'best': self.lr_scheduler.best, |
|
'last_epoch': self.lr_scheduler.last_epoch, |
|
} |
|
|
|
def load_state_dict(self, state_dict): |
|
self.lr_scheduler.best = state_dict['best'] |
|
if 'last_epoch' in state_dict: |
|
self.lr_scheduler.last_epoch = state_dict['last_epoch'] |
|
|
|
|
|
def step(self, epoch, metric=None): |
|
if epoch <= self.warmup_t: |
|
lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps] |
|
super().update_groups(lrs) |
|
else: |
|
if self.restore_lr is not None: |
|
|
|
for i, param_group in enumerate(self.optimizer.param_groups): |
|
param_group['lr'] = self.restore_lr[i] |
|
self.restore_lr = None |
|
|
|
self.lr_scheduler.step(metric, epoch) |
|
|
|
if self._is_apply_noise(epoch): |
|
self._apply_noise(epoch) |
|
|
|
def step_update(self, num_updates: int, metric: float = None): |
|
return None |
|
|
|
def _apply_noise(self, epoch): |
|
noise = self._calculate_noise(epoch) |
|
|
|
|
|
|
|
restore_lr = [] |
|
for i, param_group in enumerate(self.optimizer.param_groups): |
|
old_lr = float(param_group['lr']) |
|
restore_lr.append(old_lr) |
|
new_lr = old_lr + old_lr * noise |
|
param_group['lr'] = new_lr |
|
self.restore_lr = restore_lr |
|
|
|
def _get_lr(self, t: int) -> List[float]: |
|
assert False, 'should not be called as step is overridden' |
|
|