Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import lightconv_cuda | |
import torch | |
import torch.nn.functional as F | |
from fairseq import utils | |
from fairseq.incremental_decoding_utils import with_incremental_state | |
from fairseq.modules.fairseq_dropout import FairseqDropout | |
from torch import nn | |
from torch.autograd import Function | |
class lightconvFunction(Function): | |
def forward(ctx, x, weights, padding_l): | |
ctx.padding_l = padding_l | |
outputs = lightconv_cuda.forward(x, weights, padding_l) | |
variables = [x, weights] | |
ctx.save_for_backward(*variables) | |
return outputs[0] | |
def backward(ctx, grad_output): | |
outputs = lightconv_cuda.backward( | |
grad_output.contiguous(), ctx.padding_l, *ctx.saved_tensors | |
) | |
grad_input, grad_weights = outputs | |
return grad_input, grad_weights, None | |
class LightconvLayer(nn.Module): | |
def __init__( | |
self, | |
input_size, | |
kernel_size=1, | |
padding_l=None, | |
weight_softmax=False, | |
num_heads=1, | |
weight_dropout=0.0, | |
bias=False, | |
): | |
super(LightconvLayer, self).__init__() | |
self.input_size = input_size | |
self.kernel_size = kernel_size | |
self.padding_l = padding_l | |
self.num_heads = num_heads | |
self.weight_softmax = weight_softmax | |
self.weight_dropout_module = FairseqDropout( | |
weight_dropout, module_name=self.__class__.__name__ | |
) | |
self.weight = nn.Parameter(torch.Tensor(num_heads, kernel_size)) | |
if bias: | |
self.bias = nn.Parameter(torch.Tensor(input_size)) | |
else: | |
self.bias = None | |
self.reset_parameters() | |
def upgrade_state_dict_named(self, state_dict, name): | |
prefix = name + "." if name != "" else "" | |
for k, v in state_dict.items(): | |
if k.endswith(prefix + "weight"): | |
if v.dim() == 3 and v.size(1) == 1: | |
state_dict[k] = v.squeeze(1) | |
def reset_parameters(self): | |
nn.init.xavier_uniform_(self.weight) | |
if self.bias is not None: | |
nn.init.constant_(self.bias, 0.0) | |
def forward(self, x, incremental_state=None): | |
# during inference time, incremental BMM is faster | |
if incremental_state is not None: | |
T, B, C = x.size() | |
K, H = self.kernel_size, self.num_heads | |
R = C // H | |
input_buffer = self._get_input_buffer(incremental_state) | |
if input_buffer is None: | |
input_buffer = x.new() | |
x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3) | |
if self.kernel_size > 1: | |
self._set_input_buffer( | |
incremental_state, x_unfold[:, :, :, -self.kernel_size + 1 :] | |
) | |
x_unfold = x_unfold.view(T * B * H, R, -1) | |
weight = self.weight | |
if self.weight_softmax: | |
weight = F.softmax(weight.float(), dim=1).type_as(weight) | |
weight = weight[:, -x_unfold.size(2) :] | |
K = weight.size(1) | |
weight = ( | |
weight.view(1, H, K) | |
.expand(T * B, H, K) | |
.contiguous() | |
.view(T * B * H, K, 1) | |
) | |
weight = self.weight_dropout_module(weight) | |
output = torch.bmm(x_unfold, weight) # T*B*H x R x 1 | |
output = output.view(T, B, C) | |
return output | |
# during training time, use CUDA kernel | |
else: | |
x = x.permute(1, 2, 0).contiguous() | |
weight = self.weight | |
if self.weight_softmax: | |
weight = F.softmax(self.weight, -1) | |
if self.weight_dropout_module.p: | |
weight = self.weight_dropout_module(weight) | |
return lightconvFunction.apply(x, weight, self.padding_l).permute(2, 0, 1) | |
def reorder_incremental_state(self, incremental_state, new_order): | |
input_buffer = self._get_input_buffer(incremental_state) | |
if input_buffer is not None: | |
input_buffer = input_buffer.index_select(1, new_order) | |
self._set_input_buffer(incremental_state, input_buffer) | |
def _get_input_buffer(self, incremental_state): | |
return utils.get_incremental_state(self, incremental_state, "input_buffer") | |
def _set_input_buffer(self, incremental_state, new_buffer): | |
return utils.set_incremental_state( | |
self, incremental_state, "input_buffer", new_buffer | |
) | |
def half(self): | |
return self._apply(lambda t: t.half() if t.is_floating_point() else t) | |