ForkedHulk2 / core /optimizers /adam_clip.py
tuandunghcmut's picture
Upload folder using huggingface_hub
345ee20 verified
raw
history blame
4.48 kB
import itertools
from torch.nn.utils import clip_grad_norm_
from torch.optim import Adam, AdamW
class AdamWithClip(Adam):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, amsgrad=False, max_norm=None, norm_type=2):
super(AdamWithClip, self).__init__(params, lr, betas, eps, weight_decay, amsgrad)
self.max_norm = max_norm
self.norm_type = norm_type
def step(self, closure=None):
if self.max_norm is not None:
for group in self.param_groups:
clip_grad_norm_(group['params'], self.max_norm, self.norm_type)
super(AdamWithClip, self).step(closure)
class AdamWWithClip(AdamW):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, amsgrad=False, max_norm=None, norm_type=2):
super(AdamWWithClip, self).__init__(params, lr, betas, eps, weight_decay, amsgrad)
self.max_norm = max_norm
self.norm_type = norm_type
def step(self, closure=None):
if self.max_norm is not None:
for group in self.param_groups:
clip_grad_norm_(group['params'], self.max_norm, self.norm_type)
super(AdamWWithClip, self).step(closure)
# class AdamWWithClipDev(AdamW):
# def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
# weight_decay=1e-2, amsgrad=False, clip_norm=None, norm_type=2):
# super(AdamWWithClipDev, self).__init__(params, lr, betas, eps, weight_decay, amsgrad)
# self.clip_norm = clip_norm
# self.norm_type = norm_type
#
# def step(self, closure=None):
# if self.clip_norm is not None:
# all_params = itertools.chain(*[x["params"] for x in self.param_groups])
# clip_grad_norm_(all_params, self.clip_norm, self.norm_type)
#
# super(AdamWWithClipDev, self).step(closure)
class AdamWWithBackboneClipDev(AdamW):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=1e-2, amsgrad=False, clip_norm=None, norm_type=2):
super(AdamWWithBackboneClipDev, self).__init__(params, lr, betas, eps, weight_decay, amsgrad)
self.clip_norm = clip_norm
self.norm_type = norm_type
def step(self, closure=None):
if self.clip_norm is not None:
all_params = itertools.chain(*[x["params"] for x in self.param_groups if x['params'][0].backbone_specific ])
clip_grad_norm_(all_params, self.clip_norm, self.norm_type)
super(AdamWWithBackboneClipDev, self).step(closure)
class AdamWWithClipDev(AdamW):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=1e-2, amsgrad=False, clip_norm=None, norm_type=2):
super(AdamWWithClipDev, self).__init__(params, lr, betas, eps, weight_decay, amsgrad)
self.clip_norm = clip_norm
self.norm_type = norm_type
self._split_param_groups = None
self.reset_split_param_groups()
def reset_split_param_groups(self):
if self.clip_norm is not None:
backbone_param, neck_param, decoder_param, task_param = [], [], [], []
for x in self.param_groups:
if x["params"][0].backbone_specific:
backbone_param.append(x["params"])
elif x["params"][0].neck_specific:
neck_param.append(x["params"])
elif x["params"][0].decoder_specific:
decoder_param.append(x["params"])
elif x["params"][0].task_specific:
task_param.append(x["params"])
self._split_param_groups = [_g for _g in [backbone_param,
neck_param,
decoder_param,
task_param] if len(_g) > 0]
print(f">>> reset_split_param_groups, backbone_param: {len(backbone_param)}"
f", neck_param: {len(neck_param)}, decoder_param: {len(decoder_param)}"
f", task_param: {len(task_param)}")
def step(self, closure=None):
if self.clip_norm is not None:
for _g in self._split_param_groups:
all_params = itertools.chain(*_g)
clip_grad_norm_(all_params, self.clip_norm, self.norm_type)
super(AdamWWithClipDev, self).step(closure)