File size: 2,744 Bytes
345ee20 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import time
import numpy as np
import torch
from easydict import EasyDict as edict
import warnings
from torch._six import inf
from core.utils import sync_print
# return if any inf/nan
# div norm by loss_scale, for 'real' norm
# if auto_clipper provided, compute max_norm using auto_clipper
# else, using give max_norm
def clip_grad_norm_(parameters, max_norm=1000000, norm_type=2, auto_clipper=None, loss_scale=1.0):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p[1].grad is not None, parameters))
if len(parameters) == 0: return None
max_norm = float(max_norm)
norm_type = float(norm_type)
if norm_type == inf:
total_norm = max(p.grad.data.abs().max() for p in parameters)
else:
total_norm = 0
for name,p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
total_norm = total_norm ** (1. / norm_type)
# check inf/nan
overflow_num = torch.zeros(1)
if np.isinf(total_norm) or np.isnan(total_norm):
overflow_num[0] = 1
torch.distributed.all_reduce.allreduce(overflow_num)
if overflow_num > 0:
for name,p in parameters:
p.grad.data.fill_(float('nan'))
sync_print('total_norm is inf({})/nan({}), skip clipping!!!'.format(np.isinf(total_norm), np.isnan(total_norm)))
return total_norm
# rescale the total_norm by loss_scale
total_norm /= loss_scale
# update auto_clipper, compute max_norm
if auto_clipper is not None:
max_norm = auto_clipper.update(total_norm)
# do clipping
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1:
# sync_print('clip_coef: {}'.format(clip_coef))
for _, p in parameters:
p.grad.data.mul_(clip_coef)
return total_norm
class ClipMeter(object):
def __init__(self, mom=None, thresh=None, min_max=False, mean=False, init=False):
self.thresh = thresh
self.mom = mom
self.min_max = min_max
self.mean = mean
self.val = 1.0
self.init = init
def get_mean(self):
return self.val
def get_clip_val(self):
if self.mean:
return self.get_mean()
else:
return self.get_mean() * (1+self.thresh)
def update(self, x):
if self.init:
self.val = x
self.init = False
mean = self.get_mean()
if self.min_max:
x = max(min(x, mean*(1+self.thresh)), mean*(1-self.thresh))
else:
x = min(x, mean*(1+self.thresh))
self.val = self.mom * self.val + (1-self.mom)*x
return self.get_clip_val()
|