npvinHnivqn
commited on
Commit
•
18207f1
1
Parent(s):
224053d
bugs
Browse files- modeling_stablelm_epoch.py +16 -16
modeling_stablelm_epoch.py
CHANGED
@@ -529,23 +529,23 @@ class DecoderLayer(nn.Module):
|
|
529 |
)
|
530 |
hidden_states = residual + hidden_states
|
531 |
|
532 |
-
# Cross Attention
|
533 |
-
residual = hidden_states
|
534 |
|
535 |
-
bsz, q_len, _ = hidden_states.size()
|
536 |
-
_, kv_len, _ = cross_states.size()
|
537 |
-
|
538 |
-
cross_attn_mask = torch.zeros((bsz, 1, kv_len, q_len), device=hidden_states.device)
|
539 |
-
hidden_states, cross_attn_weights, _ = self.cross_attn(
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
)
|
548 |
-
hidden_states = residual
|
549 |
|
550 |
# Fully Connected
|
551 |
residual = hidden_states
|
|
|
529 |
)
|
530 |
hidden_states = residual + hidden_states
|
531 |
|
532 |
+
# # Cross Attention
|
533 |
+
# residual = hidden_states
|
534 |
|
535 |
+
# bsz, q_len, _ = hidden_states.size()
|
536 |
+
# _, kv_len, _ = cross_states.size()
|
537 |
+
|
538 |
+
# cross_attn_mask = torch.zeros((bsz, 1, kv_len, q_len), device=hidden_states.device)
|
539 |
+
# hidden_states, cross_attn_weights, _ = self.cross_attn(
|
540 |
+
# hidden_states=hidden_states,
|
541 |
+
# cross_states=cross_states,
|
542 |
+
# attention_mask=cross_attn_mask,
|
543 |
+
# position_ids=position_ids,
|
544 |
+
# past_key_value=past_key_value,
|
545 |
+
# output_attentions=output_attentions,
|
546 |
+
# use_cache=use_cache,
|
547 |
+
# )
|
548 |
+
# hidden_states = residual + hidden_states
|
549 |
|
550 |
# Fully Connected
|
551 |
residual = hidden_states
|