# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import lightconv_cuda import torch import torch.nn.functional as F from fairseq import utils from fairseq.incremental_decoding_utils import with_incremental_state from fairseq.modules.fairseq_dropout import FairseqDropout from torch import nn from torch.autograd import Function class lightconvFunction(Function): @staticmethod def forward(ctx, x, weights, padding_l): ctx.padding_l = padding_l outputs = lightconv_cuda.forward(x, weights, padding_l) variables = [x, weights] ctx.save_for_backward(*variables) return outputs[0] @staticmethod def backward(ctx, grad_output): outputs = lightconv_cuda.backward( grad_output.contiguous(), ctx.padding_l, *ctx.saved_tensors ) grad_input, grad_weights = outputs return grad_input, grad_weights, None @with_incremental_state class LightconvLayer(nn.Module): def __init__( self, input_size, kernel_size=1, padding_l=None, weight_softmax=False, num_heads=1, weight_dropout=0.0, bias=False, ): super(LightconvLayer, self).__init__() self.input_size = input_size self.kernel_size = kernel_size self.padding_l = padding_l self.num_heads = num_heads self.weight_softmax = weight_softmax self.weight_dropout_module = FairseqDropout( weight_dropout, module_name=self.__class__.__name__ ) self.weight = nn.Parameter(torch.Tensor(num_heads, kernel_size)) if bias: self.bias = nn.Parameter(torch.Tensor(input_size)) else: self.bias = None self.reset_parameters() def upgrade_state_dict_named(self, state_dict, name): prefix = name + "." if name != "" else "" for k, v in state_dict.items(): if k.endswith(prefix + "weight"): if v.dim() == 3 and v.size(1) == 1: state_dict[k] = v.squeeze(1) def reset_parameters(self): nn.init.xavier_uniform_(self.weight) if self.bias is not None: nn.init.constant_(self.bias, 0.0) def forward(self, x, incremental_state=None): # during inference time, incremental BMM is faster if incremental_state is not None: T, B, C = x.size() K, H = self.kernel_size, self.num_heads R = C // H input_buffer = self._get_input_buffer(incremental_state) if input_buffer is None: input_buffer = x.new() x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3) if self.kernel_size > 1: self._set_input_buffer( incremental_state, x_unfold[:, :, :, -self.kernel_size + 1 :] ) x_unfold = x_unfold.view(T * B * H, R, -1) weight = self.weight if self.weight_softmax: weight = F.softmax(weight.float(), dim=1).type_as(weight) weight = weight[:, -x_unfold.size(2) :] K = weight.size(1) weight = ( weight.view(1, H, K) .expand(T * B, H, K) .contiguous() .view(T * B * H, K, 1) ) weight = self.weight_dropout_module(weight) output = torch.bmm(x_unfold, weight) # T*B*H x R x 1 output = output.view(T, B, C) return output # during training time, use CUDA kernel else: x = x.permute(1, 2, 0).contiguous() weight = self.weight if self.weight_softmax: weight = F.softmax(self.weight, -1) if self.weight_dropout_module.p: weight = self.weight_dropout_module(weight) return lightconvFunction.apply(x, weight, self.padding_l).permute(2, 0, 1) def reorder_incremental_state(self, incremental_state, new_order): input_buffer = self._get_input_buffer(incremental_state) if input_buffer is not None: input_buffer = input_buffer.index_select(1, new_order) self._set_input_buffer(incremental_state, input_buffer) def _get_input_buffer(self, incremental_state): return utils.get_incremental_state(self, incremental_state, "input_buffer") def _set_input_buffer(self, incremental_state, new_buffer): return utils.set_incremental_state( self, incremental_state, "input_buffer", new_buffer ) def half(self): return self._apply(lambda t: t.half() if t.is_floating_point() else t)