Spaces:
Sleeping
Sleeping
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) | |
# | |
# See ../../../../LICENSE for clarification regarding multiple authors | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import logging | |
import math | |
import random | |
from typing import Optional | |
from typing import Tuple | |
from typing import Union | |
import torch | |
import torch.nn as nn | |
from torch import Tensor | |
class DoubleSwishFunction(torch.autograd.Function): | |
""" | |
double_swish(x) = x * torch.sigmoid(x-1) | |
This is a definition, originally motivated by its close numerical | |
similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). | |
Memory-efficient derivative computation: | |
double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) | |
double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). | |
Now, s'(x) = s(x) * (1-s(x)). | |
double_swish'(x) = x * s'(x) + s(x). | |
= x * s(x) * (1-s(x)) + s(x). | |
= double_swish(x) * (1-s(x)) + s(x) | |
... so we just need to remember s(x) but not x itself. | |
""" | |
def forward(ctx, x: Tensor) -> Tensor: | |
requires_grad = x.requires_grad | |
x_dtype = x.dtype | |
if x.dtype == torch.float16: | |
x = x.to(torch.float32) | |
s = torch.sigmoid(x - 1.0) | |
y = x * s | |
if requires_grad: | |
deriv = y * (1 - s) + s | |
# notes on derivative of x * sigmoid(x - 1): | |
# https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 | |
# min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund | |
# max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound. | |
# the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which | |
# floors), should be expectation-preserving. | |
floor = -0.043637 | |
ceil = 1.2 | |
d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like( | |
deriv | |
) | |
if __name__ == "__main__": | |
# for self-testing only. | |
assert d_scaled.min() >= 0.0 | |
assert d_scaled.max() < 256.0 | |
d_int = d_scaled.to(torch.uint8) | |
ctx.save_for_backward(d_int) | |
if x.dtype == torch.float16 or torch.is_autocast_enabled(): | |
y = y.to(torch.float16) | |
return y | |
def backward(ctx, y_grad: Tensor) -> Tensor: | |
(d,) = ctx.saved_tensors | |
# the same constants as used in forward pass. | |
floor = -0.043637 | |
ceil = 1.2 | |
d = d * ((ceil - floor) / 255.0) + floor | |
return y_grad * d | |
class DoubleSwish(torch.nn.Module): | |
def forward(self, x: Tensor) -> Tensor: | |
"""Return double-swish activation function which is an approximation to Swish(Swish(x)), | |
that we approximate closely with x * sigmoid(x-1). | |
""" | |
if torch.jit.is_scripting() or torch.jit.is_tracing(): | |
return x * torch.sigmoid(x - 1.0) | |
return DoubleSwishFunction.apply(x) | |
class ActivationBalancerFunction(torch.autograd.Function): | |
def forward( | |
ctx, | |
x: Tensor, | |
scale_factor: Tensor, | |
sign_factor: Optional[Tensor], | |
channel_dim: int, | |
) -> Tensor: | |
if channel_dim < 0: | |
channel_dim += x.ndim | |
ctx.channel_dim = channel_dim | |
xgt0 = x > 0 | |
if sign_factor is None: | |
ctx.save_for_backward(xgt0, scale_factor) | |
else: | |
ctx.save_for_backward(xgt0, scale_factor, sign_factor) | |
return x | |
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: | |
if len(ctx.saved_tensors) == 3: | |
xgt0, scale_factor, sign_factor = ctx.saved_tensors | |
for _ in range(ctx.channel_dim, x_grad.ndim - 1): | |
scale_factor = scale_factor.unsqueeze(-1) | |
sign_factor = sign_factor.unsqueeze(-1) | |
factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) | |
else: | |
xgt0, scale_factor = ctx.saved_tensors | |
for _ in range(ctx.channel_dim, x_grad.ndim - 1): | |
scale_factor = scale_factor.unsqueeze(-1) | |
factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5) | |
neg_delta_grad = x_grad.abs() * factor | |
return ( | |
x_grad - neg_delta_grad, | |
None, | |
None, | |
None, | |
) | |
def _compute_scale_factor( | |
x: Tensor, | |
channel_dim: int, | |
min_abs: float, | |
max_abs: float, | |
gain_factor: float, | |
max_factor: float, | |
) -> Tensor: | |
if channel_dim < 0: | |
channel_dim += x.ndim | |
sum_dims = [d for d in range(x.ndim) if d != channel_dim] | |
x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32) | |
if min_abs == 0.0: | |
below_threshold = 0.0 | |
else: | |
# below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if | |
# x_abs)_mean , min_abs. | |
below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp( | |
min=0, max=max_factor | |
) | |
above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp( | |
min=0, max=max_factor | |
) | |
return below_threshold - above_threshold | |
def _compute_sign_factor( | |
x: Tensor, | |
channel_dim: int, | |
min_positive: float, | |
max_positive: float, | |
gain_factor: float, | |
max_factor: float, | |
) -> Tensor: | |
if channel_dim < 0: | |
channel_dim += x.ndim | |
sum_dims = [d for d in range(x.ndim) if d != channel_dim] | |
proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims) | |
if min_positive == 0.0: | |
factor1 = 0.0 | |
else: | |
# 0 if proportion_positive >= min_positive, else can be | |
# as large as max_factor. | |
factor1 = ( | |
(min_positive - proportion_positive) * (gain_factor / min_positive) | |
).clamp_(min=0, max=max_factor) | |
if max_positive == 1.0: | |
factor2 = 0.0 | |
else: | |
# 0 if self.proportion_positive <= max_positive, else can be | |
# as large as -max_factor. | |
factor2 = ( | |
(proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive)) | |
).clamp_(min=0, max=max_factor) | |
sign_factor = factor1 - factor2 | |
# require min_positive != 0 or max_positive != 1: | |
assert not isinstance(sign_factor, float) | |
return sign_factor | |
class ActivationBalancer(torch.nn.Module): | |
""" | |
Modifies the backpropped derivatives of a function to try to encourage, for | |
each channel, that it is positive at least a proportion `threshold` of the | |
time. It does this by multiplying negative derivative values by up to | |
(1+max_factor), and positive derivative values by up to (1-max_factor), | |
interpolated from 1 at the threshold to those extremal values when none | |
of the inputs are positive. | |
Args: | |
num_channels: the number of channels | |
channel_dim: the dimension/axis corresponding to the channel, e.g. | |
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. | |
min_positive: the minimum, per channel, of the proportion of the time | |
that (x > 0), below which we start to modify the derivatives. | |
max_positive: the maximum, per channel, of the proportion of the time | |
that (x > 0), above which we start to modify the derivatives. | |
max_factor: the maximum factor by which we modify the derivatives for | |
either the sign constraint or the magnitude constraint; | |
e.g. with max_factor=0.02, the the derivatives would be multiplied by | |
values in the range [0.98..1.02]. | |
sign_gain_factor: determines the 'gain' with which we increase the | |
change in gradient once the constraints on min_positive and max_positive | |
are violated. | |
scale_gain_factor: determines the 'gain' with which we increase the | |
change in gradient once the constraints on min_abs and max_abs | |
are violated. | |
min_abs: the minimum average-absolute-value difference from the mean | |
value per channel, which we allow, before we start to modify | |
the derivatives to prevent this. | |
max_abs: the maximum average-absolute-value difference from the mean | |
value per channel, which we allow, before we start to modify | |
the derivatives to prevent this. | |
min_prob: determines the minimum probability with which we modify the | |
gradients for the {min,max}_positive and {min,max}_abs constraints, | |
on each forward(). This is done randomly to prevent all layers | |
from doing it at the same time. Early in training we may use | |
higher probabilities than this; it will decay to this value. | |
""" | |
def __init__( | |
self, | |
num_channels: int, | |
channel_dim: int, | |
min_positive: float = 0.05, | |
max_positive: float = 0.95, | |
max_factor: float = 0.04, | |
sign_gain_factor: float = 0.01, | |
scale_gain_factor: float = 0.02, | |
min_abs: float = 0.2, | |
max_abs: float = 100.0, | |
min_prob: float = 0.1, | |
): | |
super(ActivationBalancer, self).__init__() | |
self.num_channels = num_channels | |
self.channel_dim = channel_dim | |
self.min_positive = min_positive | |
self.max_positive = max_positive | |
self.max_factor = max_factor | |
self.min_abs = min_abs | |
self.max_abs = max_abs | |
self.min_prob = min_prob | |
self.sign_gain_factor = sign_gain_factor | |
self.scale_gain_factor = scale_gain_factor | |
# count measures how many times the forward() function has been called. | |
# We occasionally sync this to a tensor called `count`, that exists to | |
# make sure it is synced to disk when we load and save the model. | |
self.cpu_count = 0 | |
self.register_buffer("count", torch.tensor(0, dtype=torch.int64)) | |
def forward(self, x: Tensor) -> Tensor: | |
if torch.jit.is_scripting() or not x.requires_grad or torch.jit.is_tracing(): | |
return _no_op(x) | |
count = self.cpu_count | |
self.cpu_count += 1 | |
if random.random() < 0.01: | |
# Occasionally sync self.cpu_count with self.count. | |
# count affects the decay of 'prob'. don't do this on every iter, | |
# because syncing with the GPU is slow. | |
self.cpu_count = max(self.cpu_count, self.count.item()) | |
self.count.fill_(self.cpu_count) | |
# the prob of doing some work exponentially decreases from 0.5 till it hits | |
# a floor at min_prob (==0.1, by default) | |
prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0))) | |
if random.random() < prob: | |
sign_gain_factor = 0.5 | |
if self.min_positive != 0.0 or self.max_positive != 1.0: | |
sign_factor = _compute_sign_factor( | |
x, | |
self.channel_dim, | |
self.min_positive, | |
self.max_positive, | |
gain_factor=self.sign_gain_factor / prob, | |
max_factor=self.max_factor, | |
) | |
else: | |
sign_factor = None | |
scale_factor = _compute_scale_factor( | |
x.detach(), | |
self.channel_dim, | |
min_abs=self.min_abs, | |
max_abs=self.max_abs, | |
gain_factor=self.scale_gain_factor / prob, | |
max_factor=self.max_factor, | |
) | |
return ActivationBalancerFunction.apply( | |
x, | |
scale_factor, | |
sign_factor, | |
self.channel_dim, | |
) | |
else: | |
return _no_op(x) | |
def BalancedDoubleSwish( | |
d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25 | |
) -> nn.Sequential: | |
""" | |
ActivationBalancer -> DoubleSwish | |
""" | |
balancer = ActivationBalancer( | |
d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob | |
) | |
return nn.Sequential( | |
balancer, | |
DoubleSwish(), | |
) | |