Update modeling_baichuan.py
Browse files- modeling_baichuan.py +2 -2
modeling_baichuan.py
CHANGED
@@ -114,8 +114,8 @@ class RotaryEmbedding(torch.nn.Module):
|
|
114 |
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
|
115 |
freqs = torch.outer(t, self.inv_freq)
|
116 |
emb = torch.cat((freqs, freqs), dim=-1)
|
117 |
-
self.register_buffer("cos_cached", emb.cos().to(
|
118 |
-
self.register_buffer("sin_cached", emb.sin().to(
|
119 |
|
120 |
def forward(self, x, seq_len):
|
121 |
# x: [bs, num_attention_heads, seq_len, head_size]
|
|
|
114 |
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
|
115 |
freqs = torch.outer(t, self.inv_freq)
|
116 |
emb = torch.cat((freqs, freqs), dim=-1)
|
117 |
+
self.register_buffer("cos_cached", emb.cos().to(self.inv_freq.device)[None, None, :, :], persistent=False)
|
118 |
+
self.register_buffer("sin_cached", emb.sin().to(self.inv_freq.device)[None, None, :, :], persistent=False)
|
119 |
|
120 |
def forward(self, x, seq_len):
|
121 |
# x: [bs, num_attention_heads, seq_len, head_size]
|