Upload matryoshka.py
Browse files- matryoshka.py +5 -5
matryoshka.py
CHANGED
@@ -1631,10 +1631,10 @@ class MatryoshkaFusedAttnProcessor1_0_or_2_0:
|
|
1631 |
split_size = kv.shape[-1] // 2
|
1632 |
key, value = torch.split(kv, split_size, dim=-1)
|
1633 |
|
1634 |
-
if self_attention_output is None:
|
1635 |
-
|
1636 |
-
key = key.permute(0, 2, 1)
|
1637 |
-
value = value.permute(0, 2, 1)
|
1638 |
|
1639 |
if attn.norm_q is not None:
|
1640 |
query = attn.norm_q(query)
|
@@ -1665,7 +1665,7 @@ class MatryoshkaFusedAttnProcessor1_0_or_2_0:
|
|
1665 |
)
|
1666 |
|
1667 |
hidden_states = hidden_states.to(query.dtype)
|
1668 |
-
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size,
|
1669 |
|
1670 |
if self_attention_output is not None:
|
1671 |
hidden_states = hidden_states + self_attention_output
|
|
|
1631 |
split_size = kv.shape[-1] // 2
|
1632 |
key, value = torch.split(kv, split_size, dim=-1)
|
1633 |
|
1634 |
+
# if self_attention_output is None:
|
1635 |
+
# query = query.permute(0, 2, 1)
|
1636 |
+
# key = key.permute(0, 2, 1)
|
1637 |
+
# value = value.permute(0, 2, 1)
|
1638 |
|
1639 |
if attn.norm_q is not None:
|
1640 |
query = attn.norm_q(query)
|
|
|
1665 |
)
|
1666 |
|
1667 |
hidden_states = hidden_states.to(query.dtype)
|
1668 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1669 |
|
1670 |
if self_attention_output is not None:
|
1671 |
hidden_states = hidden_states + self_attention_output
|