Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
# | |
import torch | |
class ScalarBias(torch.autograd.Function): | |
""" | |
Adds a vector of scalars, used in self-attention mechanism to allow | |
the model to optionally attend to this vector instead of the past | |
""" | |
def forward(ctx, input, dim, bias_init): | |
size = list(input.size()) | |
size[dim] += 1 | |
output = input.new(*size).fill_(bias_init) | |
output.narrow(dim, 1, size[dim] - 1).copy_(input) | |
ctx.dim = dim | |
return output | |
def backward(ctx, grad): | |
return grad.narrow(ctx.dim, 1, grad.size(ctx.dim) - 1), None, None | |
def scalar_bias(input, dim, bias_init=0): | |
return ScalarBias.apply(input, dim, bias_init) | |