|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
from packaging import version |
|
import torch |
|
from torch.nn import LayerNorm |
|
|
|
if version.Version(torch.__version__) >= version.Version('1.0.0'): |
|
from torch import _softmax_backward_data as _softmax_backward_data |
|
else: |
|
from torch import softmax_backward_data as _softmax_backward_data |
|
|
|
__all__ = ['StableDropout', 'MaskedLayerNorm', 'XSoftmax', 'ACT2FN', 'LayerNorm'] |
|
|
|
|
|
class XSoftmax(torch.autograd.Function): |
|
""" Masked Softmax which is optimized for saving memory |
|
|
|
Args: |
|
|
|
input (:obj:`torch.tensor`): The input tensor that will apply softmax. |
|
mask (:obj:`torch.IntTensor`): The mask matrix where 0 indicate that element will be ignored in the softmax caculation. |
|
dim (int): The dimenssion that will apply softmax. |
|
|
|
Example:: |
|
|
|
import torch |
|
from DeBERTa.deberta import XSoftmax |
|
# Make a tensor |
|
x = torch.randn([4,20,100]) |
|
# Create a mask |
|
mask = (x>0).int() |
|
y = XSoftmax.apply(x, mask, dim=-1) |
|
|
|
""" |
|
|
|
@staticmethod |
|
def forward(self, input, mask, dim): |
|
""" |
|
""" |
|
|
|
self.dim = dim |
|
if mask is None: |
|
mask = torch.ones_like(input) |
|
if version.Version(torch.__version__) >= version.Version('1.2.0a'): |
|
rmask = ~(mask.bool()) |
|
else: |
|
rmask = (1-mask).byte() |
|
|
|
output = input.masked_fill(rmask, torch.finfo(input.dtype).min) |
|
output = torch.softmax(output, self.dim) |
|
output.masked_fill_(rmask, 0) |
|
self.save_for_backward(output) |
|
return output |
|
|
|
@staticmethod |
|
def backward(self, grad_output): |
|
""" |
|
""" |
|
|
|
output, = self.saved_tensors |
|
if '1.11' in torch.__version__: |
|
inputGrad = _softmax_backward_data(grad_output, output, self.dim, output.dtype) |
|
else: |
|
inputGrad = _softmax_backward_data(grad_output, output, self.dim, output.dtype) |
|
return inputGrad, None, None |
|
|
|
@staticmethod |
|
def symbolic(g, self, mask, dim): |
|
import torch.onnx.symbolic_helper as sym_help |
|
from torch.onnx.symbolic_opset9 import masked_fill, softmax |
|
|
|
mask_cast_value = g.op("Cast", mask, to_i=sym_help.cast_pytorch_to_onnx['Long']) |
|
r_mask = g.op("Cast", g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value), to_i=sym_help.cast_pytorch_to_onnx['Byte']) |
|
output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(float('-inf')))) |
|
output = softmax(g, output, dim) |
|
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8))) |
|
|
|
class DropoutContext(object): |
|
def __init__(self): |
|
self.dropout = 0 |
|
self.mask = None |
|
self.scale = 1 |
|
self.reuse_mask = True |
|
|
|
def get_mask(input, local_context): |
|
if not isinstance(local_context, DropoutContext): |
|
dropout = local_context |
|
mask = None |
|
else: |
|
dropout = local_context.dropout |
|
dropout *= local_context.scale |
|
mask = local_context.mask if local_context.reuse_mask else None |
|
|
|
if dropout>0 and mask is None: |
|
if version.Version(torch.__version__) >= version.Version('1.2.0a'): |
|
mask=(1-torch.empty_like(input).bernoulli_(1-dropout)).bool() |
|
else: |
|
mask=(1-torch.empty_like(input).bernoulli_(1-dropout)).byte() |
|
|
|
if isinstance(local_context, DropoutContext): |
|
if local_context.mask is None: |
|
local_context.mask = mask |
|
|
|
return mask, dropout |
|
|
|
|
|
class XDropout(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, input, local_ctx): |
|
mask, dropout = get_mask(input, local_ctx) |
|
ctx.scale=1.0/(1-dropout) |
|
if dropout>0: |
|
ctx.save_for_backward(mask) |
|
return input.masked_fill(mask, 0)*ctx.scale |
|
else: |
|
return input |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
if ctx.scale > 1: |
|
mask, = ctx.saved_tensors |
|
return grad_output.masked_fill(mask, 0)*ctx.scale, None |
|
else: |
|
return grad_output, None |
|
|
|
class StableDropout(torch.nn.Module): |
|
""" Optimized dropout module for stabilizing the training |
|
|
|
Args: |
|
|
|
drop_prob (float): the dropout probabilities |
|
|
|
""" |
|
|
|
def __init__(self, drop_prob): |
|
super().__init__() |
|
self.drop_prob = drop_prob |
|
self.count = 0 |
|
self.context_stack = None |
|
|
|
def forward(self, x): |
|
""" Call the module |
|
|
|
Args: |
|
|
|
x (:obj:`torch.tensor`): The input tensor to apply dropout |
|
|
|
|
|
""" |
|
if self.training and self.drop_prob>0: |
|
return XDropout.apply(x, self.get_context()) |
|
return x |
|
|
|
def clear_context(self): |
|
self.count = 0 |
|
self.context_stack = None |
|
|
|
def init_context(self, reuse_mask=True, scale = 1): |
|
if self.context_stack is None: |
|
self.context_stack = [] |
|
self.count = 0 |
|
for c in self.context_stack: |
|
c.reuse_mask = reuse_mask |
|
c.scale = scale |
|
|
|
def get_context(self): |
|
if self.context_stack is not None: |
|
if self.count >= len(self.context_stack): |
|
self.context_stack.append(DropoutContext()) |
|
ctx = self.context_stack[self.count] |
|
ctx.dropout = self.drop_prob |
|
self.count += 1 |
|
return ctx |
|
else: |
|
return self.drop_prob |
|
|
|
def MaskedLayerNorm(layerNorm, input, mask = None): |
|
""" Masked LayerNorm which will apply mask over the output of LayerNorm to avoid inaccurate updatings to the LayerNorm module. |
|
|
|
Args: |
|
layernorm (:obj:`~DeBERTa.deberta.LayerNorm`): LayerNorm module or function |
|
input (:obj:`torch.tensor`): The input tensor |
|
mask (:obj:`torch.IntTensor`): The mask to applied on the output of LayerNorm where `0` indicate the output of that element will be ignored, i.e. set to `0` |
|
|
|
Example:: |
|
|
|
# Create a tensor b x n x d |
|
x = torch.randn([1,10,100]) |
|
m = torch.tensor([[1,1,1,0,0,0,0,0,0,0]], dtype=torch.int) |
|
LayerNorm = DeBERTa.deberta.LayerNorm(100) |
|
y = MaskedLayerNorm(LayerNorm, x, m) |
|
|
|
""" |
|
output = layerNorm(input).to(input) |
|
if mask is None: |
|
return output |
|
if mask.dim()!=input.dim(): |
|
if mask.dim()==4: |
|
mask=mask.squeeze(1).squeeze(1) |
|
mask = mask.unsqueeze(2) |
|
mask = mask.to(output.dtype) |
|
return output*mask |
|
|
|
def gelu(x): |
|
"""Implementation of the gelu activation function. |
|
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): |
|
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) |
|
""" |
|
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) |
|
|
|
|
|
def swish(x): |
|
return x * torch.sigmoid(x) |
|
|
|
def linear_act(x): |
|
return x |
|
|
|
def sequence_masking(x, mask, value=0, axis=None): |
|
"""为序列条件mask的函数 |
|
mask: 形如(batch_size, seq_len)的0-1矩阵; |
|
value: mask部分要被替换成的值,可以是'-inf'或'inf'; |
|
axis: 序列所在轴,默认为1; |
|
""" |
|
if mask is None: |
|
return x |
|
else: |
|
x_dtype = x.dtype |
|
if x_dtype == torch.bool: |
|
x = x.to(torch.int32) |
|
|
|
|
|
if value == '-inf': |
|
value = -float('inf') |
|
elif value == 'inf': |
|
value = float('inf') |
|
if axis is None: |
|
axis = 1 |
|
elif axis < 0: |
|
axis = x.dim() + axis |
|
assert axis > 0, 'axis must be greater than 0' |
|
if mask.dim() != x.dim(): |
|
mask = align(mask, [0, axis], x.dim()) |
|
|
|
x = x.masked_fill_(~mask.bool(), value) |
|
if x_dtype == torch.bool: |
|
x = x.to(torch.bool) |
|
return x |
|
|
|
def align(tensor, axes, ndim=None): |
|
"""重新对齐tensor(批量版expand_dims) |
|
axes:原来的第i维对齐新tensor的第axes[i]维; |
|
ndim:新tensor的维度。 |
|
""" |
|
assert len(axes) == tensor.dim() |
|
assert ndim or min(axes) >= 0 |
|
ndim = ndim or max(axes) + 1 |
|
indices = [None] * ndim |
|
for i in axes: |
|
indices[i] = slice(None) |
|
return tensor[indices] |
|
|
|
ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish, "tanh": torch.tanh, "linear": linear_act, 'sigmoid': torch.sigmoid, 'silu': torch.nn.functional.silu} |
|
|
|
|
|
|