tolgacangoz commited on
Commit
a8e0acf
·
verified ·
1 Parent(s): 727ac2d

Upload matryoshka.py

Browse files
Files changed (1) hide show
  1. matryoshka.py +2 -1
matryoshka.py CHANGED
@@ -1519,7 +1519,8 @@ class MatryoshkaTransformerBlock(nn.Module):
1519
 
1520
  # attn_output_cond = attn_output_cond.permute(0, 2, 1).contiguous()
1521
  attn_output_cond = self.proj_out(attn_output_cond)
1522
- attn_output_cond = attn_output_cond.permute(0, 2, 1).reshape(batch_size, channels, *spatial_dims)
 
1523
  hidden_states = hidden_states + attn_output_cond
1524
 
1525
  if self.ff is not None:
 
1519
 
1520
  # attn_output_cond = attn_output_cond.permute(0, 2, 1).contiguous()
1521
  attn_output_cond = self.proj_out(attn_output_cond)
1522
+ # attn_output_cond = attn_output_cond.permute(0, 2, 1).reshape(batch_size, channels, *spatial_dims)
1523
+ attn_output_cond = attn_output_cond.reshape(batch_size, channels, *spatial_dims)
1524
  hidden_states = hidden_states + attn_output_cond
1525
 
1526
  if self.ff is not None: