Upload matryoshka.py
Browse files- matryoshka.py +2 -1
matryoshka.py
CHANGED
@@ -1519,7 +1519,8 @@ class MatryoshkaTransformerBlock(nn.Module):
|
|
1519 |
|
1520 |
# attn_output_cond = attn_output_cond.permute(0, 2, 1).contiguous()
|
1521 |
attn_output_cond = self.proj_out(attn_output_cond)
|
1522 |
-
attn_output_cond = attn_output_cond.permute(0, 2, 1).reshape(batch_size, channels, *spatial_dims)
|
|
|
1523 |
hidden_states = hidden_states + attn_output_cond
|
1524 |
|
1525 |
if self.ff is not None:
|
|
|
1519 |
|
1520 |
# attn_output_cond = attn_output_cond.permute(0, 2, 1).contiguous()
|
1521 |
attn_output_cond = self.proj_out(attn_output_cond)
|
1522 |
+
# attn_output_cond = attn_output_cond.permute(0, 2, 1).reshape(batch_size, channels, *spatial_dims)
|
1523 |
+
attn_output_cond = attn_output_cond.reshape(batch_size, channels, *spatial_dims)
|
1524 |
hidden_states = hidden_states + attn_output_cond
|
1525 |
|
1526 |
if self.ff is not None:
|