tolgacangoz commited on
Commit
6c5aea6
·
verified ·
1 Parent(s): 40b66fb

Upload matryoshka.py

Browse files
Files changed (1) hide show
  1. matryoshka.py +5 -5
matryoshka.py CHANGED
@@ -1631,10 +1631,10 @@ class MatryoshkaFusedAttnProcessor1_0_or_2_0:
1631
  split_size = kv.shape[-1] // 2
1632
  key, value = torch.split(kv, split_size, dim=-1)
1633
 
1634
- if self_attention_output is None:
1635
- query = query.permute(0, 2, 1)
1636
- key = key.permute(0, 2, 1)
1637
- value = value.permute(0, 2, 1)
1638
 
1639
  if attn.norm_q is not None:
1640
  query = attn.norm_q(query)
@@ -1665,7 +1665,7 @@ class MatryoshkaFusedAttnProcessor1_0_or_2_0:
1665
  )
1666
 
1667
  hidden_states = hidden_states.to(query.dtype)
1668
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, height * width, channel)
1669
 
1670
  if self_attention_output is not None:
1671
  hidden_states = hidden_states + self_attention_output
 
1631
  split_size = kv.shape[-1] // 2
1632
  key, value = torch.split(kv, split_size, dim=-1)
1633
 
1634
+ # if self_attention_output is None:
1635
+ # query = query.permute(0, 2, 1)
1636
+ # key = key.permute(0, 2, 1)
1637
+ # value = value.permute(0, 2, 1)
1638
 
1639
  if attn.norm_q is not None:
1640
  query = attn.norm_q(query)
 
1665
  )
1666
 
1667
  hidden_states = hidden_states.to(query.dtype)
1668
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1669
 
1670
  if self_attention_output is not None:
1671
  hidden_states = hidden_states + self_attention_output