Update modeling_internlm2.py
Browse files- modeling_internlm2.py +2 -2
modeling_internlm2.py
CHANGED
@@ -199,8 +199,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
|
199 |
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
200 |
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
201 |
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
202 |
-
cos = cos.unsqueeze(0).unsqueeze(0).expand(len(position_ids), -1, -1, -1)
|
203 |
-
sin = sin.unsqueeze(0).unsqueeze(0).expand(len(position_ids), -1, -1, -1)
|
204 |
if q.size(2) == 1:
|
205 |
q_embed = (q * cos[:, :, -1, :]) + (rotate_half(q) * sin[:, :, -1, :])
|
206 |
else:
|
|
|
199 |
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
200 |
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
201 |
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
202 |
+
cos = cos.unsqueeze(0).unsqueeze(0).expand(-1, -1, -1, -1) #(len(position_ids), -1, -1, -1)
|
203 |
+
sin = sin.unsqueeze(0).unsqueeze(0).expand(-1, -1, -1, -1) #(len(position_ids), -1, -1, -1)
|
204 |
if q.size(2) == 1:
|
205 |
q_embed = (q * cos[:, :, -1, :]) + (rotate_half(q) * sin[:, :, -1, :])
|
206 |
else:
|