ascust commited on
Commit
8fe92ea
·
verified ·
1 Parent(s): fc26b91

Upload folder using huggingface_hub

Browse files
ummdit_ds1_small_singlenorm_v5.py ADDED
@@ -0,0 +1,681 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # apply pos emb to downsampled and upsampled feats
2
+ # add bias and scale to blockwise AdaIN params
3
+ # subattn to subsampled feat
4
+ # block list [4, 16, 4]
5
+
6
+ from typing import Any, Dict, List, Optional, Tuple, Union
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ from diffusers.models.transformers import SD3Transformer2DModel
13
+ from diffusers.configuration_utils import register_to_config
14
+ # from diffusers.models.attention import JointTransformerBlock
15
+ from diffusers.utils import is_torch_version, logging
16
+ from diffusers.models.embeddings import PatchEmbed, get_2d_sincos_pos_embed
17
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
18
+ from diffusers.models.normalization import AdaLayerNormSingle
19
+
20
+ from diffusers.models.attention_processor import Attention, JointAttnProcessor2_0
21
+ from diffusers.models.normalization import SD35AdaLayerNormZeroX
22
+ from diffusers.models.attention import FeedForward, _chunked_feed_forward
23
+
24
+
25
+ from einops import rearrange
26
+
27
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
+
29
+ def cropped_pos_embed(pos_embed, height, width, patch_size=1, pos_embed_max_size=96):
30
+ """Crops positional embeddings for SD3 compatibility."""
31
+ if pos_embed_max_size is None:
32
+ raise ValueError("`pos_embed_max_size` must be set for cropping.")
33
+
34
+ height = height // patch_size
35
+ width = width // patch_size
36
+ if height > pos_embed_max_size:
37
+ raise ValueError(
38
+ f"Height ({height}) cannot be greater than `pos_embed_max_size`: {pos_embed_max_size}."
39
+ )
40
+ if width > pos_embed_max_size:
41
+ raise ValueError(
42
+ f"Width ({width}) cannot be greater than `pos_embed_max_size`: {pos_embed_max_size}."
43
+ )
44
+
45
+ top = (pos_embed_max_size - height) // 2
46
+ left = (pos_embed_max_size - width) // 2
47
+ spatial_pos_embed = pos_embed.reshape(1, pos_embed_max_size, pos_embed_max_size, -1)
48
+ spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
49
+ # spatial_pos_embed = torch.permute(spatial_pos_embed, [0, 3, 1, 2])
50
+ # spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
51
+ return spatial_pos_embed
52
+
53
+
54
+ class JointTransformerBlockSingleNorm(nn.Module):
55
+ r"""
56
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
57
+
58
+ Reference: https://huggingface.co/papers/2403.03206
59
+
60
+ Parameters:
61
+ dim (`int`): The number of channels in the input and output.
62
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
63
+ attention_head_dim (`int`): The number of channels in each head.
64
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
65
+ processing of `context` conditions.
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ dim: int,
71
+ num_attention_heads: int,
72
+ attention_head_dim: int,
73
+ context_pre_only: bool = False,
74
+ qk_norm: Optional[str] = None,
75
+ use_dual_attention: bool = False,
76
+ subsample_ratio = 1,
77
+ subsample_seq_len = 1,
78
+ ):
79
+ super().__init__()
80
+
81
+ self.use_dual_attention = use_dual_attention
82
+ self.context_pre_only = context_pre_only
83
+ context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_single"
84
+
85
+ if use_dual_attention:
86
+ self.norm1 = SD35AdaLayerNormZeroX(dim)
87
+ else:
88
+ # self.norm1 = AdaLayerNormZero(dim)
89
+ self.norm1 = nn.LayerNorm(dim)
90
+
91
+ assert subsample_ratio >= 1 and subsample_seq_len >= 1
92
+ self.subsample_ratio = subsample_ratio
93
+ self.subsample_seq_len = subsample_seq_len
94
+
95
+ print(self.subsample_ratio, self.subsample_seq_len)
96
+
97
+ # if context_norm_type == "ada_norm_continous":
98
+ # # self.norm1_context = AdaLayerNormContinuous(
99
+ # # dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
100
+ # # )
101
+ # elif context_norm_type == "ada_norm_single":
102
+ # # self.norm1_context = AdaLayerNormZero(dim)
103
+ # self.norm1_context = nn.LayerNorm(dim)
104
+ # else:
105
+ # raise ValueError(
106
+ # f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
107
+ # )
108
+ self.norm1_context = nn.LayerNorm(dim)
109
+
110
+ if hasattr(F, "scaled_dot_product_attention"):
111
+ processor = JointAttnProcessor2_0()
112
+ else:
113
+ raise ValueError(
114
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
115
+ )
116
+
117
+ self.attn = Attention(
118
+ query_dim=dim,
119
+ cross_attention_dim=None,
120
+ added_kv_proj_dim=dim,
121
+ dim_head=attention_head_dim,
122
+ heads=num_attention_heads,
123
+ out_dim=dim,
124
+ context_pre_only=context_pre_only,
125
+ bias=True,
126
+ processor=processor,
127
+ qk_norm=qk_norm,
128
+ eps=1e-6,
129
+ )
130
+
131
+ if use_dual_attention:
132
+ self.attn2 = Attention(
133
+ query_dim=dim,
134
+ cross_attention_dim=None,
135
+ dim_head=attention_head_dim,
136
+ heads=num_attention_heads,
137
+ out_dim=dim,
138
+ bias=True,
139
+ processor=processor,
140
+ qk_norm=qk_norm,
141
+ eps=1e-6,
142
+ )
143
+ else:
144
+ self.attn2 = None
145
+
146
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
147
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
148
+
149
+ if not context_pre_only:
150
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
151
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
152
+ else:
153
+ self.norm2_context = None
154
+ self.ff_context = None
155
+
156
+
157
+ self.scale_shift_bias = nn.Parameter(torch.randn(6, dim) / dim**0.5)
158
+ self.scale_shift_scale = nn.Parameter(torch.randn(6, dim) / dim**0.5)
159
+
160
+
161
+ if not context_pre_only:
162
+ self.scale_shift_bias_c = nn.Parameter(torch.randn(6, dim) / dim**0.5)
163
+ self.scale_shift_scale_c = nn.Parameter(torch.randn(6, dim) / dim**0.5)
164
+
165
+ # let chunk size default to None
166
+ self._chunk_size = None
167
+ self._chunk_dim = 0
168
+
169
+ # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
170
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
171
+ # Sets chunk feed-forward
172
+ self._chunk_size = chunk_size
173
+ self._chunk_dim = dim
174
+
175
+ def forward(
176
+ self,
177
+ hidden_states: torch.FloatTensor,
178
+ encoder_hidden_states: torch.FloatTensor,
179
+ temb: torch.FloatTensor,
180
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
181
+ embedded_timestep: torch.FloatTensor = None,
182
+ ):
183
+ joint_attention_kwargs = joint_attention_kwargs or {}
184
+ if self.use_dual_attention:
185
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
186
+ hidden_states, emb=temb
187
+ )
188
+ else:
189
+ # norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
190
+ batch_size = hidden_states.shape[0]
191
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
192
+ self.scale_shift_bias[None] + temb.reshape(batch_size, 6, -1)*(1+self.scale_shift_scale[None])
193
+ ).chunk(6, dim=1)
194
+ norm_hidden_states = self.norm1(hidden_states)
195
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
196
+
197
+ if self.context_pre_only:
198
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states)
199
+ # norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, embedded_timestep)
200
+ else:
201
+ # norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
202
+ # encoder_hidden_states, emb=temb
203
+ # )
204
+ batch_size = hidden_states.shape[0]
205
+ c_shift_msa, c_scale_msa, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
206
+ self.scale_shift_bias_c[None] + temb.reshape(batch_size, 6, -1)*(1+self.scale_shift_scale_c)
207
+ ).chunk(6, dim=1)
208
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states)
209
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_msa) + c_shift_msa
210
+
211
+ if self.subsample_ratio > 1:
212
+ norm_hidden_states = rearrange(norm_hidden_states,
213
+ 'b (l s n) c -> (b s) (l n) c',
214
+ n=self.subsample_seq_len, s=self.subsample_ratio)
215
+ norm_encoder_hidden_states = rearrange(norm_encoder_hidden_states,
216
+ 'b (l s n) c -> (b s) (l n) c',
217
+ n=self.subsample_seq_len, s=self.subsample_ratio)
218
+
219
+ # Attention.
220
+
221
+ attn_output, context_attn_output = self.attn(
222
+ hidden_states=norm_hidden_states,
223
+ encoder_hidden_states=norm_encoder_hidden_states,
224
+ **joint_attention_kwargs,
225
+ )
226
+ if self.subsample_ratio > 1:
227
+ attn_output = rearrange(attn_output,
228
+ '(b s) (l n) c -> b (l s n) c',
229
+ n=self.subsample_seq_len, s=self.subsample_ratio)
230
+ context_attn_output = rearrange(context_attn_output,
231
+ '(b s) (l n) c -> b (l s n) c',
232
+ n=self.subsample_seq_len, s=self.subsample_ratio)
233
+ # attn_output = norm_hidden_states
234
+ # context_attn_output = norm_encoder_hidden_states
235
+
236
+
237
+ # Process attention outputs for the `hidden_states`.
238
+ attn_output = gate_msa * attn_output
239
+ hidden_states = hidden_states + attn_output
240
+
241
+ if self.use_dual_attention:
242
+ attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs)
243
+ attn_output2 = gate_msa2 * attn_output2
244
+ hidden_states = hidden_states + attn_output2
245
+
246
+ norm_hidden_states = self.norm2(hidden_states)
247
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
248
+ if self._chunk_size is not None:
249
+ # "feed_forward_chunk_size" can be used to save memory
250
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
251
+ else:
252
+ ff_output = self.ff(norm_hidden_states)
253
+ ff_output = gate_mlp * ff_output
254
+
255
+ hidden_states = hidden_states + ff_output
256
+
257
+ # Process attention outputs for the `encoder_hidden_states`.
258
+ if self.context_pre_only:
259
+ encoder_hidden_states = None
260
+ else:
261
+ context_attn_output = c_gate_msa * context_attn_output
262
+ # print(context_attn_output.shape, encoder_hidden_states.shape)
263
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
264
+
265
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
266
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp
267
+ if self._chunk_size is not None:
268
+ # "feed_forward_chunk_size" can be used to save memory
269
+ context_ff_output = _chunked_feed_forward(
270
+ self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
271
+ )
272
+ else:
273
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
274
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
275
+
276
+ return encoder_hidden_states, hidden_states
277
+
278
+ # class TimestepEmbeddings(nn.Module):
279
+ # def __init__(self, embedding_dim):
280
+ # super().__init__()
281
+
282
+ # self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
283
+ # self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
284
+
285
+ # def forward(self, timestep, dtype):
286
+ # timesteps_proj = self.time_proj(timestep)
287
+ # timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=dtype)) # (N, D)
288
+
289
+ # return timesteps_emb
290
+
291
+ class Downsample(nn.Module):
292
+ def __init__(self, n_feat):
293
+ super(Downsample, self).__init__()
294
+
295
+ self.body = nn.Sequential(
296
+ nn.PixelUnshuffle(2),
297
+ nn.Conv2d(n_feat*4, n_feat, kernel_size=1, stride=1, padding=0, bias=True),
298
+ torch.nn.GELU('tanh'),
299
+ nn.Conv2d(n_feat, n_feat, kernel_size=1, stride=1, padding=0, bias=True))
300
+
301
+ def forward(self, x):
302
+ return self.body(x)
303
+
304
+ class Upsample(nn.Module):
305
+ def __init__(self, n_feat):
306
+ super(Upsample, self).__init__()
307
+
308
+ self.body = nn.Sequential(nn.PixelShuffle(2),
309
+ nn.Conv2d(n_feat//4, n_feat, kernel_size=1, stride=1, padding=0, bias=True),
310
+ torch.nn.GELU('tanh'),
311
+ nn.Conv2d(n_feat, n_feat, kernel_size=1, stride=1, padding=0, bias=True))
312
+
313
+ def forward(self, x):
314
+ return self.body(x)
315
+
316
+ class MMDiTTransformer2DModel(SD3Transformer2DModel):
317
+ """
318
+ The Transformer model introduced in Stable Diffusion 3.
319
+
320
+ Reference: https://arxiv.org/abs/2403.03206
321
+
322
+ Parameters:
323
+ sample_size (`int`): The width of the latent images. This is fixed during training since
324
+ it is used to learn a number of position embeddings.
325
+ patch_size (`int`): Patch size to turn the input data into small patches.
326
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
327
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use.
328
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
329
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
330
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
331
+ caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
332
+ out_channels (`int`, defaults to 16): Number of output channels.
333
+
334
+ """
335
+
336
+ _supports_gradient_checkpointing = True
337
+
338
+ @register_to_config
339
+ def __init__(
340
+ self,
341
+ sample_size: int = 128,
342
+ patch_size: int = 2,
343
+ in_channels: int = 16,
344
+ num_layers: int = 24,
345
+ attention_head_dim: int = 32,
346
+ num_attention_heads: int = 24,
347
+ caption_channels: int = 4096,
348
+ caption_projection_dim: int = 768,
349
+ out_channels: int = 16,
350
+ interpolation_scale: int = None,
351
+ pos_embed_max_size: int = 96,
352
+ dual_attention_layers: Tuple[
353
+ int, ...
354
+ ] = (), # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5
355
+ qk_norm: Optional[str] = None,
356
+ repa_depth = -1,
357
+ projector_dim=2048,
358
+ z_dims=[768]
359
+ ):
360
+ super().__init__(
361
+ sample_size=sample_size,
362
+ patch_size=patch_size,
363
+ in_channels=in_channels,
364
+ num_layers=num_layers,
365
+ attention_head_dim=attention_head_dim,
366
+ num_attention_heads=num_attention_heads,
367
+ caption_projection_dim=caption_projection_dim,
368
+ out_channels=out_channels,
369
+ pos_embed_max_size=pos_embed_max_size,
370
+ dual_attention_layers=dual_attention_layers,
371
+ qk_norm=qk_norm,
372
+ )
373
+
374
+ self.time_text_embed = None
375
+
376
+ self.patch_mixer_depth = None # initially no masking applied
377
+ self.mask_ratio = 0
378
+
379
+ # self.block_split_stage = [2, 20, 2]
380
+ self.block_split_stage = [4, 16, 4]
381
+ # self.block_split_stage = [12, 1, 12]
382
+
383
+ default_out_channels = in_channels
384
+ self.out_channels = out_channels if out_channels is not None else default_out_channels
385
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
386
+
387
+ if repa_depth != -1:
388
+ from core.models.projector import build_projector
389
+ self.projectors = nn.ModuleList([
390
+ build_projector(self.inner_dim, projector_dim, z_dim) for z_dim in z_dims
391
+ ])
392
+
393
+ assert repa_depth >= 0 and repa_depth < num_layers
394
+ self.repa_depth = repa_depth
395
+
396
+
397
+ interpolation_scale = (
398
+ self.config.interpolation_scale
399
+ if self.config.interpolation_scale is not None
400
+ else max(self.config.sample_size // 16, 1)
401
+ )
402
+
403
+ self.pos_embed = PatchEmbed(
404
+ height=self.config.sample_size,
405
+ width=self.config.sample_size,
406
+ patch_size=self.config.patch_size,
407
+ in_channels=self.config.in_channels,
408
+ embed_dim=self.inner_dim,
409
+ interpolation_scale=interpolation_scale,
410
+ pos_embed_max_size=pos_embed_max_size, # hard-code for now.
411
+ )
412
+
413
+ pos_embed_lv0 = get_2d_sincos_pos_embed(
414
+ self.inner_dim, pos_embed_max_size, base_size=self.config.sample_size // self.config.patch_size,
415
+ interpolation_scale=interpolation_scale, output_type='pt'
416
+ ) # [grid_size**2, embed_dim]
417
+
418
+ pos_embed_lv0 = cropped_pos_embed(pos_embed_lv0,
419
+ self.config.sample_size,
420
+ self.config.sample_size,
421
+ patch_size=1, pos_embed_max_size=pos_embed_max_size)
422
+
423
+
424
+ pos_embed_lv1 = pos_embed_lv0.clone()[:, ::2, ::2, :]
425
+
426
+ pos_embed_lv0 = pos_embed_lv0.reshape(1, -1, pos_embed_lv0.shape[-1])
427
+ pos_embed_lv1 = pos_embed_lv1.reshape(1, -1, pos_embed_lv1.shape[-1])
428
+
429
+
430
+
431
+ self.register_buffer("pos_embed_lv0", pos_embed_lv0.float(), persistent=False)
432
+ self.register_buffer("pos_embed_lv1", pos_embed_lv1.float(), persistent=False)
433
+
434
+ # self.time_text_embed = TimestepEmbeddings(embedding_dim=self.inner_dim)
435
+ self.context_embedder = nn.Linear(self.config.caption_channels, self.config.caption_projection_dim)
436
+
437
+ self.adaln_single = AdaLayerNormSingle(
438
+ self.inner_dim, use_additional_conditions=False
439
+ )
440
+
441
+ self.transformer_blocks = None
442
+
443
+ subample_ratio_list = [1, 4, 4]
444
+ seq_len_list = [1, 1, 4]
445
+ cur_ind = 0
446
+
447
+ self.block_groups = nn.ModuleList()
448
+ for grp_ids, cur_bks in enumerate(self.block_split_stage):
449
+ # cur_subample_ratio = 1
450
+ # seq_len_list = [1]
451
+ # if grp_ids == 1:
452
+ # cur_subample_ratio = 4
453
+ # seq_len_list = [1, 4]
454
+ cur_group = []
455
+ for i in range(cur_bks):
456
+ cur_group.append(JointTransformerBlockSingleNorm(
457
+ dim=self.inner_dim,
458
+ num_attention_heads=self.config.num_attention_heads,
459
+ attention_head_dim=self.config.attention_head_dim,
460
+ context_pre_only=(grp_ids==len(self.block_split_stage)-1) \
461
+ and (i == cur_bks - 1),
462
+ qk_norm=qk_norm,
463
+ use_dual_attention=False,
464
+ subsample_ratio=subample_ratio_list[cur_ind%len(subample_ratio_list)],
465
+ subsample_seq_len=seq_len_list[cur_ind%len(seq_len_list)],
466
+ ))
467
+ cur_ind += 1
468
+
469
+ cur_group = nn.ModuleList(cur_group)
470
+
471
+
472
+ # cur_group = nn.ModuleList(
473
+ # [
474
+ # JointTransformerBlockSingleNorm(
475
+ # dim=self.inner_dim,
476
+ # num_attention_heads=self.config.num_attention_heads,
477
+ # attention_head_dim=self.config.attention_head_dim,
478
+ # context_pre_only=(grp_ids==len(self.block_split_stage)-1) \
479
+ # and (i == cur_bks - 1),
480
+ # qk_norm=qk_norm,
481
+ # use_dual_attention=False,
482
+ # subsample_ratio=cur_subample_ratio,
483
+ # subsample_seq_len=seq_len_list[i%len(seq_len_list)],
484
+ # )
485
+ # for i in range(cur_bks)
486
+ # ])
487
+ self.block_groups.append(cur_group)
488
+
489
+ ds_num = int(len(self.block_split_stage) // 2)
490
+ self.downsamplers = nn.ModuleList()
491
+ for _ in range(ds_num):
492
+ self.downsamplers.append(Downsample(self.inner_dim))
493
+ self.upsamplers = nn.ModuleList()
494
+ for _ in range(ds_num):
495
+ self.upsamplers.append(Upsample(self.inner_dim))
496
+ self.mergers = nn.ModuleList()
497
+ for _ in range(ds_num):
498
+ # self.mergers.append(nn.Linear(self.inner_dim*2, self.inner_dim))
499
+ self.mergers.append(nn.Sequential(
500
+ nn.Linear(self.inner_dim*2, self.inner_dim),
501
+ torch.nn.GELU('tanh'),
502
+ nn.Linear(self.inner_dim, self.inner_dim)))
503
+
504
+
505
+ self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
506
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
507
+
508
+ self.gradient_checkpointing = False
509
+
510
+
511
+ def _set_gradient_checkpointing(self, module, value=False):
512
+ if hasattr(module, "gradient_checkpointing"):
513
+ module.gradient_checkpointing = value
514
+
515
+ def forward(
516
+ self,
517
+ hidden_states: torch.FloatTensor,
518
+ encoder_hidden_states: torch.FloatTensor = None,
519
+ timestep: torch.LongTensor = None,
520
+ block_controlnet_hidden_states: List = None,
521
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
522
+ return_dict: bool = True,
523
+ skip_layers: Optional[List[int]] = None,
524
+ **kwargs,
525
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
526
+ """
527
+ The [`SD3Transformer2DModel`] forward method.
528
+
529
+ Args:
530
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
531
+ Input `hidden_states`.
532
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
533
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
534
+ timestep (`torch.LongTensor`):
535
+ Used to indicate denoising step.
536
+ block_controlnet_hidden_states (`list` of `torch.Tensor`):
537
+ A list of tensors that if specified are added to the residuals of transformer blocks.
538
+ joint_attention_kwargs (`dict`, *optional*):
539
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
540
+ `self.processor` in
541
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
542
+ return_dict (`bool`, *optional*, defaults to `True`):
543
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
544
+ tuple.
545
+ skip_layers (`list` of `int`, *optional*):
546
+ A list of layer indices to skip during the forward pass.
547
+
548
+ Returns:
549
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
550
+ `tuple` where the first element is the sample tensor.
551
+ """
552
+
553
+ height, width = hidden_states.shape[-2:]
554
+
555
+ cur_height = height // self.config.patch_size
556
+ cur_width = width // self.config.patch_size
557
+
558
+ hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
559
+ # temb = self.time_text_embed(timestep, dtype=encoder_hidden_states.dtype)
560
+ temb, embedded_timestep = self.adaln_single(
561
+ timestep, None, batch_size=hidden_states.shape[0], hidden_dtype=hidden_states.dtype
562
+ )
563
+
564
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
565
+
566
+ ids_keep = None
567
+ len_keep = hidden_states.shape[1]
568
+ zs = None
569
+
570
+ ds_num = int(len(self.block_split_stage) // 2)
571
+ encoder_feats = []
572
+ for grp_ids, blocks in enumerate(self.block_groups):
573
+ # for encoders
574
+ for index_block, block in enumerate(blocks):
575
+ # Skip specified layers
576
+ is_skip = True if skip_layers is not None and index_block in skip_layers else False
577
+
578
+ if torch.is_grad_enabled() and self.gradient_checkpointing and not is_skip:
579
+
580
+ def create_custom_forward(module, return_dict=None):
581
+ def custom_forward(*inputs):
582
+ if return_dict is not None:
583
+ return module(*inputs, return_dict=return_dict)
584
+ else:
585
+ return module(*inputs)
586
+
587
+ return custom_forward
588
+
589
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
590
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
591
+ create_custom_forward(block),
592
+ hidden_states,
593
+ encoder_hidden_states,
594
+ temb,
595
+ joint_attention_kwargs,
596
+ **ckpt_kwargs,
597
+ )
598
+ elif not is_skip:
599
+ encoder_hidden_states, hidden_states = block(
600
+ hidden_states=hidden_states,
601
+ encoder_hidden_states=encoder_hidden_states,
602
+ temb=temb,
603
+ joint_attention_kwargs=joint_attention_kwargs,
604
+ )
605
+
606
+ if grp_ids == 1 and index_block==self.repa_depth-self.block_split_stage[0]-1:
607
+ if self.training and (self.repa_depth != -1):
608
+ reshaped_out = rearrange(hidden_states, "n (h w) c -> n c h w", h=cur_height, w=cur_width)
609
+ upsampled_out = torch.nn.functional.interpolate(reshaped_out, size=(cur_height*2, cur_width*2))
610
+ out_1d = rearrange(upsampled_out, "n c h w -> n (h w) c", h=cur_height*2, w=cur_width*2)
611
+ zs = [projector(out_1d) for projector in self.projectors]
612
+ if grp_ids < ds_num:
613
+ encoder_feats.append(hidden_states)
614
+
615
+ hidden_states = self.downsamplers[grp_ids](rearrange(hidden_states, "n (h w) c -> n c h w", h=cur_height, w=cur_width))
616
+ cur_height = int(cur_height / 2)
617
+ cur_width = int(cur_width / 2)
618
+ hidden_states = rearrange(hidden_states, "n c h w -> n (h w) c", h=cur_height, w=cur_width)
619
+ hidden_states = hidden_states + self.pos_embed_lv1
620
+ elif grp_ids < len(self.block_split_stage)-1:
621
+ hidden_states = self.upsamplers[grp_ids-ds_num](rearrange(hidden_states, "n (h w) c -> n c h w", h=cur_height, w=cur_width))
622
+ cur_height = int(cur_height * 2)
623
+ cur_width = int(cur_width * 2)
624
+ hidden_states = rearrange(hidden_states, "n c h w -> n (h w) c", h=cur_height, w=cur_width)
625
+
626
+ hidden_states = torch.cat([hidden_states, encoder_feats[len(encoder_feats)-1-(grp_ids-ds_num)]], dim=2)
627
+ hidden_states = self.mergers[grp_ids-ds_num](hidden_states)
628
+ hidden_states = hidden_states + self.pos_embed_lv0
629
+
630
+ # print(hidden_states.shape, temb.shape)
631
+ hidden_states = self.norm_out(hidden_states)
632
+ hidden_states = self.proj_out(hidden_states)
633
+
634
+ if not self.training:
635
+ # unpatchify
636
+ patch_size = self.config.patch_size
637
+ height = height // patch_size
638
+ width = width // patch_size
639
+
640
+ hidden_states = hidden_states.reshape(
641
+ shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
642
+ )
643
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
644
+ output = hidden_states.reshape(
645
+ shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
646
+ )
647
+
648
+ if not return_dict:
649
+ return (output,)
650
+
651
+ return Transformer2DModelOutput(sample=output)
652
+
653
+ else:
654
+ return hidden_states, ids_keep, zs
655
+
656
+
657
+ def enable_masking(self, depth, mask_ratio):
658
+ # depth: apply masking after block_[depth]. should be [0, nblks-1]
659
+ assert depth >= 0 and depth < len(self.transformer_blocks)
660
+ self.patch_mixer_depth = depth
661
+ assert mask_ratio >= 0 and mask_ratio <= 1
662
+ self.mask_ratio = mask_ratio
663
+
664
+ def disable_masking(self):
665
+ self.patch_mixer_depth = None
666
+
667
+ def enable_gradient_checkpointing(self, nblocks_to_apply_grad_checkpointing):
668
+ N = len(self.transformer_blocks)
669
+
670
+ if nblocks_to_apply_grad_checkpointing == -1:
671
+ nblocks_to_apply_grad_checkpointing = N
672
+ nblocks_to_apply_grad_checkpointing = min(N, nblocks_to_apply_grad_checkpointing)
673
+
674
+ # Apply to blocks evenly spaced out
675
+ step = N / nblocks_to_apply_grad_checkpointing if nblocks_to_apply_grad_checkpointing > 0 else 0
676
+ indices = [int((i+0.5)*step) for i in range(nblocks_to_apply_grad_checkpointing)]
677
+
678
+ self.gradient_checkpointing = True
679
+ for blk_ind, block in enumerate(self.transformer_blocks):
680
+ block.gradient_checkpointing = (blk_ind in indices)
681
+ print(f"Block {blk_ind} grad checkpointing set to {block.gradient_checkpointing}")
ummdit_small_ds1_singlenorm_v5_newdata_ft512_ema_checkpoint-48000_model_ema.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:366656475662eca6791d6945cf428f01406e57c7387d77f9879e3f92c4594786
3
+ size 1415138120
ummdit_small_ds1_singlenorm_v5_newdata_newsetting_ft512.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exp_name: 'ummdit_small_ds1_singlenorm_v5_newdata_ft512_ema'
2
+ model:
3
+ flashSA: 'Joint_SA' # choose from ['SACA', 'SA', None]
4
+ arch: 'ummdit_small_ds1_singlenorm_v5'
5
+ caption_max_seq_length: 128
6
+
7
+ training:
8
+ save_freq: 10000
9
+ max_iters: 50000
10
+ transformer_ckpt: '/root/tongs/projects/efficient_diffusion_training/ummdit_small_ds1_singlenorm_v5_newdata_newsetting/checkpoints/checkpoint-100000/model.safetensors'
11
+ use_ema: True
12
+
13
+ dataset:
14
+ datasets: [
15
+ '/root/tongs/data/jDB/mds_latents_dcvae_flant5large/train',
16
+ '/root/tongs/data/jDB/mds_latents_dcvae_flant5large/valid',
17
+ '/root/tongs/data/flux_gen_data/imglatent_mds_UCSC_part0/',
18
+ '/root/tongs/data/flux_gen_data/imglatent_mds_UCSC_part1/',
19
+ '/root/tongs/data/flux_gen_data/imglatent_mds_UCSC_part2/',
20
+ '/root/tongs/data/flux_gen_data/imglatent_mds_UCSC_part3/',
21
+ '/root/tongs/data/flux_gen_data/imglatent_mds_UCSC_part4/',
22
+ '/root/tongs/data/flux_gen_data/imglatent_mds_UCSC_part5/',
23
+ '/root/tongs/data/flux_gen_data/imglatent_mds_diffdb/'
24
+ ]