pllava-7b-demo / utils /optimizer.py
cathyxl
added
f239efc
""" Optimizer Factory w/ Custom Weight Decay
Hacked together by / Copyright 2020 Ross Wightman
"""
import re
import torch
from torch import optim as optim
from utils.distributed import is_main_process
import logging
logger = logging.getLogger(__name__)
try:
from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
has_apex = True
except ImportError:
has_apex = False
def add_weight_decay(model, weight_decay, no_decay_list=(), filter_bias_and_bn=True):
named_param_tuples = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue # frozen weights
if filter_bias_and_bn and (len(param.shape) == 1 or name.endswith(".bias")):
named_param_tuples.append([name, param, 0])
elif name in no_decay_list:
named_param_tuples.append([name, param, 0])
else:
named_param_tuples.append([name, param, weight_decay])
return named_param_tuples
def add_different_lr(named_param_tuples_or_model, diff_lr_names, diff_lr, default_lr):
"""use lr=diff_lr for modules named found in diff_lr_names,
otherwise use lr=default_lr
Args:
named_param_tuples_or_model: List([name, param, weight_decay]), or nn.Module
diff_lr_names: List(str)
diff_lr: float
default_lr: float
Returns:
named_param_tuples_with_lr: List([name, param, weight_decay, lr])
"""
named_param_tuples_with_lr = []
logger.info(f"diff_names: {diff_lr_names}, diff_lr: {diff_lr}")
for name, p, wd in named_param_tuples_or_model:
use_diff_lr = False
for diff_name in diff_lr_names:
# if diff_name in name:
if re.search(diff_name, name) is not None:
logger.info(f"param {name} use different_lr: {diff_lr}")
use_diff_lr = True
break
named_param_tuples_with_lr.append(
[name, p, wd, diff_lr if use_diff_lr else default_lr]
)
if is_main_process():
for name, _, wd, diff_lr in named_param_tuples_with_lr:
logger.info(f"param {name}: wd: {wd}, lr: {diff_lr}")
return named_param_tuples_with_lr
def create_optimizer_params_group(named_param_tuples_with_lr):
"""named_param_tuples_with_lr: List([name, param, weight_decay, lr])"""
group = {}
for name, p, wd, lr in named_param_tuples_with_lr:
if wd not in group:
group[wd] = {}
if lr not in group[wd]:
group[wd][lr] = []
group[wd][lr].append(p)
optimizer_params_group = []
for wd, lr_groups in group.items():
for lr, p in lr_groups.items():
optimizer_params_group.append(dict(
params=p,
weight_decay=wd,
lr=lr
))
logger.info(f"optimizer -- lr={lr} wd={wd} len(p)={len(p)}")
return optimizer_params_group
def create_optimizer(args, model, filter_bias_and_bn=True):
opt_lower = args.opt.lower()
weight_decay = args.weight_decay
# check for modules that requires different lr
if hasattr(args, "different_lr") and args.different_lr.enable:
diff_lr_module_names = args.different_lr.module_names
diff_lr = args.different_lr.lr
else:
diff_lr_module_names = []
diff_lr = None
no_decay = {}
if hasattr(model, 'no_weight_decay'):
no_decay = model.no_weight_decay()
named_param_tuples = add_weight_decay(
model, weight_decay, no_decay, filter_bias_and_bn)
named_param_tuples = add_different_lr(
named_param_tuples, diff_lr_module_names, diff_lr, args.lr)
parameters = create_optimizer_params_group(named_param_tuples)
if 'fused' in opt_lower:
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
opt_args = dict(lr=args.lr, weight_decay=weight_decay)
if hasattr(args, 'opt_eps') and args.opt_eps is not None:
opt_args['eps'] = args.opt_eps
if hasattr(args, 'opt_betas') and args.opt_betas is not None:
opt_args['betas'] = args.opt_betas
if hasattr(args, 'opt_args') and args.opt_args is not None:
opt_args.update(args.opt_args)
opt_split = opt_lower.split('_')
opt_lower = opt_split[-1]
if opt_lower == 'sgd' or opt_lower == 'nesterov':
opt_args.pop('eps', None)
optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
elif opt_lower == 'momentum':
opt_args.pop('eps', None)
optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
elif opt_lower == 'adam':
optimizer = optim.Adam(parameters, **opt_args)
elif opt_lower == 'adamw':
optimizer = optim.AdamW(parameters, **opt_args)
else:
assert False and "Invalid optimizer"
raise ValueError
return optimizer