# -------------------------------------------------------- # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Based on fairseq code bases # https://github.com/facebookresearch/fairseq # -------------------------------------------------------- """ Modified from https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/learned_positional_embedding.py 1. Add clamping if the input length exceeds the max-source-tokens """ from typing import Dict, Optional import torch import torch.nn as nn import torch.nn.functional as F from fairseq import utils from torch import Tensor class LearnedPositionalEmbedding(nn.Embedding): """ This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting based on padding_idx or by setting padding_idx to None and ensuring that the appropriate position ids are passed to the forward function. """ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int): super().__init__(num_embeddings, embedding_dim, padding_idx) self.onnx_trace = False if self.padding_idx is not None: self.max_positions = self.num_embeddings - self.padding_idx - 1 else: self.max_positions = self.num_embeddings def forward( self, input: Tensor, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, positions: Optional[Tensor] = None, ): """Input is expected to be of size [bsz x seqlen].""" assert (positions is None) or ( self.padding_idx is None ), "If positions is pre-computed then padding_idx should not be set." if positions is None: if incremental_state is not None: # positions is the same for every token when decoding a single step # Without the int() cast, it doesn't work in some cases when exporting to ONNX positions = torch.zeros( (1, 1), device=input.device, dtype=input.dtype ).fill_(int(self.padding_idx + input.size(1))) else: positions = utils.make_positions( input, self.padding_idx, onnx_trace=self.onnx_trace ) positions = torch.clamp(positions, max=self.padding_idx + self.max_positions) return F.embedding( positions, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse, )