File size: 5,724 Bytes
32287b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import math
from pprint import pformat
from typing import Tuple, List, Dict, Union
import torch.nn
import infinity.utils.dist as dist
def lr_wd_annealing(sche_type: str, optimizer, peak_lr, wd, wd_end, cur_it, wp_it, max_it, wp0=0.005, wpe=0.001):
"""Decay the learning rate with half-cycle cosine after warmup"""
wp_it = round(wp_it)
if cur_it < wp_it:
cur_lr = wp0 + (1-wp0) * cur_it / wp_it
pasd = (cur_it - wp_it) / (max_it-1 - wp_it) # [0, 1]
rest = 1 - pasd # [1, 0]
if sche_type == 'cos':
cur_lr = wpe + (1-wpe) * (0.5 + 0.5 * math.cos(math.pi * pasd))
elif sche_type == 'lin':
T = 0.15; max_rest = 1-T
if pasd < T: cur_lr = 1
else: cur_lr = wpe + (1-wpe) * rest / max_rest # 1 to wpe
elif sche_type == 'lin0':
T = 0.05; max_rest = 1-T
if pasd < T: cur_lr = 1
else: cur_lr = wpe + (1-wpe) * rest / max_rest
elif sche_type == 'lin00':
cur_lr = wpe + (1-wpe) * rest
elif sche_type.startswith('lin'):
T = float(sche_type[3:]); max_rest = 1-T
wpe_mid = wpe + (1-wpe) * max_rest
wpe_mid = (1 + wpe_mid) / 2
if pasd < T: cur_lr = 1 + (wpe_mid-1) * pasd / T
else: cur_lr = wpe + (wpe_mid-wpe) * rest / max_rest
elif sche_type == 'exp':
T = 0.15; max_rest = 1-T
if pasd < T: cur_lr = 1
expo = (pasd-T) / max_rest * math.log(wpe)
cur_lr = math.exp(expo)
raise NotImplementedError(f'unknown sche_type {sche_type}')
cur_lr *= peak_lr
pasd = cur_it / (max_it-1)
cur_wd = wd_end + (wd - wd_end) * (0.5 + 0.5 * math.cos(math.pi * pasd))
inf = 1e6
min_lr, max_lr = inf, -1
min_wd, max_wd = inf, -1
for param_group in optimizer.param_groups:
param_group['lr'] = cur_lr * param_group.get('lr_sc', 1) # 'lr_sc' could be assigned
max_lr = max(max_lr, param_group['lr'])
min_lr = min(min_lr, param_group['lr'])
param_group['weight_decay'] = cur_wd * param_group.get('wd_sc', 1)
max_wd = max(max_wd, param_group['weight_decay'])
if param_group['weight_decay'] > 0:
min_wd = min(min_wd, param_group['weight_decay'])
if min_lr == inf: min_lr = -1
if min_wd == inf: min_wd = -1
return min_lr, max_lr, min_wd, max_wd
def filter_params(model, ndim_dict, nowd_keys=(), lr_scale=0.0) -> Tuple[
List[str], List[torch.nn.Parameter], List[Dict[str, Union[torch.nn.Parameter, float]]]
with_lr_scale = hasattr(model, 'get_layer_id_and_scale_exp') and 0 < lr_scale <= 1
print(f'[get_param_groups][lr decay] with_lr_scale={with_lr_scale}, lr_scale={lr_scale}')
para_groups, para_groups_dbg = {}, {}
names, paras = [], []
names_no_grad = []
count, numel = 0, 0
for name, para in model.named_parameters():
name = name.replace('_fsdp_wrapped_module.', '')
if not para.requires_grad:
continue # frozen weights
count += 1
numel += para.numel()
if ndim_dict.get(name, 2) == 1 or name.endswith('bias') or any(k in name for k in nowd_keys):
cur_wd_sc, group_name = 0., 'ND'
# elif any(k in name for k in small_wd_keys):
# cur_wd_sc, group_name = small_wd, 'small_decay'
cur_wd_sc, group_name = 1., 'D'
if with_lr_scale:
layer_id, scale_exp = model.get_layer_id_and_scale_exp(name)
group_name = f'layer{layer_id}_' + group_name
cur_lr_sc = lr_scale ** scale_exp
dbg = f'[layer {layer_id}][sc = {lr_scale} ** {scale_exp}]'
cur_lr_sc = 1.
dbg = f'[no scale]'
if group_name not in para_groups:
para_groups[group_name] = {'params': [], 'wd_sc': cur_wd_sc, 'lr_sc': cur_lr_sc}
para_groups_dbg[group_name] = {'params': [], 'wd_sc': cur_wd_sc, 'lr_sc': dbg}
for g in para_groups_dbg.values():
g['params'] = pformat(', '.join(g['params']), width=200)
print(f'[get_param_groups] param_groups = \n{pformat(para_groups_dbg, indent=2, width=240)}\n')
for rk in range(dist.get_world_size()):
if dist.get_rank() == rk:
print(f'[get_param_groups][rank{dist.get_rank()}] {type(model).__name__=} {count=}, {numel=}', flush=True, force=True)
assert len(names_no_grad) == 0, f'[get_param_groups] names_no_grad = \n{pformat(names_no_grad, indent=2, width=240)}\n'
del ndim_dict
return names, paras, list(para_groups.values())
def plot():
import matplotlib.pyplot as plt
import torch.nn as nn
from torch.optim import SGD
# for sche in ('lin', 'lin0', 'lin00', 'lin0.5', 'lin0.75'):
for sche in ('lin0', ):
op = SGD(nn.Linear(3, 4).parameters(), lr=1e-3)
it, lr = [], []
iters = 500
wp_it, max_it = 1 * iters, 10 * iters
for cur_it in range(max_it):
lr.append(lr_wd_annealing(sche, op, 0.1, 1e-5, 1e-5, cur_it, wp_it, max_it, wpe=0.3)[0])
plt.plot(it, lr, 'b', label=sche)
plt.xlabel('it'), plt.ylabel('lr')
if __name__ == '__main__':