Commit
·
8ea9075
1
Parent(s):
d7176c6
Upload HyenaDNAForCausalLM
Browse files- modeling_hyena.py +2 -4
modeling_hyena.py
CHANGED
@@ -60,10 +60,8 @@ class HyenaPositionalEmbedding(nn.Module):
|
|
60 |
w = 2 * math.pi * t_rescaled / self.seq_len # 1, L, 1
|
61 |
|
62 |
f = torch.linspace(1e-4, bands - 1, bands)[None, None]
|
63 |
-
|
64 |
-
|
65 |
-
z = torch.exp(-1j * f * w)
|
66 |
-
z = torch.cat([t, z.real, z.imag], dim=-1)
|
67 |
# TODO Set z's LR to lr_pos_emb
|
68 |
self.z = nn.Parameter(z, requires_grad=True)
|
69 |
self.register_buffer("t", t)
|
|
|
60 |
w = 2 * math.pi * t_rescaled / self.seq_len # 1, L, 1
|
61 |
|
62 |
f = torch.linspace(1e-4, bands - 1, bands)[None, None]
|
63 |
+
|
64 |
+
z = torch.cat([t, torch.cos(-f * w), torch.sin(-f * w)], dim=-1)
|
|
|
|
|
65 |
# TODO Set z's LR to lr_pos_emb
|
66 |
self.z = nn.Parameter(z, requires_grad=True)
|
67 |
self.register_buffer("t", t)
|