Spaces:
Sleeping
Sleeping
class LossScaler: | |
def __init__(self, scale=1): | |
self.cur_scale = scale | |
# `params` is a list / generator of torch.Variable | |
def has_overflow(self, params): | |
return False | |
# `x` is a torch.Tensor | |
def _has_inf_or_nan(x): | |
return False | |
# `overflow` is boolean indicating whether we overflowed in gradient | |
def update_scale(self, overflow): | |
pass | |
def loss_scale(self): | |
return self.cur_scale | |
def scale_gradient(self, module, grad_in, grad_out): | |
return tuple(self.loss_scale * g for g in grad_in) | |
def backward(self, loss): | |
scaled_loss = loss*self.loss_scale | |
scaled_loss.backward() | |
class DynamicLossScaler: | |
def __init__(self, | |
init_scale=2**32, | |
scale_factor=2., | |
scale_window=1000): | |
self.cur_scale = init_scale | |
self.cur_iter = 0 | |
self.last_overflow_iter = -1 | |
self.scale_factor = scale_factor | |
self.scale_window = scale_window | |
# `params` is a list / generator of torch.Variable | |
def has_overflow(self, params): | |
for p in params: | |
if p.grad is not None and DynamicLossScaler._has_inf_or_nan(p.grad.data): | |
return True | |
return False | |
# `x` is a torch.Tensor | |
def _has_inf_or_nan(x): | |
cpu_sum = float(x.float().sum()) | |
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: | |
return True | |
return False | |
# `overflow` is boolean indicating whether we overflowed in gradient | |
def update_scale(self, overflow): | |
if overflow: | |
#self.cur_scale /= self.scale_factor | |
self.cur_scale = max(self.cur_scale/self.scale_factor, 1) | |
self.last_overflow_iter = self.cur_iter | |
else: | |
if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0: | |
self.cur_scale *= self.scale_factor | |
# self.cur_scale = 1 | |
self.cur_iter += 1 | |
def loss_scale(self): | |
return self.cur_scale | |
def scale_gradient(self, module, grad_in, grad_out): | |
return tuple(self.loss_scale * g for g in grad_in) | |
def backward(self, loss): | |
scaled_loss = loss*self.loss_scale | |
scaled_loss.backward() | |