Crystalcareai
commited on
Update modeling_gemmoe.py
Browse files- modeling_gemmoe.py +3 -2
modeling_gemmoe.py
CHANGED
@@ -243,8 +243,9 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
243 |
Returns:
|
244 |
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
245 |
"""
|
246 |
-
|
247 |
-
|
|
|
248 |
q_embed = (q * cos) + (rotate_half(q) * sin)
|
249 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
250 |
return q_embed, k_embed
|
|
|
243 |
Returns:
|
244 |
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
245 |
"""
|
246 |
+
seq_len, dim = q.shape[-2], q.shape[-1]
|
247 |
+
cos = cos[:seq_len].view(1, 1, seq_len, dim)
|
248 |
+
sin = sin[:seq_len].view(1, 1, seq_len, dim)
|
249 |
q_embed = (q * cos) + (rotate_half(q) * sin)
|
250 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
251 |
return q_embed, k_embed
|