3v324v23's picture
update
8e64bfa
# Copyright (c) Microsoft, Inc. 2020
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# Zhou Bo
# Date: 01/15/2020
#
import math
from packaging import version
import torch
from torch.nn import LayerNorm
if version.Version(torch.__version__) >= version.Version('1.0.0'):
from torch import _softmax_backward_data as _softmax_backward_data
else:
from torch import softmax_backward_data as _softmax_backward_data
__all__ = ['StableDropout', 'MaskedLayerNorm', 'XSoftmax', 'ACT2FN', 'LayerNorm']
class XSoftmax(torch.autograd.Function):
""" Masked Softmax which is optimized for saving memory
Args:
input (:obj:`torch.tensor`): The input tensor that will apply softmax.
mask (:obj:`torch.IntTensor`): The mask matrix where 0 indicate that element will be ignored in the softmax caculation.
dim (int): The dimenssion that will apply softmax.
Example::
import torch
from DeBERTa.deberta import XSoftmax
# Make a tensor
x = torch.randn([4,20,100])
# Create a mask
mask = (x>0).int()
y = XSoftmax.apply(x, mask, dim=-1)
"""
@staticmethod
def forward(self, input, mask, dim):
"""
"""
self.dim = dim
if mask is None:
mask = torch.ones_like(input)
if version.Version(torch.__version__) >= version.Version('1.2.0a'):
rmask = ~(mask.bool())
else:
rmask = (1-mask).byte() # This line is not supported by Onnx tracing.
output = input.masked_fill(rmask, torch.finfo(input.dtype).min) # float('-inf')
output = torch.softmax(output, self.dim)
output.masked_fill_(rmask, 0)
self.save_for_backward(output)
return output
@staticmethod
def backward(self, grad_output):
"""
"""
output, = self.saved_tensors
if '1.11' in torch.__version__:
inputGrad = _softmax_backward_data(grad_output, output, self.dim, output.dtype)
else:
inputGrad = _softmax_backward_data(grad_output, output, self.dim, output.dtype)
return inputGrad, None, None
@staticmethod
def symbolic(g, self, mask, dim):
import torch.onnx.symbolic_helper as sym_help
from torch.onnx.symbolic_opset9 import masked_fill, softmax
mask_cast_value = g.op("Cast", mask, to_i=sym_help.cast_pytorch_to_onnx['Long'])
r_mask = g.op("Cast", g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value), to_i=sym_help.cast_pytorch_to_onnx['Byte'])
output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(float('-inf'))))
output = softmax(g, output, dim)
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
class DropoutContext(object):
def __init__(self):
self.dropout = 0
self.mask = None
self.scale = 1
self.reuse_mask = True
def get_mask(input, local_context):
if not isinstance(local_context, DropoutContext):
dropout = local_context
mask = None
else:
dropout = local_context.dropout
dropout *= local_context.scale
mask = local_context.mask if local_context.reuse_mask else None
if dropout>0 and mask is None:
if version.Version(torch.__version__) >= version.Version('1.2.0a'):
mask=(1-torch.empty_like(input).bernoulli_(1-dropout)).bool()
else:
mask=(1-torch.empty_like(input).bernoulli_(1-dropout)).byte()
if isinstance(local_context, DropoutContext):
if local_context.mask is None:
local_context.mask = mask
return mask, dropout
class XDropout(torch.autograd.Function):
@staticmethod
def forward(ctx, input, local_ctx):
mask, dropout = get_mask(input, local_ctx)
ctx.scale=1.0/(1-dropout)
if dropout>0:
ctx.save_for_backward(mask)
return input.masked_fill(mask, 0)*ctx.scale
else:
return input
@staticmethod
def backward(ctx, grad_output):
if ctx.scale > 1:
mask, = ctx.saved_tensors
return grad_output.masked_fill(mask, 0)*ctx.scale, None
else:
return grad_output, None
class StableDropout(torch.nn.Module):
""" Optimized dropout module for stabilizing the training
Args:
drop_prob (float): the dropout probabilities
"""
def __init__(self, drop_prob):
super().__init__()
self.drop_prob = drop_prob
self.count = 0
self.context_stack = None
def forward(self, x):
""" Call the module
Args:
x (:obj:`torch.tensor`): The input tensor to apply dropout
"""
if self.training and self.drop_prob>0:
return XDropout.apply(x, self.get_context())
return x
def clear_context(self):
self.count = 0
self.context_stack = None
def init_context(self, reuse_mask=True, scale = 1):
if self.context_stack is None:
self.context_stack = []
self.count = 0
for c in self.context_stack:
c.reuse_mask = reuse_mask
c.scale = scale
def get_context(self):
if self.context_stack is not None:
if self.count >= len(self.context_stack):
self.context_stack.append(DropoutContext())
ctx = self.context_stack[self.count]
ctx.dropout = self.drop_prob
self.count += 1
return ctx
else:
return self.drop_prob
def MaskedLayerNorm(layerNorm, input, mask = None):
""" Masked LayerNorm which will apply mask over the output of LayerNorm to avoid inaccurate updatings to the LayerNorm module.
Args:
layernorm (:obj:`~DeBERTa.deberta.LayerNorm`): LayerNorm module or function
input (:obj:`torch.tensor`): The input tensor
mask (:obj:`torch.IntTensor`): The mask to applied on the output of LayerNorm where `0` indicate the output of that element will be ignored, i.e. set to `0`
Example::
# Create a tensor b x n x d
x = torch.randn([1,10,100])
m = torch.tensor([[1,1,1,0,0,0,0,0,0,0]], dtype=torch.int)
LayerNorm = DeBERTa.deberta.LayerNorm(100)
y = MaskedLayerNorm(LayerNorm, x, m)
"""
output = layerNorm(input).to(input)
if mask is None:
return output
if mask.dim()!=input.dim():
if mask.dim()==4:
mask=mask.squeeze(1).squeeze(1)
mask = mask.unsqueeze(2)
mask = mask.to(output.dtype)
return output*mask
def gelu(x):
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def swish(x):
return x * torch.sigmoid(x)
def linear_act(x):
return x
def sequence_masking(x, mask, value=0, axis=None):
"""为序列条件mask的函数
mask: 形如(batch_size, seq_len)的0-1矩阵;
value: mask部分要被替换成的值,可以是'-inf'或'inf';
axis: 序列所在轴,默认为1;
"""
if mask is None:
return x
else:
x_dtype = x.dtype
if x_dtype == torch.bool:
x = x.to(torch.int32)
# if mask.dtype != x.dtype:
# mask = mask.to(x.dtype)
if value == '-inf':
value = -float('inf')
elif value == 'inf':
value = float('inf')
if axis is None:
axis = 1
elif axis < 0:
axis = x.dim() + axis
assert axis > 0, 'axis must be greater than 0'
if mask.dim() != x.dim():
mask = align(mask, [0, axis], x.dim())
# value = value.to(x.dtype)
x = x.masked_fill_(~mask.bool(), value) # * mask + mask.fill_(value)
if x_dtype == torch.bool:
x = x.to(torch.bool)
return x
def align(tensor, axes, ndim=None):
"""重新对齐tensor(批量版expand_dims)
axes:原来的第i维对齐新tensor的第axes[i]维;
ndim:新tensor的维度。
"""
assert len(axes) == tensor.dim()
assert ndim or min(axes) >= 0
ndim = ndim or max(axes) + 1
indices = [None] * ndim
for i in axes:
indices[i] = slice(None)
return tensor[indices]
ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish, "tanh": torch.tanh, "linear": linear_act, 'sigmoid': torch.sigmoid, 'silu': torch.nn.functional.silu}