import itertools from torch.nn.utils import clip_grad_norm_ from torch.optim import Adam, AdamW class AdamWithClip(Adam): def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, max_norm=None, norm_type=2): super(AdamWithClip, self).__init__(params, lr, betas, eps, weight_decay, amsgrad) self.max_norm = max_norm self.norm_type = norm_type def step(self, closure=None): if self.max_norm is not None: for group in self.param_groups: clip_grad_norm_(group['params'], self.max_norm, self.norm_type) super(AdamWithClip, self).step(closure) class AdamWWithClip(AdamW): def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, max_norm=None, norm_type=2): super(AdamWWithClip, self).__init__(params, lr, betas, eps, weight_decay, amsgrad) self.max_norm = max_norm self.norm_type = norm_type def step(self, closure=None): if self.max_norm is not None: for group in self.param_groups: clip_grad_norm_(group['params'], self.max_norm, self.norm_type) super(AdamWWithClip, self).step(closure) # class AdamWWithClipDev(AdamW): # def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, # weight_decay=1e-2, amsgrad=False, clip_norm=None, norm_type=2): # super(AdamWWithClipDev, self).__init__(params, lr, betas, eps, weight_decay, amsgrad) # self.clip_norm = clip_norm # self.norm_type = norm_type # # def step(self, closure=None): # if self.clip_norm is not None: # all_params = itertools.chain(*[x["params"] for x in self.param_groups]) # clip_grad_norm_(all_params, self.clip_norm, self.norm_type) # # super(AdamWWithClipDev, self).step(closure) class AdamWWithBackboneClipDev(AdamW): def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, clip_norm=None, norm_type=2): super(AdamWWithBackboneClipDev, self).__init__(params, lr, betas, eps, weight_decay, amsgrad) self.clip_norm = clip_norm self.norm_type = norm_type def step(self, closure=None): if self.clip_norm is not None: all_params = itertools.chain(*[x["params"] for x in self.param_groups if x['params'][0].backbone_specific ]) clip_grad_norm_(all_params, self.clip_norm, self.norm_type) super(AdamWWithBackboneClipDev, self).step(closure) class AdamWWithClipDev(AdamW): def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, clip_norm=None, norm_type=2): super(AdamWWithClipDev, self).__init__(params, lr, betas, eps, weight_decay, amsgrad) self.clip_norm = clip_norm self.norm_type = norm_type self._split_param_groups = None self.reset_split_param_groups() def reset_split_param_groups(self): if self.clip_norm is not None: backbone_param, neck_param, decoder_param, task_param = [], [], [], [] for x in self.param_groups: if x["params"][0].backbone_specific: backbone_param.append(x["params"]) elif x["params"][0].neck_specific: neck_param.append(x["params"]) elif x["params"][0].decoder_specific: decoder_param.append(x["params"]) elif x["params"][0].task_specific: task_param.append(x["params"]) self._split_param_groups = [_g for _g in [backbone_param, neck_param, decoder_param, task_param] if len(_g) > 0] print(f">>> reset_split_param_groups, backbone_param: {len(backbone_param)}" f", neck_param: {len(neck_param)}, decoder_param: {len(decoder_param)}" f", task_param: {len(task_param)}") def step(self, closure=None): if self.clip_norm is not None: for _g in self._split_param_groups: all_params = itertools.chain(*_g) clip_grad_norm_(all_params, self.clip_norm, self.norm_type) super(AdamWWithClipDev, self).step(closure)