import torch from torch.optim.optimizer import Optimizer import pytorch_lightning as pl from torch.optim.lr_scheduler import _LRScheduler class BaseScheduler(object): """Base class for the step-wise scheduler logic. Args: optimizer (Optimize): Optimizer instance to apply lr schedule on. Subclass this and overwrite ``_get_lr`` to write your own step-wise scheduler. """ def __init__(self, optimizer): self.optimizer = optimizer self.step_num = 0 def zero_grad(self): self.optimizer.zero_grad() def _get_lr(self): raise NotImplementedError def _set_lr(self, lr): for param_group in self.optimizer.param_groups: param_group["lr"] = lr def step(self, metrics=None, epoch=None): """Update step-wise learning rate before optimizer.step.""" self.step_num += 1 lr = self._get_lr() self._set_lr(lr) def load_state_dict(self, state_dict): self.__dict__.update(state_dict) def state_dict(self): return {key: value for key, value in self.__dict__.items() if key != "optimizer"} def as_tensor(self, start=0, stop=100_000): """Returns the scheduler values from start to stop.""" lr_list = [] for _ in range(start, stop): self.step_num += 1 lr_list.append(self._get_lr()) self.step_num = 0 return torch.tensor(lr_list) def plot(self, start=0, stop=100_000): # noqa """Plot the scheduler values from start to stop.""" import matplotlib.pyplot as plt all_lr = self.as_tensor(start=start, stop=stop) plt.plot(all_lr.numpy()) plt.show() class DPTNetScheduler(BaseScheduler): """Dual Path Transformer Scheduler used in [1] Args: optimizer (Optimizer): Optimizer instance to apply lr schedule on. steps_per_epoch (int): Number of steps per epoch. d_model(int): The number of units in the layer output. warmup_steps (int): The number of steps in the warmup stage of training. noam_scale (float): Linear increase rate in first phase. exp_max (float): Max learning rate in second phase. exp_base (float): Exp learning rate base in second phase. Schedule: This scheduler increases the learning rate linearly for the first ``warmup_steps``, and then decay it by 0.98 for every two epochs. References [1]: Jingjing Chen et al. "Dual-Path Transformer Network: Direct Context- Aware Modeling for End-to-End Monaural Speech Separation" Interspeech 2020. """ def __init__( self, optimizer, steps_per_epoch, d_model, warmup_steps=4000, noam_scale=1.0, exp_max=0.0004, exp_base=0.98, ): super().__init__(optimizer) self.noam_scale = noam_scale self.d_model = d_model self.warmup_steps = warmup_steps self.exp_max = exp_max self.exp_base = exp_base self.steps_per_epoch = steps_per_epoch self.epoch = 0 def _get_lr(self): if self.step_num % self.steps_per_epoch == 0: self.epoch += 1 if self.step_num > self.warmup_steps: # exp decaying lr = self.exp_max * (self.exp_base ** ((self.epoch - 1) // 2)) else: # noam lr = ( self.noam_scale * self.d_model ** (-0.5) * min(self.step_num ** (-0.5), self.step_num * self.warmup_steps ** (-1.5)) ) return lr class CustomExponentialLR(_LRScheduler): def __init__(self, optimizer, gamma, step_size, last_epoch=-1): self.gamma = gamma self.step_size = step_size self.base_lrs = list(map(lambda group: group['lr'], optimizer.param_groups)) super(CustomExponentialLR, self).__init__(optimizer, last_epoch) def get_lr(self): if self.last_epoch == 0 or (self.last_epoch + 1) % self.step_size != 0: return [group['lr'] for group in self.optimizer.param_groups] return [lr * self.gamma for lr in self.base_lrs] # Backward compat _BaseScheduler = BaseScheduler