added option to skip mid block

#5
by NatanBagrov - opened
Files changed (1) hide show
  1. pipeline.py +28 -12
pipeline.py CHANGED
@@ -52,6 +52,19 @@ def custom_sort_order(obj):
52
  return {ResnetBlock2D: 0, Transformer2DModel: 1, FlexibleTransformer2DModel: 1}.get(obj.__class__)
53
 
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  class FlexibleUNet2DConditionModel(UNet2DConditionModel, ModelMixin):
56
  configurations = FlexibleUnetConfigurations
57
 
@@ -105,18 +118,21 @@ class FlexibleUNet2DConditionModel(UNet2DConditionModel, ModelMixin):
105
  mid_block_add_upsample = self.configurations.get("add_upsample_mid_block")
106
  mid_num_attentions = self.configurations.get("mid_num_attentions")
107
  mid_num_resnets = self.configurations.get("mid_num_resnets")
108
-
109
- self.mid_block = FlexibleUNetMidBlock2DCrossAttn(in_channels=down_blocks_out_channels[-1],
110
- temb_channels=temb_dim,
111
- resnet_act_fn=resnet_act_fn,
112
- resnet_eps=resnet_eps,
113
- cross_attention_dim=cross_attention_dim,
114
- num_attention_heads=num_attention_heads,
115
- num_resnets=mid_num_resnets,
116
- num_attentions=mid_num_attentions,
117
- mix_block_in_forward=mix_block_in_forward,
118
- add_upsample=mid_block_add_upsample
119
- )
 
 
 
120
 
121
  ###############
122
  # Up blocks #
 
52
  return {ResnetBlock2D: 0, Transformer2DModel: 1, FlexibleTransformer2DModel: 1}.get(obj.__class__)
53
 
54
 
55
+ class FlexibleIdentityBlock(nn.Module):
56
+ def forward(
57
+ self,
58
+ hidden_states: torch.FloatTensor,
59
+ temb: Optional[torch.FloatTensor] = None,
60
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
61
+ attention_mask: Optional[torch.FloatTensor] = None,
62
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
63
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
64
+ ):
65
+ return hidden_states
66
+
67
+
68
  class FlexibleUNet2DConditionModel(UNet2DConditionModel, ModelMixin):
69
  configurations = FlexibleUnetConfigurations
70
 
 
118
  mid_block_add_upsample = self.configurations.get("add_upsample_mid_block")
119
  mid_num_attentions = self.configurations.get("mid_num_attentions")
120
  mid_num_resnets = self.configurations.get("mid_num_resnets")
121
+
122
+ if mid_num_resnets == mid_num_attentions == 0:
123
+ self.mid_block = FlexibleIdentityBlock()
124
+ else:
125
+ self.mid_block = FlexibleUNetMidBlock2DCrossAttn(in_channels=down_blocks_out_channels[-1],
126
+ temb_channels=temb_dim,
127
+ resnet_act_fn=resnet_act_fn,
128
+ resnet_eps=resnet_eps,
129
+ cross_attention_dim=cross_attention_dim,
130
+ num_attention_heads=num_attention_heads,
131
+ num_resnets=mid_num_resnets,
132
+ num_attentions=mid_num_attentions,
133
+ mix_block_in_forward=mix_block_in_forward,
134
+ add_upsample=mid_block_add_upsample
135
+ )
136
 
137
  ###############
138
  # Up blocks #