osanseviero's picture
Add repo
fc67275
raw
history blame
4.8 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import lightconv_cuda
import torch
import torch.nn.functional as F
from fairseq import utils
from fairseq.incremental_decoding_utils import with_incremental_state
from fairseq.modules.fairseq_dropout import FairseqDropout
from torch import nn
from torch.autograd import Function
class lightconvFunction(Function):
@staticmethod
def forward(ctx, x, weights, padding_l):
ctx.padding_l = padding_l
outputs = lightconv_cuda.forward(x, weights, padding_l)
variables = [x, weights]
ctx.save_for_backward(*variables)
return outputs[0]
@staticmethod
def backward(ctx, grad_output):
outputs = lightconv_cuda.backward(
grad_output.contiguous(), ctx.padding_l, *ctx.saved_tensors
)
grad_input, grad_weights = outputs
return grad_input, grad_weights, None
@with_incremental_state
class LightconvLayer(nn.Module):
def __init__(
self,
input_size,
kernel_size=1,
padding_l=None,
weight_softmax=False,
num_heads=1,
weight_dropout=0.0,
bias=False,
):
super(LightconvLayer, self).__init__()
self.input_size = input_size
self.kernel_size = kernel_size
self.padding_l = padding_l
self.num_heads = num_heads
self.weight_softmax = weight_softmax
self.weight_dropout_module = FairseqDropout(
weight_dropout, module_name=self.__class__.__name__
)
self.weight = nn.Parameter(torch.Tensor(num_heads, kernel_size))
if bias:
self.bias = nn.Parameter(torch.Tensor(input_size))
else:
self.bias = None
self.reset_parameters()
def upgrade_state_dict_named(self, state_dict, name):
prefix = name + "." if name != "" else ""
for k, v in state_dict.items():
if k.endswith(prefix + "weight"):
if v.dim() == 3 and v.size(1) == 1:
state_dict[k] = v.squeeze(1)
def reset_parameters(self):
nn.init.xavier_uniform_(self.weight)
if self.bias is not None:
nn.init.constant_(self.bias, 0.0)
def forward(self, x, incremental_state=None):
# during inference time, incremental BMM is faster
if incremental_state is not None:
T, B, C = x.size()
K, H = self.kernel_size, self.num_heads
R = C // H
input_buffer = self._get_input_buffer(incremental_state)
if input_buffer is None:
input_buffer = x.new()
x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3)
if self.kernel_size > 1:
self._set_input_buffer(
incremental_state, x_unfold[:, :, :, -self.kernel_size + 1 :]
)
x_unfold = x_unfold.view(T * B * H, R, -1)
weight = self.weight
if self.weight_softmax:
weight = F.softmax(weight.float(), dim=1).type_as(weight)
weight = weight[:, -x_unfold.size(2) :]
K = weight.size(1)
weight = (
weight.view(1, H, K)
.expand(T * B, H, K)
.contiguous()
.view(T * B * H, K, 1)
)
weight = self.weight_dropout_module(weight)
output = torch.bmm(x_unfold, weight) # T*B*H x R x 1
output = output.view(T, B, C)
return output
# during training time, use CUDA kernel
else:
x = x.permute(1, 2, 0).contiguous()
weight = self.weight
if self.weight_softmax:
weight = F.softmax(self.weight, -1)
if self.weight_dropout_module.p:
weight = self.weight_dropout_module(weight)
return lightconvFunction.apply(x, weight, self.padding_l).permute(2, 0, 1)
def reorder_incremental_state(self, incremental_state, new_order):
input_buffer = self._get_input_buffer(incremental_state)
if input_buffer is not None:
input_buffer = input_buffer.index_select(1, new_order)
self._set_input_buffer(incremental_state, input_buffer)
def _get_input_buffer(self, incremental_state):
return utils.get_incremental_state(self, incremental_state, "input_buffer")
def _set_input_buffer(self, incremental_state, new_buffer):
return utils.set_incremental_state(
self, incremental_state, "input_buffer", new_buffer
)
def half(self):
return self._apply(lambda t: t.half() if t.is_floating_point() else t)