katuni4ka commited on
Commit
0d670f1
·
verified ·
1 Parent(s): 2209f5d

Update modeling_baichuan.py

Browse files
Files changed (1) hide show
  1. modeling_baichuan.py +2 -2
modeling_baichuan.py CHANGED
@@ -114,8 +114,8 @@ class RotaryEmbedding(torch.nn.Module):
114
  t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
115
  freqs = torch.outer(t, self.inv_freq)
116
  emb = torch.cat((freqs, freqs), dim=-1)
117
- self.register_buffer("cos_cached", emb.cos().to(dtype)[None, None, :, :], persistent=False)
118
- self.register_buffer("sin_cached", emb.sin().to(dtype)[None, None, :, :], persistent=False)
119
 
120
  def forward(self, x, seq_len):
121
  # x: [bs, num_attention_heads, seq_len, head_size]
 
114
  t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
115
  freqs = torch.outer(t, self.inv_freq)
116
  emb = torch.cat((freqs, freqs), dim=-1)
117
+ self.register_buffer("cos_cached", emb.cos().to(self.inv_freq.device)[None, None, :, :], persistent=False)
118
+ self.register_buffer("sin_cached", emb.sin().to(self.inv_freq.device)[None, None, :, :], persistent=False)
119
 
120
  def forward(self, x, seq_len):
121
  # x: [bs, num_attention_heads, seq_len, head_size]