tgn-playground / model /time_encoding.py
ashu316's picture
Upload 14 files
41aae2b verified
raw
history blame contribute delete
775 Bytes
import torch
import numpy as np
class TimeEncode(torch.nn.Module):
# Time Encoding proposed by TGAT
def __init__(self, dimension):
super(TimeEncode, self).__init__()
self.dimension = dimension
self.w = torch.nn.Linear(1, dimension)
self.w.weight = torch.nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, dimension)))
.float().reshape(dimension, -1))
self.w.bias = torch.nn.Parameter(torch.zeros(dimension).float())
def forward(self, t):
# t has shape [batch_size, seq_len]
# Add dimension at the end to apply linear layer --> [batch_size, seq_len, 1]
t = t.unsqueeze(dim=2)
# output has shape [batch_size, seq_len, dimension]
output = torch.cos(self.w(t))
return output