npvinHnivqn commited on
Commit
653b1fb
1 Parent(s): 18207f1

update bugs

Browse files
Files changed (1) hide show
  1. modeling_stablelm_epoch.py +16 -16
modeling_stablelm_epoch.py CHANGED
@@ -529,23 +529,23 @@ class DecoderLayer(nn.Module):
529
  )
530
  hidden_states = residual + hidden_states
531
 
532
- # # Cross Attention
533
- # residual = hidden_states
534
 
535
- # bsz, q_len, _ = hidden_states.size()
536
- # _, kv_len, _ = cross_states.size()
537
-
538
- # cross_attn_mask = torch.zeros((bsz, 1, kv_len, q_len), device=hidden_states.device)
539
- # hidden_states, cross_attn_weights, _ = self.cross_attn(
540
- # hidden_states=hidden_states,
541
- # cross_states=cross_states,
542
- # attention_mask=cross_attn_mask,
543
- # position_ids=position_ids,
544
- # past_key_value=past_key_value,
545
- # output_attentions=output_attentions,
546
- # use_cache=use_cache,
547
- # )
548
- # hidden_states = residual + hidden_states
549
 
550
  # Fully Connected
551
  residual = hidden_states
 
529
  )
530
  hidden_states = residual + hidden_states
531
 
532
+ # Cross Attention
533
+ residual = hidden_states
534
 
535
+ bsz, q_len, _ = hidden_states.size()
536
+ _, kv_len, _ = cross_states.size()
537
+
538
+ cross_attn_mask = torch.zeros((bsz, 1, kv_len, q_len), device=hidden_states.device)
539
+ hidden_states, cross_attn_weights, _ = self.cross_attn(
540
+ hidden_states=hidden_states,
541
+ cross_states=cross_states,
542
+ attention_mask=cross_attn_mask,
543
+ position_ids=position_ids,
544
+ past_key_value=past_key_value,
545
+ output_attentions=output_attentions,
546
+ use_cache=use_cache,
547
+ )
548
+ hidden_states = residual + hidden_states
549
 
550
  # Fully Connected
551
  residual = hidden_states