Spaces:
Running
on
T4
Running
on
T4
""" | |
Various positional encodings for the transformer. | |
""" | |
import math | |
import torch | |
from torch import nn | |
def PE1d_sincos(seq_length, dim): | |
""" | |
:param d_model: dimension of the model | |
:param length: length of positions | |
:return: length*d_model position matrix | |
""" | |
if dim % 2 != 0: | |
raise ValueError("Cannot use sin/cos positional encoding with " | |
"odd dim (got dim={:d})".format(dim)) | |
pe = torch.zeros(seq_length, dim) | |
position = torch.arange(0, seq_length).unsqueeze(1) | |
div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) * | |
-(math.log(10000.0) / dim))) | |
pe[:, 0::2] = torch.sin(position.float() * div_term) | |
pe[:, 1::2] = torch.cos(position.float() * div_term) | |
return pe.unsqueeze(1) | |
class PositionEmbedding(nn.Module): | |
""" | |
Absolute pos embedding (standard), learned. | |
""" | |
def __init__(self, seq_length, dim, dropout, grad=False): | |
super().__init__() | |
self.embed = nn.Parameter(data=PE1d_sincos(seq_length, dim), requires_grad=grad) | |
self.dropout = nn.Dropout(p=dropout) | |
def forward(self, x): | |
# x.shape: bs, seq_len, feat_dim | |
l = x.shape[1] | |
x = x.permute(1, 0, 2) + self.embed[:l].expand(x.permute(1, 0, 2).shape) | |
x = self.dropout(x.permute(1, 0, 2)) | |
return x | |