File size: 6,249 Bytes
ab687e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
from bisect import bisect_right
from timm.scheduler.cosine_lr import CosineLRScheduler
from timm.scheduler.step_lr import StepLRScheduler
from timm.scheduler.scheduler import Scheduler
import torch
import torch.distributed as dist
def build_scheduler(config, optimizer, n_iter_per_epoch):
num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch)
warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch)
decay_steps = int(
config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch)
multi_steps = [
i * n_iter_per_epoch for i in config.TRAIN.LR_SCHEDULER.MULTISTEPS]
lr_scheduler = None
if config.TRAIN.LR_SCHEDULER.NAME == 'cosine':
lr_scheduler = CosineLRScheduler(
optimizer,
t_initial=num_steps,
cycle_mul=1.,
lr_min=config.TRAIN.MIN_LR,
warmup_lr_init=config.TRAIN.WARMUP_LR,
warmup_t=warmup_steps,
cycle_limit=1,
t_in_epochs=False,
)
elif config.TRAIN.LR_SCHEDULER.NAME == 'linear':
lr_scheduler = LinearLRScheduler(
optimizer,
t_initial=num_steps,
lr_min_rate=0.01,
warmup_lr_init=config.TRAIN.WARMUP_LR,
warmup_t=warmup_steps,
t_in_epochs=False,
)
elif config.TRAIN.LR_SCHEDULER.NAME == 'step':
lr_scheduler = StepLRScheduler(
optimizer,
decay_t=decay_steps,
decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE,
warmup_lr_init=config.TRAIN.WARMUP_LR,
warmup_t=warmup_steps,
t_in_epochs=False,
)
elif config.TRAIN.LR_SCHEDULER.NAME == 'multistep':
lr_scheduler = MultiStepLRScheduler(
optimizer,
milestones=multi_steps,
gamma=config.TRAIN.LR_SCHEDULER.GAMMA,
warmup_lr_init=config.TRAIN.WARMUP_LR,
warmup_t=warmup_steps,
t_in_epochs=False,
)
return lr_scheduler
class LinearLRScheduler(Scheduler):
def __init__(self,
optimizer: torch.optim.Optimizer,
t_initial: int,
lr_min_rate: float,
warmup_t=0,
warmup_lr_init=0.,
t_in_epochs=True,
noise_range_t=None,
noise_pct=0.67,
noise_std=1.0,
noise_seed=42,
initialize=True,
) -> None:
super().__init__(
optimizer, param_group_field="lr",
noise_range_t=noise_range_t, noise_pct=noise_pct,
noise_std=noise_std, noise_seed=noise_seed,
initialize=initialize)
self.t_initial = t_initial
self.lr_min_rate = lr_min_rate
self.warmup_t = warmup_t
self.warmup_lr_init = warmup_lr_init
self.t_in_epochs = t_in_epochs
if self.warmup_t:
self.warmup_steps = [(v - warmup_lr_init) /
self.warmup_t for v in self.base_values]
super().update_groups(self.warmup_lr_init)
else:
self.warmup_steps = [1 for _ in self.base_values]
def _get_lr(self, t):
if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else:
t = t - self.warmup_t
total_t = self.t_initial - self.warmup_t
lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t))
for v in self.base_values]
return lrs
def get_epoch_values(self, epoch: int):
if self.t_in_epochs:
return self._get_lr(epoch)
else:
return None
def get_update_values(self, num_updates: int):
if not self.t_in_epochs:
return self._get_lr(num_updates)
else:
return None
class MultiStepLRScheduler(Scheduler):
def __init__(self, optimizer: torch.optim.Optimizer,
milestones, gamma=0.1, warmup_t=0,
warmup_lr_init=0, t_in_epochs=True) -> None:
super().__init__(optimizer, param_group_field="lr")
self.milestones = milestones
self.gamma = gamma
self.warmup_t = warmup_t
self.warmup_lr_init = warmup_lr_init
self.t_in_epochs = t_in_epochs
if self.warmup_t:
self.warmup_steps = [(v - warmup_lr_init) /
self.warmup_t for v in self.base_values]
super().update_groups(self.warmup_lr_init)
else:
self.warmup_steps = [1 for _ in self.base_values]
assert self.warmup_t <= min(self.milestones)
def _get_lr(self, t):
if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else:
lrs = [v * (self.gamma ** bisect_right(self.milestones, t))
for v in self.base_values]
return lrs
def get_epoch_values(self, epoch: int):
if self.t_in_epochs:
return self._get_lr(epoch)
else:
return None
def get_update_values(self, num_updates: int):
if not self.t_in_epochs:
return self._get_lr(num_updates)
else:
return None
def setup_scaled_lr(config):
# linear scale the learning rate according to total batch size,
# may not be optimal
batch_size = config.DATA.BATCH_SIZE
world_size = dist.get_world_size()
denom_const = 512.0
accumulation_steps = config.TRAIN.ACCUMULATION_STEPS
linear_scaled_lr = config.TRAIN.BASE_LR * \
batch_size * world_size / denom_const
linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * \
batch_size * world_size / denom_const
linear_scaled_min_lr = config.TRAIN.MIN_LR * \
batch_size * world_size / denom_const
# gradient accumulation also need to scale the learning rate
if accumulation_steps > 1:
linear_scaled_lr = linear_scaled_lr * accumulation_steps
linear_scaled_warmup_lr = linear_scaled_warmup_lr * accumulation_steps
linear_scaled_min_lr = linear_scaled_min_lr * accumulation_steps
return linear_scaled_lr, linear_scaled_warmup_lr, linear_scaled_min_lr
|