Crystalcareai commited on
Commit
dbe3df7
·
verified ·
1 Parent(s): c7c5a3d

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +3 -2
modeling_gemmoe.py CHANGED
@@ -243,8 +243,9 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
243
  Returns:
244
  `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
245
  """
246
- cos = cos.unsqueeze(unsqueeze_dim)
247
- sin = sin.unsqueeze(unsqueeze_dim)
 
248
  q_embed = (q * cos) + (rotate_half(q) * sin)
249
  k_embed = (k * cos) + (rotate_half(k) * sin)
250
  return q_embed, k_embed
 
243
  Returns:
244
  `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
245
  """
246
+ seq_len, dim = q.shape[-2], q.shape[-1]
247
+ cos = cos[:seq_len].view(1, 1, seq_len, dim)
248
+ sin = sin[:seq_len].view(1, 1, seq_len, dim)
249
  q_embed = (q * cos) + (rotate_half(q) * sin)
250
  k_embed = (k * cos) + (rotate_half(k) * sin)
251
  return q_embed, k_embed