MaykaGR commited on
Commit
2ed07d9
·
verified ·
1 Parent(s): 078a1ac

Upload 28 files

Browse files
vae (1)/model.py ADDED
@@ -0,0 +1,711 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #original code from https://github.com/genmoai/models under apache 2.0 license
2
+ #adapted to ComfyUI
3
+
4
+ from typing import List, Optional, Tuple, Union
5
+ from functools import partial
6
+ import math
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from einops import rearrange
12
+
13
+ from comfy.ldm.modules.attention import optimized_attention
14
+
15
+ import comfy.ops
16
+ ops = comfy.ops.disable_weight_init
17
+
18
+ # import mochi_preview.dit.joint_model.context_parallel as cp
19
+ # from mochi_preview.vae.cp_conv import cp_pass_frames, gather_all_frames
20
+
21
+
22
+ def cast_tuple(t, length=1):
23
+ return t if isinstance(t, tuple) else ((t,) * length)
24
+
25
+
26
+ class GroupNormSpatial(ops.GroupNorm):
27
+ """
28
+ GroupNorm applied per-frame.
29
+ """
30
+
31
+ def forward(self, x: torch.Tensor, *, chunk_size: int = 8):
32
+ B, C, T, H, W = x.shape
33
+ x = rearrange(x, "B C T H W -> (B T) C H W")
34
+ # Run group norm in chunks.
35
+ output = torch.empty_like(x)
36
+ for b in range(0, B * T, chunk_size):
37
+ output[b : b + chunk_size] = super().forward(x[b : b + chunk_size])
38
+ return rearrange(output, "(B T) C H W -> B C T H W", B=B, T=T)
39
+
40
+ class PConv3d(ops.Conv3d):
41
+ def __init__(
42
+ self,
43
+ in_channels,
44
+ out_channels,
45
+ kernel_size: Union[int, Tuple[int, int, int]],
46
+ stride: Union[int, Tuple[int, int, int]],
47
+ causal: bool = True,
48
+ context_parallel: bool = True,
49
+ **kwargs,
50
+ ):
51
+ self.causal = causal
52
+ self.context_parallel = context_parallel
53
+ kernel_size = cast_tuple(kernel_size, 3)
54
+ stride = cast_tuple(stride, 3)
55
+ height_pad = (kernel_size[1] - 1) // 2
56
+ width_pad = (kernel_size[2] - 1) // 2
57
+
58
+ super().__init__(
59
+ in_channels=in_channels,
60
+ out_channels=out_channels,
61
+ kernel_size=kernel_size,
62
+ stride=stride,
63
+ dilation=(1, 1, 1),
64
+ padding=(0, height_pad, width_pad),
65
+ **kwargs,
66
+ )
67
+
68
+ def forward(self, x: torch.Tensor):
69
+ # Compute padding amounts.
70
+ context_size = self.kernel_size[0] - 1
71
+ if self.causal:
72
+ pad_front = context_size
73
+ pad_back = 0
74
+ else:
75
+ pad_front = context_size // 2
76
+ pad_back = context_size - pad_front
77
+
78
+ # Apply padding.
79
+ assert self.padding_mode == "replicate" # DEBUG
80
+ mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
81
+ x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
82
+ return super().forward(x)
83
+
84
+
85
+ class Conv1x1(ops.Linear):
86
+ """*1x1 Conv implemented with a linear layer."""
87
+
88
+ def __init__(self, in_features: int, out_features: int, *args, **kwargs):
89
+ super().__init__(in_features, out_features, *args, **kwargs)
90
+
91
+ def forward(self, x: torch.Tensor):
92
+ """Forward pass.
93
+
94
+ Args:
95
+ x: Input tensor. Shape: [B, C, *] or [B, *, C].
96
+
97
+ Returns:
98
+ x: Output tensor. Shape: [B, C', *] or [B, *, C'].
99
+ """
100
+ x = x.movedim(1, -1)
101
+ x = super().forward(x)
102
+ x = x.movedim(-1, 1)
103
+ return x
104
+
105
+
106
+ class DepthToSpaceTime(nn.Module):
107
+ def __init__(
108
+ self,
109
+ temporal_expansion: int,
110
+ spatial_expansion: int,
111
+ ):
112
+ super().__init__()
113
+ self.temporal_expansion = temporal_expansion
114
+ self.spatial_expansion = spatial_expansion
115
+
116
+ # When printed, this module should show the temporal and spatial expansion factors.
117
+ def extra_repr(self):
118
+ return f"texp={self.temporal_expansion}, sexp={self.spatial_expansion}"
119
+
120
+ def forward(self, x: torch.Tensor):
121
+ """Forward pass.
122
+
123
+ Args:
124
+ x: Input tensor. Shape: [B, C, T, H, W].
125
+
126
+ Returns:
127
+ x: Rearranged tensor. Shape: [B, C/(st*s*s), T*st, H*s, W*s].
128
+ """
129
+ x = rearrange(
130
+ x,
131
+ "B (C st sh sw) T H W -> B C (T st) (H sh) (W sw)",
132
+ st=self.temporal_expansion,
133
+ sh=self.spatial_expansion,
134
+ sw=self.spatial_expansion,
135
+ )
136
+
137
+ # cp_rank, _ = cp.get_cp_rank_size()
138
+ if self.temporal_expansion > 1: # and cp_rank == 0:
139
+ # Drop the first self.temporal_expansion - 1 frames.
140
+ # This is because we always want the 3x3x3 conv filter to only apply
141
+ # to the first frame, and the first frame doesn't need to be repeated.
142
+ assert all(x.shape)
143
+ x = x[:, :, self.temporal_expansion - 1 :]
144
+ assert all(x.shape)
145
+
146
+ return x
147
+
148
+
149
+ def norm_fn(
150
+ in_channels: int,
151
+ affine: bool = True,
152
+ ):
153
+ return GroupNormSpatial(affine=affine, num_groups=32, num_channels=in_channels)
154
+
155
+
156
+ class ResBlock(nn.Module):
157
+ """Residual block that preserves the spatial dimensions."""
158
+
159
+ def __init__(
160
+ self,
161
+ channels: int,
162
+ *,
163
+ affine: bool = True,
164
+ attn_block: Optional[nn.Module] = None,
165
+ causal: bool = True,
166
+ prune_bottleneck: bool = False,
167
+ padding_mode: str,
168
+ bias: bool = True,
169
+ ):
170
+ super().__init__()
171
+ self.channels = channels
172
+
173
+ assert causal
174
+ self.stack = nn.Sequential(
175
+ norm_fn(channels, affine=affine),
176
+ nn.SiLU(inplace=True),
177
+ PConv3d(
178
+ in_channels=channels,
179
+ out_channels=channels // 2 if prune_bottleneck else channels,
180
+ kernel_size=(3, 3, 3),
181
+ stride=(1, 1, 1),
182
+ padding_mode=padding_mode,
183
+ bias=bias,
184
+ causal=causal,
185
+ ),
186
+ norm_fn(channels, affine=affine),
187
+ nn.SiLU(inplace=True),
188
+ PConv3d(
189
+ in_channels=channels // 2 if prune_bottleneck else channels,
190
+ out_channels=channels,
191
+ kernel_size=(3, 3, 3),
192
+ stride=(1, 1, 1),
193
+ padding_mode=padding_mode,
194
+ bias=bias,
195
+ causal=causal,
196
+ ),
197
+ )
198
+
199
+ self.attn_block = attn_block if attn_block else nn.Identity()
200
+
201
+ def forward(self, x: torch.Tensor):
202
+ """Forward pass.
203
+
204
+ Args:
205
+ x: Input tensor. Shape: [B, C, T, H, W].
206
+ """
207
+ residual = x
208
+ x = self.stack(x)
209
+ x = x + residual
210
+ del residual
211
+
212
+ return self.attn_block(x)
213
+
214
+
215
+ class Attention(nn.Module):
216
+ def __init__(
217
+ self,
218
+ dim: int,
219
+ head_dim: int = 32,
220
+ qkv_bias: bool = False,
221
+ out_bias: bool = True,
222
+ qk_norm: bool = True,
223
+ ) -> None:
224
+ super().__init__()
225
+ self.head_dim = head_dim
226
+ self.num_heads = dim // head_dim
227
+ self.qk_norm = qk_norm
228
+
229
+ self.qkv = nn.Linear(dim, 3 * dim, bias=qkv_bias)
230
+ self.out = nn.Linear(dim, dim, bias=out_bias)
231
+
232
+ def forward(
233
+ self,
234
+ x: torch.Tensor,
235
+ ) -> torch.Tensor:
236
+ """Compute temporal self-attention.
237
+
238
+ Args:
239
+ x: Input tensor. Shape: [B, C, T, H, W].
240
+ chunk_size: Chunk size for large tensors.
241
+
242
+ Returns:
243
+ x: Output tensor. Shape: [B, C, T, H, W].
244
+ """
245
+ B, _, T, H, W = x.shape
246
+
247
+ if T == 1:
248
+ # No attention for single frame.
249
+ x = x.movedim(1, -1) # [B, C, T, H, W] -> [B, T, H, W, C]
250
+ qkv = self.qkv(x)
251
+ _, _, x = qkv.chunk(3, dim=-1) # Throw away queries and keys.
252
+ x = self.out(x)
253
+ return x.movedim(-1, 1) # [B, T, H, W, C] -> [B, C, T, H, W]
254
+
255
+ # 1D temporal attention.
256
+ x = rearrange(x, "B C t h w -> (B h w) t C")
257
+ qkv = self.qkv(x)
258
+
259
+ # Input: qkv with shape [B, t, 3 * num_heads * head_dim]
260
+ # Output: x with shape [B, num_heads, t, head_dim]
261
+ q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, self.head_dim).transpose(1, 3).unbind(2)
262
+
263
+ if self.qk_norm:
264
+ q = F.normalize(q, p=2, dim=-1)
265
+ k = F.normalize(k, p=2, dim=-1)
266
+
267
+ x = optimized_attention(q, k, v, self.num_heads, skip_reshape=True)
268
+
269
+ assert x.size(0) == q.size(0)
270
+
271
+ x = self.out(x)
272
+ x = rearrange(x, "(B h w) t C -> B C t h w", B=B, h=H, w=W)
273
+ return x
274
+
275
+
276
+ class AttentionBlock(nn.Module):
277
+ def __init__(
278
+ self,
279
+ dim: int,
280
+ **attn_kwargs,
281
+ ) -> None:
282
+ super().__init__()
283
+ self.norm = norm_fn(dim)
284
+ self.attn = Attention(dim, **attn_kwargs)
285
+
286
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
287
+ return x + self.attn(self.norm(x))
288
+
289
+
290
+ class CausalUpsampleBlock(nn.Module):
291
+ def __init__(
292
+ self,
293
+ in_channels: int,
294
+ out_channels: int,
295
+ num_res_blocks: int,
296
+ *,
297
+ temporal_expansion: int = 2,
298
+ spatial_expansion: int = 2,
299
+ **block_kwargs,
300
+ ):
301
+ super().__init__()
302
+
303
+ blocks = []
304
+ for _ in range(num_res_blocks):
305
+ blocks.append(block_fn(in_channels, **block_kwargs))
306
+ self.blocks = nn.Sequential(*blocks)
307
+
308
+ self.temporal_expansion = temporal_expansion
309
+ self.spatial_expansion = spatial_expansion
310
+
311
+ # Change channels in the final convolution layer.
312
+ self.proj = Conv1x1(
313
+ in_channels,
314
+ out_channels * temporal_expansion * (spatial_expansion**2),
315
+ )
316
+
317
+ self.d2st = DepthToSpaceTime(
318
+ temporal_expansion=temporal_expansion, spatial_expansion=spatial_expansion
319
+ )
320
+
321
+ def forward(self, x):
322
+ x = self.blocks(x)
323
+ x = self.proj(x)
324
+ x = self.d2st(x)
325
+ return x
326
+
327
+
328
+ def block_fn(channels, *, affine: bool = True, has_attention: bool = False, **block_kwargs):
329
+ attn_block = AttentionBlock(channels) if has_attention else None
330
+ return ResBlock(channels, affine=affine, attn_block=attn_block, **block_kwargs)
331
+
332
+
333
+ class DownsampleBlock(nn.Module):
334
+ def __init__(
335
+ self,
336
+ in_channels: int,
337
+ out_channels: int,
338
+ num_res_blocks,
339
+ *,
340
+ temporal_reduction=2,
341
+ spatial_reduction=2,
342
+ **block_kwargs,
343
+ ):
344
+ """
345
+ Downsample block for the VAE encoder.
346
+
347
+ Args:
348
+ in_channels: Number of input channels.
349
+ out_channels: Number of output channels.
350
+ num_res_blocks: Number of residual blocks.
351
+ temporal_reduction: Temporal reduction factor.
352
+ spatial_reduction: Spatial reduction factor.
353
+ """
354
+ super().__init__()
355
+ layers = []
356
+
357
+ # Change the channel count in the strided convolution.
358
+ # This lets the ResBlock have uniform channel count,
359
+ # as in ConvNeXt.
360
+ assert in_channels != out_channels
361
+ layers.append(
362
+ PConv3d(
363
+ in_channels=in_channels,
364
+ out_channels=out_channels,
365
+ kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction),
366
+ stride=(temporal_reduction, spatial_reduction, spatial_reduction),
367
+ # First layer in each block always uses replicate padding
368
+ padding_mode="replicate",
369
+ bias=block_kwargs["bias"],
370
+ )
371
+ )
372
+
373
+ for _ in range(num_res_blocks):
374
+ layers.append(block_fn(out_channels, **block_kwargs))
375
+
376
+ self.layers = nn.Sequential(*layers)
377
+
378
+ def forward(self, x):
379
+ return self.layers(x)
380
+
381
+
382
+ def add_fourier_features(inputs: torch.Tensor, start=6, stop=8, step=1):
383
+ num_freqs = (stop - start) // step
384
+ assert inputs.ndim == 5
385
+ C = inputs.size(1)
386
+
387
+ # Create Base 2 Fourier features.
388
+ freqs = torch.arange(start, stop, step, dtype=inputs.dtype, device=inputs.device)
389
+ assert num_freqs == len(freqs)
390
+ w = torch.pow(2.0, freqs) * (2 * torch.pi) # [num_freqs]
391
+ C = inputs.shape[1]
392
+ w = w.repeat(C)[None, :, None, None, None] # [1, C * num_freqs, 1, 1, 1]
393
+
394
+ # Interleaved repeat of input channels to match w.
395
+ h = inputs.repeat_interleave(num_freqs, dim=1) # [B, C * num_freqs, T, H, W]
396
+ # Scale channels by frequency.
397
+ h = w * h
398
+
399
+ return torch.cat(
400
+ [
401
+ inputs,
402
+ torch.sin(h),
403
+ torch.cos(h),
404
+ ],
405
+ dim=1,
406
+ )
407
+
408
+
409
+ class FourierFeatures(nn.Module):
410
+ def __init__(self, start: int = 6, stop: int = 8, step: int = 1):
411
+ super().__init__()
412
+ self.start = start
413
+ self.stop = stop
414
+ self.step = step
415
+
416
+ def forward(self, inputs):
417
+ """Add Fourier features to inputs.
418
+
419
+ Args:
420
+ inputs: Input tensor. Shape: [B, C, T, H, W]
421
+
422
+ Returns:
423
+ h: Output tensor. Shape: [B, (1 + 2 * num_freqs) * C, T, H, W]
424
+ """
425
+ return add_fourier_features(inputs, self.start, self.stop, self.step)
426
+
427
+
428
+ class Decoder(nn.Module):
429
+ def __init__(
430
+ self,
431
+ *,
432
+ out_channels: int = 3,
433
+ latent_dim: int,
434
+ base_channels: int,
435
+ channel_multipliers: List[int],
436
+ num_res_blocks: List[int],
437
+ temporal_expansions: Optional[List[int]] = None,
438
+ spatial_expansions: Optional[List[int]] = None,
439
+ has_attention: List[bool],
440
+ output_norm: bool = True,
441
+ nonlinearity: str = "silu",
442
+ output_nonlinearity: str = "silu",
443
+ causal: bool = True,
444
+ **block_kwargs,
445
+ ):
446
+ super().__init__()
447
+ self.input_channels = latent_dim
448
+ self.base_channels = base_channels
449
+ self.channel_multipliers = channel_multipliers
450
+ self.num_res_blocks = num_res_blocks
451
+ self.output_nonlinearity = output_nonlinearity
452
+ assert nonlinearity == "silu"
453
+ assert causal
454
+
455
+ ch = [mult * base_channels for mult in channel_multipliers]
456
+ self.num_up_blocks = len(ch) - 1
457
+ assert len(num_res_blocks) == self.num_up_blocks + 2
458
+
459
+ blocks = []
460
+
461
+ first_block = [
462
+ ops.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1))
463
+ ] # Input layer.
464
+ # First set of blocks preserve channel count.
465
+ for _ in range(num_res_blocks[-1]):
466
+ first_block.append(
467
+ block_fn(
468
+ ch[-1],
469
+ has_attention=has_attention[-1],
470
+ causal=causal,
471
+ **block_kwargs,
472
+ )
473
+ )
474
+ blocks.append(nn.Sequential(*first_block))
475
+
476
+ assert len(temporal_expansions) == len(spatial_expansions) == self.num_up_blocks
477
+ assert len(num_res_blocks) == len(has_attention) == self.num_up_blocks + 2
478
+
479
+ upsample_block_fn = CausalUpsampleBlock
480
+
481
+ for i in range(self.num_up_blocks):
482
+ block = upsample_block_fn(
483
+ ch[-i - 1],
484
+ ch[-i - 2],
485
+ num_res_blocks=num_res_blocks[-i - 2],
486
+ has_attention=has_attention[-i - 2],
487
+ temporal_expansion=temporal_expansions[-i - 1],
488
+ spatial_expansion=spatial_expansions[-i - 1],
489
+ causal=causal,
490
+ **block_kwargs,
491
+ )
492
+ blocks.append(block)
493
+
494
+ assert not output_norm
495
+
496
+ # Last block. Preserve channel count.
497
+ last_block = []
498
+ for _ in range(num_res_blocks[0]):
499
+ last_block.append(
500
+ block_fn(
501
+ ch[0], has_attention=has_attention[0], causal=causal, **block_kwargs
502
+ )
503
+ )
504
+ blocks.append(nn.Sequential(*last_block))
505
+
506
+ self.blocks = nn.ModuleList(blocks)
507
+ self.output_proj = Conv1x1(ch[0], out_channels)
508
+
509
+ def forward(self, x):
510
+ """Forward pass.
511
+
512
+ Args:
513
+ x: Latent tensor. Shape: [B, input_channels, t, h, w]. Scaled [-1, 1].
514
+
515
+ Returns:
516
+ x: Reconstructed video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1].
517
+ T + 1 = (t - 1) * 4.
518
+ H = h * 16, W = w * 16.
519
+ """
520
+ for block in self.blocks:
521
+ x = block(x)
522
+
523
+ if self.output_nonlinearity == "silu":
524
+ x = F.silu(x, inplace=not self.training)
525
+ else:
526
+ assert (
527
+ not self.output_nonlinearity
528
+ ) # StyleGAN3 omits the to-RGB nonlinearity.
529
+
530
+ return self.output_proj(x).contiguous()
531
+
532
+ class LatentDistribution:
533
+ def __init__(self, mean: torch.Tensor, logvar: torch.Tensor):
534
+ """Initialize latent distribution.
535
+
536
+ Args:
537
+ mean: Mean of the distribution. Shape: [B, C, T, H, W].
538
+ logvar: Logarithm of variance of the distribution. Shape: [B, C, T, H, W].
539
+ """
540
+ assert mean.shape == logvar.shape
541
+ self.mean = mean
542
+ self.logvar = logvar
543
+
544
+ def sample(self, temperature=1.0, generator: torch.Generator = None, noise=None):
545
+ if temperature == 0.0:
546
+ return self.mean
547
+
548
+ if noise is None:
549
+ noise = torch.randn(self.mean.shape, device=self.mean.device, dtype=self.mean.dtype, generator=generator)
550
+ else:
551
+ assert noise.device == self.mean.device
552
+ noise = noise.to(self.mean.dtype)
553
+
554
+ if temperature != 1.0:
555
+ raise NotImplementedError(f"Temperature {temperature} is not supported.")
556
+
557
+ # Just Gaussian sample with no scaling of variance.
558
+ return noise * torch.exp(self.logvar * 0.5) + self.mean
559
+
560
+ def mode(self):
561
+ return self.mean
562
+
563
+ class Encoder(nn.Module):
564
+ def __init__(
565
+ self,
566
+ *,
567
+ in_channels: int,
568
+ base_channels: int,
569
+ channel_multipliers: List[int],
570
+ num_res_blocks: List[int],
571
+ latent_dim: int,
572
+ temporal_reductions: List[int],
573
+ spatial_reductions: List[int],
574
+ prune_bottlenecks: List[bool],
575
+ has_attentions: List[bool],
576
+ affine: bool = True,
577
+ bias: bool = True,
578
+ input_is_conv_1x1: bool = False,
579
+ padding_mode: str,
580
+ ):
581
+ super().__init__()
582
+ self.temporal_reductions = temporal_reductions
583
+ self.spatial_reductions = spatial_reductions
584
+ self.base_channels = base_channels
585
+ self.channel_multipliers = channel_multipliers
586
+ self.num_res_blocks = num_res_blocks
587
+ self.latent_dim = latent_dim
588
+
589
+ self.fourier_features = FourierFeatures()
590
+ ch = [mult * base_channels for mult in channel_multipliers]
591
+ num_down_blocks = len(ch) - 1
592
+ assert len(num_res_blocks) == num_down_blocks + 2
593
+
594
+ layers = (
595
+ [ops.Conv3d(in_channels, ch[0], kernel_size=(1, 1, 1), bias=True)]
596
+ if not input_is_conv_1x1
597
+ else [Conv1x1(in_channels, ch[0])]
598
+ )
599
+
600
+ assert len(prune_bottlenecks) == num_down_blocks + 2
601
+ assert len(has_attentions) == num_down_blocks + 2
602
+ block = partial(block_fn, padding_mode=padding_mode, affine=affine, bias=bias)
603
+
604
+ for _ in range(num_res_blocks[0]):
605
+ layers.append(block(ch[0], has_attention=has_attentions[0], prune_bottleneck=prune_bottlenecks[0]))
606
+ prune_bottlenecks = prune_bottlenecks[1:]
607
+ has_attentions = has_attentions[1:]
608
+
609
+ assert len(temporal_reductions) == len(spatial_reductions) == len(ch) - 1
610
+ for i in range(num_down_blocks):
611
+ layer = DownsampleBlock(
612
+ ch[i],
613
+ ch[i + 1],
614
+ num_res_blocks=num_res_blocks[i + 1],
615
+ temporal_reduction=temporal_reductions[i],
616
+ spatial_reduction=spatial_reductions[i],
617
+ prune_bottleneck=prune_bottlenecks[i],
618
+ has_attention=has_attentions[i],
619
+ affine=affine,
620
+ bias=bias,
621
+ padding_mode=padding_mode,
622
+ )
623
+
624
+ layers.append(layer)
625
+
626
+ # Additional blocks.
627
+ for _ in range(num_res_blocks[-1]):
628
+ layers.append(block(ch[-1], has_attention=has_attentions[-1], prune_bottleneck=prune_bottlenecks[-1]))
629
+
630
+ self.layers = nn.Sequential(*layers)
631
+
632
+ # Output layers.
633
+ self.output_norm = norm_fn(ch[-1])
634
+ self.output_proj = Conv1x1(ch[-1], 2 * latent_dim, bias=False)
635
+
636
+ @property
637
+ def temporal_downsample(self):
638
+ return math.prod(self.temporal_reductions)
639
+
640
+ @property
641
+ def spatial_downsample(self):
642
+ return math.prod(self.spatial_reductions)
643
+
644
+ def forward(self, x) -> LatentDistribution:
645
+ """Forward pass.
646
+
647
+ Args:
648
+ x: Input video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1]
649
+
650
+ Returns:
651
+ means: Latent tensor. Shape: [B, latent_dim, t, h, w]. Scaled [-1, 1].
652
+ h = H // 8, w = W // 8, t - 1 = (T - 1) // 6
653
+ logvar: Shape: [B, latent_dim, t, h, w].
654
+ """
655
+ assert x.ndim == 5, f"Expected 5D input, got {x.shape}"
656
+ x = self.fourier_features(x)
657
+
658
+ x = self.layers(x)
659
+
660
+ x = self.output_norm(x)
661
+ x = F.silu(x, inplace=True)
662
+ x = self.output_proj(x)
663
+
664
+ means, logvar = torch.chunk(x, 2, dim=1)
665
+
666
+ assert means.ndim == 5
667
+ assert logvar.shape == means.shape
668
+ assert means.size(1) == self.latent_dim
669
+
670
+ return LatentDistribution(means, logvar)
671
+
672
+
673
+ class VideoVAE(nn.Module):
674
+ def __init__(self):
675
+ super().__init__()
676
+ self.encoder = Encoder(
677
+ in_channels=15,
678
+ base_channels=64,
679
+ channel_multipliers=[1, 2, 4, 6],
680
+ num_res_blocks=[3, 3, 4, 6, 3],
681
+ latent_dim=12,
682
+ temporal_reductions=[1, 2, 3],
683
+ spatial_reductions=[2, 2, 2],
684
+ prune_bottlenecks=[False, False, False, False, False],
685
+ has_attentions=[False, True, True, True, True],
686
+ affine=True,
687
+ bias=True,
688
+ input_is_conv_1x1=True,
689
+ padding_mode="replicate"
690
+ )
691
+ self.decoder = Decoder(
692
+ out_channels=3,
693
+ base_channels=128,
694
+ channel_multipliers=[1, 2, 4, 6],
695
+ temporal_expansions=[1, 2, 3],
696
+ spatial_expansions=[2, 2, 2],
697
+ num_res_blocks=[3, 3, 4, 6, 3],
698
+ latent_dim=12,
699
+ has_attention=[False, False, False, False, False],
700
+ padding_mode="replicate",
701
+ output_norm=False,
702
+ nonlinearity="silu",
703
+ output_nonlinearity="silu",
704
+ causal=True,
705
+ )
706
+
707
+ def encode(self, x):
708
+ return self.encoder(x).mode()
709
+
710
+ def decode(self, x):
711
+ return self.decoder(x)
vae (1)/put_vae_here ADDED
File without changes
vae (2)/causal_conv3d.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import comfy.ops
6
+ ops = comfy.ops.disable_weight_init
7
+
8
+
9
+ class CausalConv3d(nn.Module):
10
+ def __init__(
11
+ self,
12
+ in_channels,
13
+ out_channels,
14
+ kernel_size: int = 3,
15
+ stride: Union[int, Tuple[int]] = 1,
16
+ dilation: int = 1,
17
+ groups: int = 1,
18
+ **kwargs,
19
+ ):
20
+ super().__init__()
21
+
22
+ self.in_channels = in_channels
23
+ self.out_channels = out_channels
24
+
25
+ kernel_size = (kernel_size, kernel_size, kernel_size)
26
+ self.time_kernel_size = kernel_size[0]
27
+
28
+ dilation = (dilation, 1, 1)
29
+
30
+ height_pad = kernel_size[1] // 2
31
+ width_pad = kernel_size[2] // 2
32
+ padding = (0, height_pad, width_pad)
33
+
34
+ self.conv = ops.Conv3d(
35
+ in_channels,
36
+ out_channels,
37
+ kernel_size,
38
+ stride=stride,
39
+ dilation=dilation,
40
+ padding=padding,
41
+ padding_mode="zeros",
42
+ groups=groups,
43
+ )
44
+
45
+ def forward(self, x, causal: bool = True):
46
+ if causal:
47
+ first_frame_pad = x[:, :, :1, :, :].repeat(
48
+ (1, 1, self.time_kernel_size - 1, 1, 1)
49
+ )
50
+ x = torch.concatenate((first_frame_pad, x), dim=2)
51
+ else:
52
+ first_frame_pad = x[:, :, :1, :, :].repeat(
53
+ (1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
54
+ )
55
+ last_frame_pad = x[:, :, -1:, :, :].repeat(
56
+ (1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
57
+ )
58
+ x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
59
+ x = self.conv(x)
60
+ return x
61
+
62
+ @property
63
+ def weight(self):
64
+ return self.conv.weight
vae (2)/causal_video_autoencoder.py ADDED
@@ -0,0 +1,907 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from functools import partial
4
+ import math
5
+ from einops import rearrange
6
+ from typing import Optional, Tuple, Union
7
+ from .conv_nd_factory import make_conv_nd, make_linear_nd
8
+ from .pixel_norm import PixelNorm
9
+ from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
10
+ import comfy.ops
11
+ ops = comfy.ops.disable_weight_init
12
+
13
+ class Encoder(nn.Module):
14
+ r"""
15
+ The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
16
+
17
+ Args:
18
+ dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
19
+ The number of dimensions to use in convolutions.
20
+ in_channels (`int`, *optional*, defaults to 3):
21
+ The number of input channels.
22
+ out_channels (`int`, *optional*, defaults to 3):
23
+ The number of output channels.
24
+ blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
25
+ The blocks to use. Each block is a tuple of the block name and the number of layers.
26
+ base_channels (`int`, *optional*, defaults to 128):
27
+ The number of output channels for the first convolutional layer.
28
+ norm_num_groups (`int`, *optional*, defaults to 32):
29
+ The number of groups for normalization.
30
+ patch_size (`int`, *optional*, defaults to 1):
31
+ The patch size to use. Should be a power of 2.
32
+ norm_layer (`str`, *optional*, defaults to `group_norm`):
33
+ The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
34
+ latent_log_var (`str`, *optional*, defaults to `per_channel`):
35
+ The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ dims: Union[int, Tuple[int, int]] = 3,
41
+ in_channels: int = 3,
42
+ out_channels: int = 3,
43
+ blocks=[("res_x", 1)],
44
+ base_channels: int = 128,
45
+ norm_num_groups: int = 32,
46
+ patch_size: Union[int, Tuple[int]] = 1,
47
+ norm_layer: str = "group_norm", # group_norm, pixel_norm
48
+ latent_log_var: str = "per_channel",
49
+ ):
50
+ super().__init__()
51
+ self.patch_size = patch_size
52
+ self.norm_layer = norm_layer
53
+ self.latent_channels = out_channels
54
+ self.latent_log_var = latent_log_var
55
+ self.blocks_desc = blocks
56
+
57
+ in_channels = in_channels * patch_size**2
58
+ output_channel = base_channels
59
+
60
+ self.conv_in = make_conv_nd(
61
+ dims=dims,
62
+ in_channels=in_channels,
63
+ out_channels=output_channel,
64
+ kernel_size=3,
65
+ stride=1,
66
+ padding=1,
67
+ causal=True,
68
+ )
69
+
70
+ self.down_blocks = nn.ModuleList([])
71
+
72
+ for block_name, block_params in blocks:
73
+ input_channel = output_channel
74
+ if isinstance(block_params, int):
75
+ block_params = {"num_layers": block_params}
76
+
77
+ if block_name == "res_x":
78
+ block = UNetMidBlock3D(
79
+ dims=dims,
80
+ in_channels=input_channel,
81
+ num_layers=block_params["num_layers"],
82
+ resnet_eps=1e-6,
83
+ resnet_groups=norm_num_groups,
84
+ norm_layer=norm_layer,
85
+ )
86
+ elif block_name == "res_x_y":
87
+ output_channel = block_params.get("multiplier", 2) * output_channel
88
+ block = ResnetBlock3D(
89
+ dims=dims,
90
+ in_channels=input_channel,
91
+ out_channels=output_channel,
92
+ eps=1e-6,
93
+ groups=norm_num_groups,
94
+ norm_layer=norm_layer,
95
+ )
96
+ elif block_name == "compress_time":
97
+ block = make_conv_nd(
98
+ dims=dims,
99
+ in_channels=input_channel,
100
+ out_channels=output_channel,
101
+ kernel_size=3,
102
+ stride=(2, 1, 1),
103
+ causal=True,
104
+ )
105
+ elif block_name == "compress_space":
106
+ block = make_conv_nd(
107
+ dims=dims,
108
+ in_channels=input_channel,
109
+ out_channels=output_channel,
110
+ kernel_size=3,
111
+ stride=(1, 2, 2),
112
+ causal=True,
113
+ )
114
+ elif block_name == "compress_all":
115
+ block = make_conv_nd(
116
+ dims=dims,
117
+ in_channels=input_channel,
118
+ out_channels=output_channel,
119
+ kernel_size=3,
120
+ stride=(2, 2, 2),
121
+ causal=True,
122
+ )
123
+ elif block_name == "compress_all_x_y":
124
+ output_channel = block_params.get("multiplier", 2) * output_channel
125
+ block = make_conv_nd(
126
+ dims=dims,
127
+ in_channels=input_channel,
128
+ out_channels=output_channel,
129
+ kernel_size=3,
130
+ stride=(2, 2, 2),
131
+ causal=True,
132
+ )
133
+ else:
134
+ raise ValueError(f"unknown block: {block_name}")
135
+
136
+ self.down_blocks.append(block)
137
+
138
+ # out
139
+ if norm_layer == "group_norm":
140
+ self.conv_norm_out = nn.GroupNorm(
141
+ num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
142
+ )
143
+ elif norm_layer == "pixel_norm":
144
+ self.conv_norm_out = PixelNorm()
145
+ elif norm_layer == "layer_norm":
146
+ self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
147
+
148
+ self.conv_act = nn.SiLU()
149
+
150
+ conv_out_channels = out_channels
151
+ if latent_log_var == "per_channel":
152
+ conv_out_channels *= 2
153
+ elif latent_log_var == "uniform":
154
+ conv_out_channels += 1
155
+ elif latent_log_var != "none":
156
+ raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
157
+ self.conv_out = make_conv_nd(
158
+ dims, output_channel, conv_out_channels, 3, padding=1, causal=True
159
+ )
160
+
161
+ self.gradient_checkpointing = False
162
+
163
+ def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
164
+ r"""The forward method of the `Encoder` class."""
165
+
166
+ sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
167
+ sample = self.conv_in(sample)
168
+
169
+ checkpoint_fn = (
170
+ partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
171
+ if self.gradient_checkpointing and self.training
172
+ else lambda x: x
173
+ )
174
+
175
+ for down_block in self.down_blocks:
176
+ sample = checkpoint_fn(down_block)(sample)
177
+
178
+ sample = self.conv_norm_out(sample)
179
+ sample = self.conv_act(sample)
180
+ sample = self.conv_out(sample)
181
+
182
+ if self.latent_log_var == "uniform":
183
+ last_channel = sample[:, -1:, ...]
184
+ num_dims = sample.dim()
185
+
186
+ if num_dims == 4:
187
+ # For shape (B, C, H, W)
188
+ repeated_last_channel = last_channel.repeat(
189
+ 1, sample.shape[1] - 2, 1, 1
190
+ )
191
+ sample = torch.cat([sample, repeated_last_channel], dim=1)
192
+ elif num_dims == 5:
193
+ # For shape (B, C, F, H, W)
194
+ repeated_last_channel = last_channel.repeat(
195
+ 1, sample.shape[1] - 2, 1, 1, 1
196
+ )
197
+ sample = torch.cat([sample, repeated_last_channel], dim=1)
198
+ else:
199
+ raise ValueError(f"Invalid input shape: {sample.shape}")
200
+
201
+ return sample
202
+
203
+
204
+ class Decoder(nn.Module):
205
+ r"""
206
+ The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
207
+
208
+ Args:
209
+ dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
210
+ The number of dimensions to use in convolutions.
211
+ in_channels (`int`, *optional*, defaults to 3):
212
+ The number of input channels.
213
+ out_channels (`int`, *optional*, defaults to 3):
214
+ The number of output channels.
215
+ blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
216
+ The blocks to use. Each block is a tuple of the block name and the number of layers.
217
+ base_channels (`int`, *optional*, defaults to 128):
218
+ The number of output channels for the first convolutional layer.
219
+ norm_num_groups (`int`, *optional*, defaults to 32):
220
+ The number of groups for normalization.
221
+ patch_size (`int`, *optional*, defaults to 1):
222
+ The patch size to use. Should be a power of 2.
223
+ norm_layer (`str`, *optional*, defaults to `group_norm`):
224
+ The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
225
+ causal (`bool`, *optional*, defaults to `True`):
226
+ Whether to use causal convolutions or not.
227
+ """
228
+
229
+ def __init__(
230
+ self,
231
+ dims,
232
+ in_channels: int = 3,
233
+ out_channels: int = 3,
234
+ blocks=[("res_x", 1)],
235
+ base_channels: int = 128,
236
+ layers_per_block: int = 2,
237
+ norm_num_groups: int = 32,
238
+ patch_size: int = 1,
239
+ norm_layer: str = "group_norm",
240
+ causal: bool = True,
241
+ timestep_conditioning: bool = False,
242
+ ):
243
+ super().__init__()
244
+ self.patch_size = patch_size
245
+ self.layers_per_block = layers_per_block
246
+ out_channels = out_channels * patch_size**2
247
+ self.causal = causal
248
+ self.blocks_desc = blocks
249
+
250
+ # Compute output channel to be product of all channel-multiplier blocks
251
+ output_channel = base_channels
252
+ for block_name, block_params in list(reversed(blocks)):
253
+ block_params = block_params if isinstance(block_params, dict) else {}
254
+ if block_name == "res_x_y":
255
+ output_channel = output_channel * block_params.get("multiplier", 2)
256
+ if block_name == "compress_all":
257
+ output_channel = output_channel * block_params.get("multiplier", 1)
258
+
259
+ self.conv_in = make_conv_nd(
260
+ dims,
261
+ in_channels,
262
+ output_channel,
263
+ kernel_size=3,
264
+ stride=1,
265
+ padding=1,
266
+ causal=True,
267
+ )
268
+
269
+ self.up_blocks = nn.ModuleList([])
270
+
271
+ for block_name, block_params in list(reversed(blocks)):
272
+ input_channel = output_channel
273
+ if isinstance(block_params, int):
274
+ block_params = {"num_layers": block_params}
275
+
276
+ if block_name == "res_x":
277
+ block = UNetMidBlock3D(
278
+ dims=dims,
279
+ in_channels=input_channel,
280
+ num_layers=block_params["num_layers"],
281
+ resnet_eps=1e-6,
282
+ resnet_groups=norm_num_groups,
283
+ norm_layer=norm_layer,
284
+ inject_noise=block_params.get("inject_noise", False),
285
+ timestep_conditioning=timestep_conditioning,
286
+ )
287
+ elif block_name == "attn_res_x":
288
+ block = UNetMidBlock3D(
289
+ dims=dims,
290
+ in_channels=input_channel,
291
+ num_layers=block_params["num_layers"],
292
+ resnet_groups=norm_num_groups,
293
+ norm_layer=norm_layer,
294
+ inject_noise=block_params.get("inject_noise", False),
295
+ timestep_conditioning=timestep_conditioning,
296
+ attention_head_dim=block_params["attention_head_dim"],
297
+ )
298
+ elif block_name == "res_x_y":
299
+ output_channel = output_channel // block_params.get("multiplier", 2)
300
+ block = ResnetBlock3D(
301
+ dims=dims,
302
+ in_channels=input_channel,
303
+ out_channels=output_channel,
304
+ eps=1e-6,
305
+ groups=norm_num_groups,
306
+ norm_layer=norm_layer,
307
+ inject_noise=block_params.get("inject_noise", False),
308
+ timestep_conditioning=False,
309
+ )
310
+ elif block_name == "compress_time":
311
+ block = DepthToSpaceUpsample(
312
+ dims=dims, in_channels=input_channel, stride=(2, 1, 1)
313
+ )
314
+ elif block_name == "compress_space":
315
+ block = DepthToSpaceUpsample(
316
+ dims=dims, in_channels=input_channel, stride=(1, 2, 2)
317
+ )
318
+ elif block_name == "compress_all":
319
+ output_channel = output_channel // block_params.get("multiplier", 1)
320
+ block = DepthToSpaceUpsample(
321
+ dims=dims,
322
+ in_channels=input_channel,
323
+ stride=(2, 2, 2),
324
+ residual=block_params.get("residual", False),
325
+ out_channels_reduction_factor=block_params.get("multiplier", 1),
326
+ )
327
+ else:
328
+ raise ValueError(f"unknown layer: {block_name}")
329
+
330
+ self.up_blocks.append(block)
331
+
332
+ if norm_layer == "group_norm":
333
+ self.conv_norm_out = nn.GroupNorm(
334
+ num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
335
+ )
336
+ elif norm_layer == "pixel_norm":
337
+ self.conv_norm_out = PixelNorm()
338
+ elif norm_layer == "layer_norm":
339
+ self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
340
+
341
+ self.conv_act = nn.SiLU()
342
+ self.conv_out = make_conv_nd(
343
+ dims, output_channel, out_channels, 3, padding=1, causal=True
344
+ )
345
+
346
+ self.gradient_checkpointing = False
347
+
348
+ self.timestep_conditioning = timestep_conditioning
349
+
350
+ if timestep_conditioning:
351
+ self.timestep_scale_multiplier = nn.Parameter(
352
+ torch.tensor(1000.0, dtype=torch.float32)
353
+ )
354
+ self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
355
+ output_channel * 2, 0, operations=ops,
356
+ )
357
+ self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
358
+
359
+ # def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
360
+ def forward(
361
+ self,
362
+ sample: torch.FloatTensor,
363
+ timestep: Optional[torch.Tensor] = None,
364
+ ) -> torch.FloatTensor:
365
+ r"""The forward method of the `Decoder` class."""
366
+ batch_size = sample.shape[0]
367
+
368
+ sample = self.conv_in(sample, causal=self.causal)
369
+
370
+ checkpoint_fn = (
371
+ partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
372
+ if self.gradient_checkpointing and self.training
373
+ else lambda x: x
374
+ )
375
+
376
+ scaled_timestep = None
377
+ if self.timestep_conditioning:
378
+ assert (
379
+ timestep is not None
380
+ ), "should pass timestep with timestep_conditioning=True"
381
+ scaled_timestep = timestep * self.timestep_scale_multiplier.to(dtype=sample.dtype, device=sample.device)
382
+
383
+ for up_block in self.up_blocks:
384
+ if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
385
+ sample = checkpoint_fn(up_block)(
386
+ sample, causal=self.causal, timestep=scaled_timestep
387
+ )
388
+ else:
389
+ sample = checkpoint_fn(up_block)(sample, causal=self.causal)
390
+
391
+ sample = self.conv_norm_out(sample)
392
+
393
+ if self.timestep_conditioning:
394
+ embedded_timestep = self.last_time_embedder(
395
+ timestep=scaled_timestep.flatten(),
396
+ resolution=None,
397
+ aspect_ratio=None,
398
+ batch_size=sample.shape[0],
399
+ hidden_dtype=sample.dtype,
400
+ )
401
+ embedded_timestep = embedded_timestep.view(
402
+ batch_size, embedded_timestep.shape[-1], 1, 1, 1
403
+ )
404
+ ada_values = self.last_scale_shift_table[
405
+ None, ..., None, None, None
406
+ ].to(device=sample.device, dtype=sample.dtype) + embedded_timestep.reshape(
407
+ batch_size,
408
+ 2,
409
+ -1,
410
+ embedded_timestep.shape[-3],
411
+ embedded_timestep.shape[-2],
412
+ embedded_timestep.shape[-1],
413
+ )
414
+ shift, scale = ada_values.unbind(dim=1)
415
+ sample = sample * (1 + scale) + shift
416
+
417
+ sample = self.conv_act(sample)
418
+ sample = self.conv_out(sample, causal=self.causal)
419
+
420
+ sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
421
+
422
+ return sample
423
+
424
+
425
+ class UNetMidBlock3D(nn.Module):
426
+ """
427
+ A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
428
+
429
+ Args:
430
+ in_channels (`int`): The number of input channels.
431
+ dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
432
+ num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
433
+ resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
434
+ resnet_groups (`int`, *optional*, defaults to 32):
435
+ The number of groups to use in the group normalization layers of the resnet blocks.
436
+
437
+ Returns:
438
+ `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
439
+ in_channels, height, width)`.
440
+
441
+ """
442
+
443
+ def __init__(
444
+ self,
445
+ dims: Union[int, Tuple[int, int]],
446
+ in_channels: int,
447
+ dropout: float = 0.0,
448
+ num_layers: int = 1,
449
+ resnet_eps: float = 1e-6,
450
+ resnet_groups: int = 32,
451
+ norm_layer: str = "group_norm",
452
+ inject_noise: bool = False,
453
+ timestep_conditioning: bool = False,
454
+ ):
455
+ super().__init__()
456
+ resnet_groups = (
457
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
458
+ )
459
+
460
+ self.timestep_conditioning = timestep_conditioning
461
+
462
+ if timestep_conditioning:
463
+ self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
464
+ in_channels * 4, 0, operations=ops,
465
+ )
466
+
467
+ self.res_blocks = nn.ModuleList(
468
+ [
469
+ ResnetBlock3D(
470
+ dims=dims,
471
+ in_channels=in_channels,
472
+ out_channels=in_channels,
473
+ eps=resnet_eps,
474
+ groups=resnet_groups,
475
+ dropout=dropout,
476
+ norm_layer=norm_layer,
477
+ inject_noise=inject_noise,
478
+ timestep_conditioning=timestep_conditioning,
479
+ )
480
+ for _ in range(num_layers)
481
+ ]
482
+ )
483
+
484
+ def forward(
485
+ self, hidden_states: torch.FloatTensor, causal: bool = True, timestep: Optional[torch.Tensor] = None
486
+ ) -> torch.FloatTensor:
487
+ timestep_embed = None
488
+ if self.timestep_conditioning:
489
+ assert (
490
+ timestep is not None
491
+ ), "should pass timestep with timestep_conditioning=True"
492
+ batch_size = hidden_states.shape[0]
493
+ timestep_embed = self.time_embedder(
494
+ timestep=timestep.flatten(),
495
+ resolution=None,
496
+ aspect_ratio=None,
497
+ batch_size=batch_size,
498
+ hidden_dtype=hidden_states.dtype,
499
+ )
500
+ timestep_embed = timestep_embed.view(
501
+ batch_size, timestep_embed.shape[-1], 1, 1, 1
502
+ )
503
+
504
+ for resnet in self.res_blocks:
505
+ hidden_states = resnet(hidden_states, causal=causal, timestep=timestep_embed)
506
+
507
+ return hidden_states
508
+
509
+
510
+ class DepthToSpaceUpsample(nn.Module):
511
+ def __init__(
512
+ self, dims, in_channels, stride, residual=False, out_channels_reduction_factor=1
513
+ ):
514
+ super().__init__()
515
+ self.stride = stride
516
+ self.out_channels = (
517
+ math.prod(stride) * in_channels // out_channels_reduction_factor
518
+ )
519
+ self.conv = make_conv_nd(
520
+ dims=dims,
521
+ in_channels=in_channels,
522
+ out_channels=self.out_channels,
523
+ kernel_size=3,
524
+ stride=1,
525
+ causal=True,
526
+ )
527
+ self.residual = residual
528
+ self.out_channels_reduction_factor = out_channels_reduction_factor
529
+
530
+ def forward(self, x, causal: bool = True, timestep: Optional[torch.Tensor] = None):
531
+ if self.residual:
532
+ # Reshape and duplicate the input to match the output shape
533
+ x_in = rearrange(
534
+ x,
535
+ "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
536
+ p1=self.stride[0],
537
+ p2=self.stride[1],
538
+ p3=self.stride[2],
539
+ )
540
+ num_repeat = math.prod(self.stride) // self.out_channels_reduction_factor
541
+ x_in = x_in.repeat(1, num_repeat, 1, 1, 1)
542
+ if self.stride[0] == 2:
543
+ x_in = x_in[:, :, 1:, :, :]
544
+ x = self.conv(x, causal=causal)
545
+ x = rearrange(
546
+ x,
547
+ "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
548
+ p1=self.stride[0],
549
+ p2=self.stride[1],
550
+ p3=self.stride[2],
551
+ )
552
+ if self.stride[0] == 2:
553
+ x = x[:, :, 1:, :, :]
554
+ if self.residual:
555
+ x = x + x_in
556
+ return x
557
+
558
+ class LayerNorm(nn.Module):
559
+ def __init__(self, dim, eps, elementwise_affine=True) -> None:
560
+ super().__init__()
561
+ self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
562
+
563
+ def forward(self, x):
564
+ x = rearrange(x, "b c d h w -> b d h w c")
565
+ x = self.norm(x)
566
+ x = rearrange(x, "b d h w c -> b c d h w")
567
+ return x
568
+
569
+
570
+ class ResnetBlock3D(nn.Module):
571
+ r"""
572
+ A Resnet block.
573
+
574
+ Parameters:
575
+ in_channels (`int`): The number of channels in the input.
576
+ out_channels (`int`, *optional*, default to be `None`):
577
+ The number of output channels for the first conv layer. If None, same as `in_channels`.
578
+ dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
579
+ groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
580
+ eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
581
+ """
582
+
583
+ def __init__(
584
+ self,
585
+ dims: Union[int, Tuple[int, int]],
586
+ in_channels: int,
587
+ out_channels: Optional[int] = None,
588
+ dropout: float = 0.0,
589
+ groups: int = 32,
590
+ eps: float = 1e-6,
591
+ norm_layer: str = "group_norm",
592
+ inject_noise: bool = False,
593
+ timestep_conditioning: bool = False,
594
+ ):
595
+ super().__init__()
596
+ self.in_channels = in_channels
597
+ out_channels = in_channels if out_channels is None else out_channels
598
+ self.out_channels = out_channels
599
+ self.inject_noise = inject_noise
600
+
601
+ if norm_layer == "group_norm":
602
+ self.norm1 = nn.GroupNorm(
603
+ num_groups=groups, num_channels=in_channels, eps=eps, affine=True
604
+ )
605
+ elif norm_layer == "pixel_norm":
606
+ self.norm1 = PixelNorm()
607
+ elif norm_layer == "layer_norm":
608
+ self.norm1 = LayerNorm(in_channels, eps=eps, elementwise_affine=True)
609
+
610
+ self.non_linearity = nn.SiLU()
611
+
612
+ self.conv1 = make_conv_nd(
613
+ dims,
614
+ in_channels,
615
+ out_channels,
616
+ kernel_size=3,
617
+ stride=1,
618
+ padding=1,
619
+ causal=True,
620
+ )
621
+
622
+ if inject_noise:
623
+ self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
624
+
625
+ if norm_layer == "group_norm":
626
+ self.norm2 = nn.GroupNorm(
627
+ num_groups=groups, num_channels=out_channels, eps=eps, affine=True
628
+ )
629
+ elif norm_layer == "pixel_norm":
630
+ self.norm2 = PixelNorm()
631
+ elif norm_layer == "layer_norm":
632
+ self.norm2 = LayerNorm(out_channels, eps=eps, elementwise_affine=True)
633
+
634
+ self.dropout = torch.nn.Dropout(dropout)
635
+
636
+ self.conv2 = make_conv_nd(
637
+ dims,
638
+ out_channels,
639
+ out_channels,
640
+ kernel_size=3,
641
+ stride=1,
642
+ padding=1,
643
+ causal=True,
644
+ )
645
+
646
+ if inject_noise:
647
+ self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
648
+
649
+ self.conv_shortcut = (
650
+ make_linear_nd(
651
+ dims=dims, in_channels=in_channels, out_channels=out_channels
652
+ )
653
+ if in_channels != out_channels
654
+ else nn.Identity()
655
+ )
656
+
657
+ self.norm3 = (
658
+ LayerNorm(in_channels, eps=eps, elementwise_affine=True)
659
+ if in_channels != out_channels
660
+ else nn.Identity()
661
+ )
662
+
663
+ self.timestep_conditioning = timestep_conditioning
664
+
665
+ if timestep_conditioning:
666
+ self.scale_shift_table = nn.Parameter(
667
+ torch.randn(4, in_channels) / in_channels**0.5
668
+ )
669
+
670
+ def _feed_spatial_noise(
671
+ self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor
672
+ ) -> torch.FloatTensor:
673
+ spatial_shape = hidden_states.shape[-2:]
674
+ device = hidden_states.device
675
+ dtype = hidden_states.dtype
676
+
677
+ # similar to the "explicit noise inputs" method in style-gan
678
+ spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype)[None]
679
+ scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...]
680
+ hidden_states = hidden_states + scaled_noise
681
+
682
+ return hidden_states
683
+
684
+ def forward(
685
+ self,
686
+ input_tensor: torch.FloatTensor,
687
+ causal: bool = True,
688
+ timestep: Optional[torch.Tensor] = None,
689
+ ) -> torch.FloatTensor:
690
+ hidden_states = input_tensor
691
+ batch_size = hidden_states.shape[0]
692
+
693
+ hidden_states = self.norm1(hidden_states)
694
+ if self.timestep_conditioning:
695
+ assert (
696
+ timestep is not None
697
+ ), "should pass timestep with timestep_conditioning=True"
698
+ ada_values = self.scale_shift_table[
699
+ None, ..., None, None, None
700
+ ].to(device=hidden_states.device, dtype=hidden_states.dtype) + timestep.reshape(
701
+ batch_size,
702
+ 4,
703
+ -1,
704
+ timestep.shape[-3],
705
+ timestep.shape[-2],
706
+ timestep.shape[-1],
707
+ )
708
+ shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1)
709
+
710
+ hidden_states = hidden_states * (1 + scale1) + shift1
711
+
712
+ hidden_states = self.non_linearity(hidden_states)
713
+
714
+ hidden_states = self.conv1(hidden_states, causal=causal)
715
+
716
+ if self.inject_noise:
717
+ hidden_states = self._feed_spatial_noise(
718
+ hidden_states, self.per_channel_scale1.to(device=hidden_states.device, dtype=hidden_states.dtype)
719
+ )
720
+
721
+ hidden_states = self.norm2(hidden_states)
722
+
723
+ if self.timestep_conditioning:
724
+ hidden_states = hidden_states * (1 + scale2) + shift2
725
+
726
+ hidden_states = self.non_linearity(hidden_states)
727
+
728
+ hidden_states = self.dropout(hidden_states)
729
+
730
+ hidden_states = self.conv2(hidden_states, causal=causal)
731
+
732
+ if self.inject_noise:
733
+ hidden_states = self._feed_spatial_noise(
734
+ hidden_states, self.per_channel_scale2.to(device=hidden_states.device, dtype=hidden_states.dtype)
735
+ )
736
+
737
+ input_tensor = self.norm3(input_tensor)
738
+
739
+ batch_size = input_tensor.shape[0]
740
+
741
+ input_tensor = self.conv_shortcut(input_tensor)
742
+
743
+ output_tensor = input_tensor + hidden_states
744
+
745
+ return output_tensor
746
+
747
+
748
+ def patchify(x, patch_size_hw, patch_size_t=1):
749
+ if patch_size_hw == 1 and patch_size_t == 1:
750
+ return x
751
+ if x.dim() == 4:
752
+ x = rearrange(
753
+ x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw
754
+ )
755
+ elif x.dim() == 5:
756
+ x = rearrange(
757
+ x,
758
+ "b c (f p) (h q) (w r) -> b (c p r q) f h w",
759
+ p=patch_size_t,
760
+ q=patch_size_hw,
761
+ r=patch_size_hw,
762
+ )
763
+ else:
764
+ raise ValueError(f"Invalid input shape: {x.shape}")
765
+
766
+ return x
767
+
768
+
769
+ def unpatchify(x, patch_size_hw, patch_size_t=1):
770
+ if patch_size_hw == 1 and patch_size_t == 1:
771
+ return x
772
+
773
+ if x.dim() == 4:
774
+ x = rearrange(
775
+ x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw
776
+ )
777
+ elif x.dim() == 5:
778
+ x = rearrange(
779
+ x,
780
+ "b (c p r q) f h w -> b c (f p) (h q) (w r)",
781
+ p=patch_size_t,
782
+ q=patch_size_hw,
783
+ r=patch_size_hw,
784
+ )
785
+
786
+ return x
787
+
788
+ class processor(nn.Module):
789
+ def __init__(self):
790
+ super().__init__()
791
+ self.register_buffer("std-of-means", torch.empty(128))
792
+ self.register_buffer("mean-of-means", torch.empty(128))
793
+ self.register_buffer("mean-of-stds", torch.empty(128))
794
+ self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(128))
795
+ self.register_buffer("channel", torch.empty(128))
796
+
797
+ def un_normalize(self, x):
798
+ return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)
799
+
800
+ def normalize(self, x):
801
+ return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
802
+
803
+ class VideoVAE(nn.Module):
804
+ def __init__(self, version=0):
805
+ super().__init__()
806
+
807
+ if version == 0:
808
+ config = {
809
+ "_class_name": "CausalVideoAutoencoder",
810
+ "dims": 3,
811
+ "in_channels": 3,
812
+ "out_channels": 3,
813
+ "latent_channels": 128,
814
+ "blocks": [
815
+ ["res_x", 4],
816
+ ["compress_all", 1],
817
+ ["res_x_y", 1],
818
+ ["res_x", 3],
819
+ ["compress_all", 1],
820
+ ["res_x_y", 1],
821
+ ["res_x", 3],
822
+ ["compress_all", 1],
823
+ ["res_x", 3],
824
+ ["res_x", 4],
825
+ ],
826
+ "scaling_factor": 1.0,
827
+ "norm_layer": "pixel_norm",
828
+ "patch_size": 4,
829
+ "latent_log_var": "uniform",
830
+ "use_quant_conv": False,
831
+ "causal_decoder": False,
832
+ }
833
+ else:
834
+ config = {
835
+ "_class_name": "CausalVideoAutoencoder",
836
+ "dims": 3,
837
+ "in_channels": 3,
838
+ "out_channels": 3,
839
+ "latent_channels": 128,
840
+ "decoder_blocks": [
841
+ ["res_x", {"num_layers": 5, "inject_noise": True}],
842
+ ["compress_all", {"residual": True, "multiplier": 2}],
843
+ ["res_x", {"num_layers": 6, "inject_noise": True}],
844
+ ["compress_all", {"residual": True, "multiplier": 2}],
845
+ ["res_x", {"num_layers": 7, "inject_noise": True}],
846
+ ["compress_all", {"residual": True, "multiplier": 2}],
847
+ ["res_x", {"num_layers": 8, "inject_noise": False}]
848
+ ],
849
+ "encoder_blocks": [
850
+ ["res_x", {"num_layers": 4}],
851
+ ["compress_all", {}],
852
+ ["res_x_y", 1],
853
+ ["res_x", {"num_layers": 3}],
854
+ ["compress_all", {}],
855
+ ["res_x_y", 1],
856
+ ["res_x", {"num_layers": 3}],
857
+ ["compress_all", {}],
858
+ ["res_x", {"num_layers": 3}],
859
+ ["res_x", {"num_layers": 4}]
860
+ ],
861
+ "scaling_factor": 1.0,
862
+ "norm_layer": "pixel_norm",
863
+ "patch_size": 4,
864
+ "latent_log_var": "uniform",
865
+ "use_quant_conv": False,
866
+ "causal_decoder": False,
867
+ "timestep_conditioning": True,
868
+ }
869
+
870
+ double_z = config.get("double_z", True)
871
+ latent_log_var = config.get(
872
+ "latent_log_var", "per_channel" if double_z else "none"
873
+ )
874
+
875
+ self.encoder = Encoder(
876
+ dims=config["dims"],
877
+ in_channels=config.get("in_channels", 3),
878
+ out_channels=config["latent_channels"],
879
+ blocks=config.get("encoder_blocks", config.get("encoder_blocks", config.get("blocks"))),
880
+ patch_size=config.get("patch_size", 1),
881
+ latent_log_var=latent_log_var,
882
+ norm_layer=config.get("norm_layer", "group_norm"),
883
+ )
884
+
885
+ self.decoder = Decoder(
886
+ dims=config["dims"],
887
+ in_channels=config["latent_channels"],
888
+ out_channels=config.get("out_channels", 3),
889
+ blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
890
+ patch_size=config.get("patch_size", 1),
891
+ norm_layer=config.get("norm_layer", "group_norm"),
892
+ causal=config.get("causal_decoder", False),
893
+ timestep_conditioning=config.get("timestep_conditioning", False),
894
+ )
895
+
896
+ self.timestep_conditioning = config.get("timestep_conditioning", False)
897
+ self.per_channel_statistics = processor()
898
+
899
+ def encode(self, x):
900
+ means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
901
+ return self.per_channel_statistics.normalize(means)
902
+
903
+ def decode(self, x, timestep=0.05, noise_scale=0.025):
904
+ if self.timestep_conditioning: #TODO: seed
905
+ x = torch.randn_like(x) * noise_scale + (1.0 - noise_scale) * x
906
+ return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=timestep)
907
+
vae (2)/conv_nd_factory.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Union
2
+
3
+
4
+ from .dual_conv3d import DualConv3d
5
+ from .causal_conv3d import CausalConv3d
6
+ import comfy.ops
7
+ ops = comfy.ops.disable_weight_init
8
+
9
+ def make_conv_nd(
10
+ dims: Union[int, Tuple[int, int]],
11
+ in_channels: int,
12
+ out_channels: int,
13
+ kernel_size: int,
14
+ stride=1,
15
+ padding=0,
16
+ dilation=1,
17
+ groups=1,
18
+ bias=True,
19
+ causal=False,
20
+ ):
21
+ if dims == 2:
22
+ return ops.Conv2d(
23
+ in_channels=in_channels,
24
+ out_channels=out_channels,
25
+ kernel_size=kernel_size,
26
+ stride=stride,
27
+ padding=padding,
28
+ dilation=dilation,
29
+ groups=groups,
30
+ bias=bias,
31
+ )
32
+ elif dims == 3:
33
+ if causal:
34
+ return CausalConv3d(
35
+ in_channels=in_channels,
36
+ out_channels=out_channels,
37
+ kernel_size=kernel_size,
38
+ stride=stride,
39
+ padding=padding,
40
+ dilation=dilation,
41
+ groups=groups,
42
+ bias=bias,
43
+ )
44
+ return ops.Conv3d(
45
+ in_channels=in_channels,
46
+ out_channels=out_channels,
47
+ kernel_size=kernel_size,
48
+ stride=stride,
49
+ padding=padding,
50
+ dilation=dilation,
51
+ groups=groups,
52
+ bias=bias,
53
+ )
54
+ elif dims == (2, 1):
55
+ return DualConv3d(
56
+ in_channels=in_channels,
57
+ out_channels=out_channels,
58
+ kernel_size=kernel_size,
59
+ stride=stride,
60
+ padding=padding,
61
+ bias=bias,
62
+ )
63
+ else:
64
+ raise ValueError(f"unsupported dimensions: {dims}")
65
+
66
+
67
+ def make_linear_nd(
68
+ dims: int,
69
+ in_channels: int,
70
+ out_channels: int,
71
+ bias=True,
72
+ ):
73
+ if dims == 2:
74
+ return ops.Conv2d(
75
+ in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
76
+ )
77
+ elif dims == 3 or dims == (2, 1):
78
+ return ops.Conv3d(
79
+ in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
80
+ )
81
+ else:
82
+ raise ValueError(f"unsupported dimensions: {dims}")
vae (2)/dual_conv3d.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+
9
+
10
+ class DualConv3d(nn.Module):
11
+ def __init__(
12
+ self,
13
+ in_channels,
14
+ out_channels,
15
+ kernel_size,
16
+ stride: Union[int, Tuple[int, int, int]] = 1,
17
+ padding: Union[int, Tuple[int, int, int]] = 0,
18
+ dilation: Union[int, Tuple[int, int, int]] = 1,
19
+ groups=1,
20
+ bias=True,
21
+ ):
22
+ super(DualConv3d, self).__init__()
23
+
24
+ self.in_channels = in_channels
25
+ self.out_channels = out_channels
26
+ # Ensure kernel_size, stride, padding, and dilation are tuples of length 3
27
+ if isinstance(kernel_size, int):
28
+ kernel_size = (kernel_size, kernel_size, kernel_size)
29
+ if kernel_size == (1, 1, 1):
30
+ raise ValueError(
31
+ "kernel_size must be greater than 1. Use make_linear_nd instead."
32
+ )
33
+ if isinstance(stride, int):
34
+ stride = (stride, stride, stride)
35
+ if isinstance(padding, int):
36
+ padding = (padding, padding, padding)
37
+ if isinstance(dilation, int):
38
+ dilation = (dilation, dilation, dilation)
39
+
40
+ # Set parameters for convolutions
41
+ self.groups = groups
42
+ self.bias = bias
43
+
44
+ # Define the size of the channels after the first convolution
45
+ intermediate_channels = (
46
+ out_channels if in_channels < out_channels else in_channels
47
+ )
48
+
49
+ # Define parameters for the first convolution
50
+ self.weight1 = nn.Parameter(
51
+ torch.Tensor(
52
+ intermediate_channels,
53
+ in_channels // groups,
54
+ 1,
55
+ kernel_size[1],
56
+ kernel_size[2],
57
+ )
58
+ )
59
+ self.stride1 = (1, stride[1], stride[2])
60
+ self.padding1 = (0, padding[1], padding[2])
61
+ self.dilation1 = (1, dilation[1], dilation[2])
62
+ if bias:
63
+ self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels))
64
+ else:
65
+ self.register_parameter("bias1", None)
66
+
67
+ # Define parameters for the second convolution
68
+ self.weight2 = nn.Parameter(
69
+ torch.Tensor(
70
+ out_channels, intermediate_channels // groups, kernel_size[0], 1, 1
71
+ )
72
+ )
73
+ self.stride2 = (stride[0], 1, 1)
74
+ self.padding2 = (padding[0], 0, 0)
75
+ self.dilation2 = (dilation[0], 1, 1)
76
+ if bias:
77
+ self.bias2 = nn.Parameter(torch.Tensor(out_channels))
78
+ else:
79
+ self.register_parameter("bias2", None)
80
+
81
+ # Initialize weights and biases
82
+ self.reset_parameters()
83
+
84
+ def reset_parameters(self):
85
+ nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5))
86
+ nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5))
87
+ if self.bias:
88
+ fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1)
89
+ bound1 = 1 / math.sqrt(fan_in1)
90
+ nn.init.uniform_(self.bias1, -bound1, bound1)
91
+ fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2)
92
+ bound2 = 1 / math.sqrt(fan_in2)
93
+ nn.init.uniform_(self.bias2, -bound2, bound2)
94
+
95
+ def forward(self, x, use_conv3d=False, skip_time_conv=False):
96
+ if use_conv3d:
97
+ return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv)
98
+ else:
99
+ return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv)
100
+
101
+ def forward_with_3d(self, x, skip_time_conv):
102
+ # First convolution
103
+ x = F.conv3d(
104
+ x,
105
+ self.weight1,
106
+ self.bias1,
107
+ self.stride1,
108
+ self.padding1,
109
+ self.dilation1,
110
+ self.groups,
111
+ )
112
+
113
+ if skip_time_conv:
114
+ return x
115
+
116
+ # Second convolution
117
+ x = F.conv3d(
118
+ x,
119
+ self.weight2,
120
+ self.bias2,
121
+ self.stride2,
122
+ self.padding2,
123
+ self.dilation2,
124
+ self.groups,
125
+ )
126
+
127
+ return x
128
+
129
+ def forward_with_2d(self, x, skip_time_conv):
130
+ b, c, d, h, w = x.shape
131
+
132
+ # First 2D convolution
133
+ x = rearrange(x, "b c d h w -> (b d) c h w")
134
+ # Squeeze the depth dimension out of weight1 since it's 1
135
+ weight1 = self.weight1.squeeze(2)
136
+ # Select stride, padding, and dilation for the 2D convolution
137
+ stride1 = (self.stride1[1], self.stride1[2])
138
+ padding1 = (self.padding1[1], self.padding1[2])
139
+ dilation1 = (self.dilation1[1], self.dilation1[2])
140
+ x = F.conv2d(x, weight1, self.bias1, stride1, padding1, dilation1, self.groups)
141
+
142
+ _, _, h, w = x.shape
143
+
144
+ if skip_time_conv:
145
+ x = rearrange(x, "(b d) c h w -> b c d h w", b=b)
146
+ return x
147
+
148
+ # Second convolution which is essentially treated as a 1D convolution across the 'd' dimension
149
+ x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b)
150
+
151
+ # Reshape weight2 to match the expected dimensions for conv1d
152
+ weight2 = self.weight2.squeeze(-1).squeeze(-1)
153
+ # Use only the relevant dimension for stride, padding, and dilation for the 1D convolution
154
+ stride2 = self.stride2[0]
155
+ padding2 = self.padding2[0]
156
+ dilation2 = self.dilation2[0]
157
+ x = F.conv1d(x, weight2, self.bias2, stride2, padding2, dilation2, self.groups)
158
+ x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
159
+
160
+ return x
161
+
162
+ @property
163
+ def weight(self):
164
+ return self.weight2
165
+
166
+
167
+ def test_dual_conv3d_consistency():
168
+ # Initialize parameters
169
+ in_channels = 3
170
+ out_channels = 5
171
+ kernel_size = (3, 3, 3)
172
+ stride = (2, 2, 2)
173
+ padding = (1, 1, 1)
174
+
175
+ # Create an instance of the DualConv3d class
176
+ dual_conv3d = DualConv3d(
177
+ in_channels=in_channels,
178
+ out_channels=out_channels,
179
+ kernel_size=kernel_size,
180
+ stride=stride,
181
+ padding=padding,
182
+ bias=True,
183
+ )
184
+
185
+ # Example input tensor
186
+ test_input = torch.randn(1, 3, 10, 10, 10)
187
+
188
+ # Perform forward passes with both 3D and 2D settings
189
+ output_conv3d = dual_conv3d(test_input, use_conv3d=True)
190
+ output_2d = dual_conv3d(test_input, use_conv3d=False)
191
+
192
+ # Assert that the outputs from both methods are sufficiently close
193
+ assert torch.allclose(
194
+ output_conv3d, output_2d, atol=1e-6
195
+ ), "Outputs are not consistent between 3D and 2D convolutions."
vae (2)/pixel_norm.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class PixelNorm(nn.Module):
6
+ def __init__(self, dim=1, eps=1e-8):
7
+ super(PixelNorm, self).__init__()
8
+ self.dim = dim
9
+ self.eps = eps
10
+
11
+ def forward(self, x):
12
+ return x / torch.sqrt(torch.mean(x**2, dim=self.dim, keepdim=True) + self.eps)
vae (2)/put_vae_here ADDED
File without changes
vae/causal_conv3d.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import comfy.ops
6
+ ops = comfy.ops.disable_weight_init
7
+
8
+
9
+ class CausalConv3d(nn.Module):
10
+ def __init__(
11
+ self,
12
+ in_channels,
13
+ out_channels,
14
+ kernel_size: int = 3,
15
+ stride: Union[int, Tuple[int]] = 1,
16
+ dilation: int = 1,
17
+ groups: int = 1,
18
+ **kwargs,
19
+ ):
20
+ super().__init__()
21
+
22
+ self.in_channels = in_channels
23
+ self.out_channels = out_channels
24
+
25
+ kernel_size = (kernel_size, kernel_size, kernel_size)
26
+ self.time_kernel_size = kernel_size[0]
27
+
28
+ dilation = (dilation, 1, 1)
29
+
30
+ height_pad = kernel_size[1] // 2
31
+ width_pad = kernel_size[2] // 2
32
+ padding = (0, height_pad, width_pad)
33
+
34
+ self.conv = ops.Conv3d(
35
+ in_channels,
36
+ out_channels,
37
+ kernel_size,
38
+ stride=stride,
39
+ dilation=dilation,
40
+ padding=padding,
41
+ padding_mode="zeros",
42
+ groups=groups,
43
+ )
44
+
45
+ def forward(self, x, causal: bool = True):
46
+ if causal:
47
+ first_frame_pad = x[:, :, :1, :, :].repeat(
48
+ (1, 1, self.time_kernel_size - 1, 1, 1)
49
+ )
50
+ x = torch.concatenate((first_frame_pad, x), dim=2)
51
+ else:
52
+ first_frame_pad = x[:, :, :1, :, :].repeat(
53
+ (1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
54
+ )
55
+ last_frame_pad = x[:, :, -1:, :, :].repeat(
56
+ (1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
57
+ )
58
+ x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
59
+ x = self.conv(x)
60
+ return x
61
+
62
+ @property
63
+ def weight(self):
64
+ return self.conv.weight
vae/causal_video_autoencoder.py ADDED
@@ -0,0 +1,907 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from functools import partial
4
+ import math
5
+ from einops import rearrange
6
+ from typing import Optional, Tuple, Union
7
+ from .conv_nd_factory import make_conv_nd, make_linear_nd
8
+ from .pixel_norm import PixelNorm
9
+ from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
10
+ import comfy.ops
11
+ ops = comfy.ops.disable_weight_init
12
+
13
+ class Encoder(nn.Module):
14
+ r"""
15
+ The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
16
+
17
+ Args:
18
+ dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
19
+ The number of dimensions to use in convolutions.
20
+ in_channels (`int`, *optional*, defaults to 3):
21
+ The number of input channels.
22
+ out_channels (`int`, *optional*, defaults to 3):
23
+ The number of output channels.
24
+ blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
25
+ The blocks to use. Each block is a tuple of the block name and the number of layers.
26
+ base_channels (`int`, *optional*, defaults to 128):
27
+ The number of output channels for the first convolutional layer.
28
+ norm_num_groups (`int`, *optional*, defaults to 32):
29
+ The number of groups for normalization.
30
+ patch_size (`int`, *optional*, defaults to 1):
31
+ The patch size to use. Should be a power of 2.
32
+ norm_layer (`str`, *optional*, defaults to `group_norm`):
33
+ The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
34
+ latent_log_var (`str`, *optional*, defaults to `per_channel`):
35
+ The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ dims: Union[int, Tuple[int, int]] = 3,
41
+ in_channels: int = 3,
42
+ out_channels: int = 3,
43
+ blocks=[("res_x", 1)],
44
+ base_channels: int = 128,
45
+ norm_num_groups: int = 32,
46
+ patch_size: Union[int, Tuple[int]] = 1,
47
+ norm_layer: str = "group_norm", # group_norm, pixel_norm
48
+ latent_log_var: str = "per_channel",
49
+ ):
50
+ super().__init__()
51
+ self.patch_size = patch_size
52
+ self.norm_layer = norm_layer
53
+ self.latent_channels = out_channels
54
+ self.latent_log_var = latent_log_var
55
+ self.blocks_desc = blocks
56
+
57
+ in_channels = in_channels * patch_size**2
58
+ output_channel = base_channels
59
+
60
+ self.conv_in = make_conv_nd(
61
+ dims=dims,
62
+ in_channels=in_channels,
63
+ out_channels=output_channel,
64
+ kernel_size=3,
65
+ stride=1,
66
+ padding=1,
67
+ causal=True,
68
+ )
69
+
70
+ self.down_blocks = nn.ModuleList([])
71
+
72
+ for block_name, block_params in blocks:
73
+ input_channel = output_channel
74
+ if isinstance(block_params, int):
75
+ block_params = {"num_layers": block_params}
76
+
77
+ if block_name == "res_x":
78
+ block = UNetMidBlock3D(
79
+ dims=dims,
80
+ in_channels=input_channel,
81
+ num_layers=block_params["num_layers"],
82
+ resnet_eps=1e-6,
83
+ resnet_groups=norm_num_groups,
84
+ norm_layer=norm_layer,
85
+ )
86
+ elif block_name == "res_x_y":
87
+ output_channel = block_params.get("multiplier", 2) * output_channel
88
+ block = ResnetBlock3D(
89
+ dims=dims,
90
+ in_channels=input_channel,
91
+ out_channels=output_channel,
92
+ eps=1e-6,
93
+ groups=norm_num_groups,
94
+ norm_layer=norm_layer,
95
+ )
96
+ elif block_name == "compress_time":
97
+ block = make_conv_nd(
98
+ dims=dims,
99
+ in_channels=input_channel,
100
+ out_channels=output_channel,
101
+ kernel_size=3,
102
+ stride=(2, 1, 1),
103
+ causal=True,
104
+ )
105
+ elif block_name == "compress_space":
106
+ block = make_conv_nd(
107
+ dims=dims,
108
+ in_channels=input_channel,
109
+ out_channels=output_channel,
110
+ kernel_size=3,
111
+ stride=(1, 2, 2),
112
+ causal=True,
113
+ )
114
+ elif block_name == "compress_all":
115
+ block = make_conv_nd(
116
+ dims=dims,
117
+ in_channels=input_channel,
118
+ out_channels=output_channel,
119
+ kernel_size=3,
120
+ stride=(2, 2, 2),
121
+ causal=True,
122
+ )
123
+ elif block_name == "compress_all_x_y":
124
+ output_channel = block_params.get("multiplier", 2) * output_channel
125
+ block = make_conv_nd(
126
+ dims=dims,
127
+ in_channels=input_channel,
128
+ out_channels=output_channel,
129
+ kernel_size=3,
130
+ stride=(2, 2, 2),
131
+ causal=True,
132
+ )
133
+ else:
134
+ raise ValueError(f"unknown block: {block_name}")
135
+
136
+ self.down_blocks.append(block)
137
+
138
+ # out
139
+ if norm_layer == "group_norm":
140
+ self.conv_norm_out = nn.GroupNorm(
141
+ num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
142
+ )
143
+ elif norm_layer == "pixel_norm":
144
+ self.conv_norm_out = PixelNorm()
145
+ elif norm_layer == "layer_norm":
146
+ self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
147
+
148
+ self.conv_act = nn.SiLU()
149
+
150
+ conv_out_channels = out_channels
151
+ if latent_log_var == "per_channel":
152
+ conv_out_channels *= 2
153
+ elif latent_log_var == "uniform":
154
+ conv_out_channels += 1
155
+ elif latent_log_var != "none":
156
+ raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
157
+ self.conv_out = make_conv_nd(
158
+ dims, output_channel, conv_out_channels, 3, padding=1, causal=True
159
+ )
160
+
161
+ self.gradient_checkpointing = False
162
+
163
+ def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
164
+ r"""The forward method of the `Encoder` class."""
165
+
166
+ sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
167
+ sample = self.conv_in(sample)
168
+
169
+ checkpoint_fn = (
170
+ partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
171
+ if self.gradient_checkpointing and self.training
172
+ else lambda x: x
173
+ )
174
+
175
+ for down_block in self.down_blocks:
176
+ sample = checkpoint_fn(down_block)(sample)
177
+
178
+ sample = self.conv_norm_out(sample)
179
+ sample = self.conv_act(sample)
180
+ sample = self.conv_out(sample)
181
+
182
+ if self.latent_log_var == "uniform":
183
+ last_channel = sample[:, -1:, ...]
184
+ num_dims = sample.dim()
185
+
186
+ if num_dims == 4:
187
+ # For shape (B, C, H, W)
188
+ repeated_last_channel = last_channel.repeat(
189
+ 1, sample.shape[1] - 2, 1, 1
190
+ )
191
+ sample = torch.cat([sample, repeated_last_channel], dim=1)
192
+ elif num_dims == 5:
193
+ # For shape (B, C, F, H, W)
194
+ repeated_last_channel = last_channel.repeat(
195
+ 1, sample.shape[1] - 2, 1, 1, 1
196
+ )
197
+ sample = torch.cat([sample, repeated_last_channel], dim=1)
198
+ else:
199
+ raise ValueError(f"Invalid input shape: {sample.shape}")
200
+
201
+ return sample
202
+
203
+
204
+ class Decoder(nn.Module):
205
+ r"""
206
+ The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
207
+
208
+ Args:
209
+ dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
210
+ The number of dimensions to use in convolutions.
211
+ in_channels (`int`, *optional*, defaults to 3):
212
+ The number of input channels.
213
+ out_channels (`int`, *optional*, defaults to 3):
214
+ The number of output channels.
215
+ blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
216
+ The blocks to use. Each block is a tuple of the block name and the number of layers.
217
+ base_channels (`int`, *optional*, defaults to 128):
218
+ The number of output channels for the first convolutional layer.
219
+ norm_num_groups (`int`, *optional*, defaults to 32):
220
+ The number of groups for normalization.
221
+ patch_size (`int`, *optional*, defaults to 1):
222
+ The patch size to use. Should be a power of 2.
223
+ norm_layer (`str`, *optional*, defaults to `group_norm`):
224
+ The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
225
+ causal (`bool`, *optional*, defaults to `True`):
226
+ Whether to use causal convolutions or not.
227
+ """
228
+
229
+ def __init__(
230
+ self,
231
+ dims,
232
+ in_channels: int = 3,
233
+ out_channels: int = 3,
234
+ blocks=[("res_x", 1)],
235
+ base_channels: int = 128,
236
+ layers_per_block: int = 2,
237
+ norm_num_groups: int = 32,
238
+ patch_size: int = 1,
239
+ norm_layer: str = "group_norm",
240
+ causal: bool = True,
241
+ timestep_conditioning: bool = False,
242
+ ):
243
+ super().__init__()
244
+ self.patch_size = patch_size
245
+ self.layers_per_block = layers_per_block
246
+ out_channels = out_channels * patch_size**2
247
+ self.causal = causal
248
+ self.blocks_desc = blocks
249
+
250
+ # Compute output channel to be product of all channel-multiplier blocks
251
+ output_channel = base_channels
252
+ for block_name, block_params in list(reversed(blocks)):
253
+ block_params = block_params if isinstance(block_params, dict) else {}
254
+ if block_name == "res_x_y":
255
+ output_channel = output_channel * block_params.get("multiplier", 2)
256
+ if block_name == "compress_all":
257
+ output_channel = output_channel * block_params.get("multiplier", 1)
258
+
259
+ self.conv_in = make_conv_nd(
260
+ dims,
261
+ in_channels,
262
+ output_channel,
263
+ kernel_size=3,
264
+ stride=1,
265
+ padding=1,
266
+ causal=True,
267
+ )
268
+
269
+ self.up_blocks = nn.ModuleList([])
270
+
271
+ for block_name, block_params in list(reversed(blocks)):
272
+ input_channel = output_channel
273
+ if isinstance(block_params, int):
274
+ block_params = {"num_layers": block_params}
275
+
276
+ if block_name == "res_x":
277
+ block = UNetMidBlock3D(
278
+ dims=dims,
279
+ in_channels=input_channel,
280
+ num_layers=block_params["num_layers"],
281
+ resnet_eps=1e-6,
282
+ resnet_groups=norm_num_groups,
283
+ norm_layer=norm_layer,
284
+ inject_noise=block_params.get("inject_noise", False),
285
+ timestep_conditioning=timestep_conditioning,
286
+ )
287
+ elif block_name == "attn_res_x":
288
+ block = UNetMidBlock3D(
289
+ dims=dims,
290
+ in_channels=input_channel,
291
+ num_layers=block_params["num_layers"],
292
+ resnet_groups=norm_num_groups,
293
+ norm_layer=norm_layer,
294
+ inject_noise=block_params.get("inject_noise", False),
295
+ timestep_conditioning=timestep_conditioning,
296
+ attention_head_dim=block_params["attention_head_dim"],
297
+ )
298
+ elif block_name == "res_x_y":
299
+ output_channel = output_channel // block_params.get("multiplier", 2)
300
+ block = ResnetBlock3D(
301
+ dims=dims,
302
+ in_channels=input_channel,
303
+ out_channels=output_channel,
304
+ eps=1e-6,
305
+ groups=norm_num_groups,
306
+ norm_layer=norm_layer,
307
+ inject_noise=block_params.get("inject_noise", False),
308
+ timestep_conditioning=False,
309
+ )
310
+ elif block_name == "compress_time":
311
+ block = DepthToSpaceUpsample(
312
+ dims=dims, in_channels=input_channel, stride=(2, 1, 1)
313
+ )
314
+ elif block_name == "compress_space":
315
+ block = DepthToSpaceUpsample(
316
+ dims=dims, in_channels=input_channel, stride=(1, 2, 2)
317
+ )
318
+ elif block_name == "compress_all":
319
+ output_channel = output_channel // block_params.get("multiplier", 1)
320
+ block = DepthToSpaceUpsample(
321
+ dims=dims,
322
+ in_channels=input_channel,
323
+ stride=(2, 2, 2),
324
+ residual=block_params.get("residual", False),
325
+ out_channels_reduction_factor=block_params.get("multiplier", 1),
326
+ )
327
+ else:
328
+ raise ValueError(f"unknown layer: {block_name}")
329
+
330
+ self.up_blocks.append(block)
331
+
332
+ if norm_layer == "group_norm":
333
+ self.conv_norm_out = nn.GroupNorm(
334
+ num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
335
+ )
336
+ elif norm_layer == "pixel_norm":
337
+ self.conv_norm_out = PixelNorm()
338
+ elif norm_layer == "layer_norm":
339
+ self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
340
+
341
+ self.conv_act = nn.SiLU()
342
+ self.conv_out = make_conv_nd(
343
+ dims, output_channel, out_channels, 3, padding=1, causal=True
344
+ )
345
+
346
+ self.gradient_checkpointing = False
347
+
348
+ self.timestep_conditioning = timestep_conditioning
349
+
350
+ if timestep_conditioning:
351
+ self.timestep_scale_multiplier = nn.Parameter(
352
+ torch.tensor(1000.0, dtype=torch.float32)
353
+ )
354
+ self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
355
+ output_channel * 2, 0, operations=ops,
356
+ )
357
+ self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
358
+
359
+ # def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
360
+ def forward(
361
+ self,
362
+ sample: torch.FloatTensor,
363
+ timestep: Optional[torch.Tensor] = None,
364
+ ) -> torch.FloatTensor:
365
+ r"""The forward method of the `Decoder` class."""
366
+ batch_size = sample.shape[0]
367
+
368
+ sample = self.conv_in(sample, causal=self.causal)
369
+
370
+ checkpoint_fn = (
371
+ partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
372
+ if self.gradient_checkpointing and self.training
373
+ else lambda x: x
374
+ )
375
+
376
+ scaled_timestep = None
377
+ if self.timestep_conditioning:
378
+ assert (
379
+ timestep is not None
380
+ ), "should pass timestep with timestep_conditioning=True"
381
+ scaled_timestep = timestep * self.timestep_scale_multiplier.to(dtype=sample.dtype, device=sample.device)
382
+
383
+ for up_block in self.up_blocks:
384
+ if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
385
+ sample = checkpoint_fn(up_block)(
386
+ sample, causal=self.causal, timestep=scaled_timestep
387
+ )
388
+ else:
389
+ sample = checkpoint_fn(up_block)(sample, causal=self.causal)
390
+
391
+ sample = self.conv_norm_out(sample)
392
+
393
+ if self.timestep_conditioning:
394
+ embedded_timestep = self.last_time_embedder(
395
+ timestep=scaled_timestep.flatten(),
396
+ resolution=None,
397
+ aspect_ratio=None,
398
+ batch_size=sample.shape[0],
399
+ hidden_dtype=sample.dtype,
400
+ )
401
+ embedded_timestep = embedded_timestep.view(
402
+ batch_size, embedded_timestep.shape[-1], 1, 1, 1
403
+ )
404
+ ada_values = self.last_scale_shift_table[
405
+ None, ..., None, None, None
406
+ ].to(device=sample.device, dtype=sample.dtype) + embedded_timestep.reshape(
407
+ batch_size,
408
+ 2,
409
+ -1,
410
+ embedded_timestep.shape[-3],
411
+ embedded_timestep.shape[-2],
412
+ embedded_timestep.shape[-1],
413
+ )
414
+ shift, scale = ada_values.unbind(dim=1)
415
+ sample = sample * (1 + scale) + shift
416
+
417
+ sample = self.conv_act(sample)
418
+ sample = self.conv_out(sample, causal=self.causal)
419
+
420
+ sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
421
+
422
+ return sample
423
+
424
+
425
+ class UNetMidBlock3D(nn.Module):
426
+ """
427
+ A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
428
+
429
+ Args:
430
+ in_channels (`int`): The number of input channels.
431
+ dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
432
+ num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
433
+ resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
434
+ resnet_groups (`int`, *optional*, defaults to 32):
435
+ The number of groups to use in the group normalization layers of the resnet blocks.
436
+
437
+ Returns:
438
+ `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
439
+ in_channels, height, width)`.
440
+
441
+ """
442
+
443
+ def __init__(
444
+ self,
445
+ dims: Union[int, Tuple[int, int]],
446
+ in_channels: int,
447
+ dropout: float = 0.0,
448
+ num_layers: int = 1,
449
+ resnet_eps: float = 1e-6,
450
+ resnet_groups: int = 32,
451
+ norm_layer: str = "group_norm",
452
+ inject_noise: bool = False,
453
+ timestep_conditioning: bool = False,
454
+ ):
455
+ super().__init__()
456
+ resnet_groups = (
457
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
458
+ )
459
+
460
+ self.timestep_conditioning = timestep_conditioning
461
+
462
+ if timestep_conditioning:
463
+ self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
464
+ in_channels * 4, 0, operations=ops,
465
+ )
466
+
467
+ self.res_blocks = nn.ModuleList(
468
+ [
469
+ ResnetBlock3D(
470
+ dims=dims,
471
+ in_channels=in_channels,
472
+ out_channels=in_channels,
473
+ eps=resnet_eps,
474
+ groups=resnet_groups,
475
+ dropout=dropout,
476
+ norm_layer=norm_layer,
477
+ inject_noise=inject_noise,
478
+ timestep_conditioning=timestep_conditioning,
479
+ )
480
+ for _ in range(num_layers)
481
+ ]
482
+ )
483
+
484
+ def forward(
485
+ self, hidden_states: torch.FloatTensor, causal: bool = True, timestep: Optional[torch.Tensor] = None
486
+ ) -> torch.FloatTensor:
487
+ timestep_embed = None
488
+ if self.timestep_conditioning:
489
+ assert (
490
+ timestep is not None
491
+ ), "should pass timestep with timestep_conditioning=True"
492
+ batch_size = hidden_states.shape[0]
493
+ timestep_embed = self.time_embedder(
494
+ timestep=timestep.flatten(),
495
+ resolution=None,
496
+ aspect_ratio=None,
497
+ batch_size=batch_size,
498
+ hidden_dtype=hidden_states.dtype,
499
+ )
500
+ timestep_embed = timestep_embed.view(
501
+ batch_size, timestep_embed.shape[-1], 1, 1, 1
502
+ )
503
+
504
+ for resnet in self.res_blocks:
505
+ hidden_states = resnet(hidden_states, causal=causal, timestep=timestep_embed)
506
+
507
+ return hidden_states
508
+
509
+
510
+ class DepthToSpaceUpsample(nn.Module):
511
+ def __init__(
512
+ self, dims, in_channels, stride, residual=False, out_channels_reduction_factor=1
513
+ ):
514
+ super().__init__()
515
+ self.stride = stride
516
+ self.out_channels = (
517
+ math.prod(stride) * in_channels // out_channels_reduction_factor
518
+ )
519
+ self.conv = make_conv_nd(
520
+ dims=dims,
521
+ in_channels=in_channels,
522
+ out_channels=self.out_channels,
523
+ kernel_size=3,
524
+ stride=1,
525
+ causal=True,
526
+ )
527
+ self.residual = residual
528
+ self.out_channels_reduction_factor = out_channels_reduction_factor
529
+
530
+ def forward(self, x, causal: bool = True, timestep: Optional[torch.Tensor] = None):
531
+ if self.residual:
532
+ # Reshape and duplicate the input to match the output shape
533
+ x_in = rearrange(
534
+ x,
535
+ "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
536
+ p1=self.stride[0],
537
+ p2=self.stride[1],
538
+ p3=self.stride[2],
539
+ )
540
+ num_repeat = math.prod(self.stride) // self.out_channels_reduction_factor
541
+ x_in = x_in.repeat(1, num_repeat, 1, 1, 1)
542
+ if self.stride[0] == 2:
543
+ x_in = x_in[:, :, 1:, :, :]
544
+ x = self.conv(x, causal=causal)
545
+ x = rearrange(
546
+ x,
547
+ "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
548
+ p1=self.stride[0],
549
+ p2=self.stride[1],
550
+ p3=self.stride[2],
551
+ )
552
+ if self.stride[0] == 2:
553
+ x = x[:, :, 1:, :, :]
554
+ if self.residual:
555
+ x = x + x_in
556
+ return x
557
+
558
+ class LayerNorm(nn.Module):
559
+ def __init__(self, dim, eps, elementwise_affine=True) -> None:
560
+ super().__init__()
561
+ self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
562
+
563
+ def forward(self, x):
564
+ x = rearrange(x, "b c d h w -> b d h w c")
565
+ x = self.norm(x)
566
+ x = rearrange(x, "b d h w c -> b c d h w")
567
+ return x
568
+
569
+
570
+ class ResnetBlock3D(nn.Module):
571
+ r"""
572
+ A Resnet block.
573
+
574
+ Parameters:
575
+ in_channels (`int`): The number of channels in the input.
576
+ out_channels (`int`, *optional*, default to be `None`):
577
+ The number of output channels for the first conv layer. If None, same as `in_channels`.
578
+ dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
579
+ groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
580
+ eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
581
+ """
582
+
583
+ def __init__(
584
+ self,
585
+ dims: Union[int, Tuple[int, int]],
586
+ in_channels: int,
587
+ out_channels: Optional[int] = None,
588
+ dropout: float = 0.0,
589
+ groups: int = 32,
590
+ eps: float = 1e-6,
591
+ norm_layer: str = "group_norm",
592
+ inject_noise: bool = False,
593
+ timestep_conditioning: bool = False,
594
+ ):
595
+ super().__init__()
596
+ self.in_channels = in_channels
597
+ out_channels = in_channels if out_channels is None else out_channels
598
+ self.out_channels = out_channels
599
+ self.inject_noise = inject_noise
600
+
601
+ if norm_layer == "group_norm":
602
+ self.norm1 = nn.GroupNorm(
603
+ num_groups=groups, num_channels=in_channels, eps=eps, affine=True
604
+ )
605
+ elif norm_layer == "pixel_norm":
606
+ self.norm1 = PixelNorm()
607
+ elif norm_layer == "layer_norm":
608
+ self.norm1 = LayerNorm(in_channels, eps=eps, elementwise_affine=True)
609
+
610
+ self.non_linearity = nn.SiLU()
611
+
612
+ self.conv1 = make_conv_nd(
613
+ dims,
614
+ in_channels,
615
+ out_channels,
616
+ kernel_size=3,
617
+ stride=1,
618
+ padding=1,
619
+ causal=True,
620
+ )
621
+
622
+ if inject_noise:
623
+ self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
624
+
625
+ if norm_layer == "group_norm":
626
+ self.norm2 = nn.GroupNorm(
627
+ num_groups=groups, num_channels=out_channels, eps=eps, affine=True
628
+ )
629
+ elif norm_layer == "pixel_norm":
630
+ self.norm2 = PixelNorm()
631
+ elif norm_layer == "layer_norm":
632
+ self.norm2 = LayerNorm(out_channels, eps=eps, elementwise_affine=True)
633
+
634
+ self.dropout = torch.nn.Dropout(dropout)
635
+
636
+ self.conv2 = make_conv_nd(
637
+ dims,
638
+ out_channels,
639
+ out_channels,
640
+ kernel_size=3,
641
+ stride=1,
642
+ padding=1,
643
+ causal=True,
644
+ )
645
+
646
+ if inject_noise:
647
+ self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
648
+
649
+ self.conv_shortcut = (
650
+ make_linear_nd(
651
+ dims=dims, in_channels=in_channels, out_channels=out_channels
652
+ )
653
+ if in_channels != out_channels
654
+ else nn.Identity()
655
+ )
656
+
657
+ self.norm3 = (
658
+ LayerNorm(in_channels, eps=eps, elementwise_affine=True)
659
+ if in_channels != out_channels
660
+ else nn.Identity()
661
+ )
662
+
663
+ self.timestep_conditioning = timestep_conditioning
664
+
665
+ if timestep_conditioning:
666
+ self.scale_shift_table = nn.Parameter(
667
+ torch.randn(4, in_channels) / in_channels**0.5
668
+ )
669
+
670
+ def _feed_spatial_noise(
671
+ self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor
672
+ ) -> torch.FloatTensor:
673
+ spatial_shape = hidden_states.shape[-2:]
674
+ device = hidden_states.device
675
+ dtype = hidden_states.dtype
676
+
677
+ # similar to the "explicit noise inputs" method in style-gan
678
+ spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype)[None]
679
+ scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...]
680
+ hidden_states = hidden_states + scaled_noise
681
+
682
+ return hidden_states
683
+
684
+ def forward(
685
+ self,
686
+ input_tensor: torch.FloatTensor,
687
+ causal: bool = True,
688
+ timestep: Optional[torch.Tensor] = None,
689
+ ) -> torch.FloatTensor:
690
+ hidden_states = input_tensor
691
+ batch_size = hidden_states.shape[0]
692
+
693
+ hidden_states = self.norm1(hidden_states)
694
+ if self.timestep_conditioning:
695
+ assert (
696
+ timestep is not None
697
+ ), "should pass timestep with timestep_conditioning=True"
698
+ ada_values = self.scale_shift_table[
699
+ None, ..., None, None, None
700
+ ].to(device=hidden_states.device, dtype=hidden_states.dtype) + timestep.reshape(
701
+ batch_size,
702
+ 4,
703
+ -1,
704
+ timestep.shape[-3],
705
+ timestep.shape[-2],
706
+ timestep.shape[-1],
707
+ )
708
+ shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1)
709
+
710
+ hidden_states = hidden_states * (1 + scale1) + shift1
711
+
712
+ hidden_states = self.non_linearity(hidden_states)
713
+
714
+ hidden_states = self.conv1(hidden_states, causal=causal)
715
+
716
+ if self.inject_noise:
717
+ hidden_states = self._feed_spatial_noise(
718
+ hidden_states, self.per_channel_scale1.to(device=hidden_states.device, dtype=hidden_states.dtype)
719
+ )
720
+
721
+ hidden_states = self.norm2(hidden_states)
722
+
723
+ if self.timestep_conditioning:
724
+ hidden_states = hidden_states * (1 + scale2) + shift2
725
+
726
+ hidden_states = self.non_linearity(hidden_states)
727
+
728
+ hidden_states = self.dropout(hidden_states)
729
+
730
+ hidden_states = self.conv2(hidden_states, causal=causal)
731
+
732
+ if self.inject_noise:
733
+ hidden_states = self._feed_spatial_noise(
734
+ hidden_states, self.per_channel_scale2.to(device=hidden_states.device, dtype=hidden_states.dtype)
735
+ )
736
+
737
+ input_tensor = self.norm3(input_tensor)
738
+
739
+ batch_size = input_tensor.shape[0]
740
+
741
+ input_tensor = self.conv_shortcut(input_tensor)
742
+
743
+ output_tensor = input_tensor + hidden_states
744
+
745
+ return output_tensor
746
+
747
+
748
+ def patchify(x, patch_size_hw, patch_size_t=1):
749
+ if patch_size_hw == 1 and patch_size_t == 1:
750
+ return x
751
+ if x.dim() == 4:
752
+ x = rearrange(
753
+ x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw
754
+ )
755
+ elif x.dim() == 5:
756
+ x = rearrange(
757
+ x,
758
+ "b c (f p) (h q) (w r) -> b (c p r q) f h w",
759
+ p=patch_size_t,
760
+ q=patch_size_hw,
761
+ r=patch_size_hw,
762
+ )
763
+ else:
764
+ raise ValueError(f"Invalid input shape: {x.shape}")
765
+
766
+ return x
767
+
768
+
769
+ def unpatchify(x, patch_size_hw, patch_size_t=1):
770
+ if patch_size_hw == 1 and patch_size_t == 1:
771
+ return x
772
+
773
+ if x.dim() == 4:
774
+ x = rearrange(
775
+ x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw
776
+ )
777
+ elif x.dim() == 5:
778
+ x = rearrange(
779
+ x,
780
+ "b (c p r q) f h w -> b c (f p) (h q) (w r)",
781
+ p=patch_size_t,
782
+ q=patch_size_hw,
783
+ r=patch_size_hw,
784
+ )
785
+
786
+ return x
787
+
788
+ class processor(nn.Module):
789
+ def __init__(self):
790
+ super().__init__()
791
+ self.register_buffer("std-of-means", torch.empty(128))
792
+ self.register_buffer("mean-of-means", torch.empty(128))
793
+ self.register_buffer("mean-of-stds", torch.empty(128))
794
+ self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(128))
795
+ self.register_buffer("channel", torch.empty(128))
796
+
797
+ def un_normalize(self, x):
798
+ return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)
799
+
800
+ def normalize(self, x):
801
+ return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
802
+
803
+ class VideoVAE(nn.Module):
804
+ def __init__(self, version=0):
805
+ super().__init__()
806
+
807
+ if version == 0:
808
+ config = {
809
+ "_class_name": "CausalVideoAutoencoder",
810
+ "dims": 3,
811
+ "in_channels": 3,
812
+ "out_channels": 3,
813
+ "latent_channels": 128,
814
+ "blocks": [
815
+ ["res_x", 4],
816
+ ["compress_all", 1],
817
+ ["res_x_y", 1],
818
+ ["res_x", 3],
819
+ ["compress_all", 1],
820
+ ["res_x_y", 1],
821
+ ["res_x", 3],
822
+ ["compress_all", 1],
823
+ ["res_x", 3],
824
+ ["res_x", 4],
825
+ ],
826
+ "scaling_factor": 1.0,
827
+ "norm_layer": "pixel_norm",
828
+ "patch_size": 4,
829
+ "latent_log_var": "uniform",
830
+ "use_quant_conv": False,
831
+ "causal_decoder": False,
832
+ }
833
+ else:
834
+ config = {
835
+ "_class_name": "CausalVideoAutoencoder",
836
+ "dims": 3,
837
+ "in_channels": 3,
838
+ "out_channels": 3,
839
+ "latent_channels": 128,
840
+ "decoder_blocks": [
841
+ ["res_x", {"num_layers": 5, "inject_noise": True}],
842
+ ["compress_all", {"residual": True, "multiplier": 2}],
843
+ ["res_x", {"num_layers": 6, "inject_noise": True}],
844
+ ["compress_all", {"residual": True, "multiplier": 2}],
845
+ ["res_x", {"num_layers": 7, "inject_noise": True}],
846
+ ["compress_all", {"residual": True, "multiplier": 2}],
847
+ ["res_x", {"num_layers": 8, "inject_noise": False}]
848
+ ],
849
+ "encoder_blocks": [
850
+ ["res_x", {"num_layers": 4}],
851
+ ["compress_all", {}],
852
+ ["res_x_y", 1],
853
+ ["res_x", {"num_layers": 3}],
854
+ ["compress_all", {}],
855
+ ["res_x_y", 1],
856
+ ["res_x", {"num_layers": 3}],
857
+ ["compress_all", {}],
858
+ ["res_x", {"num_layers": 3}],
859
+ ["res_x", {"num_layers": 4}]
860
+ ],
861
+ "scaling_factor": 1.0,
862
+ "norm_layer": "pixel_norm",
863
+ "patch_size": 4,
864
+ "latent_log_var": "uniform",
865
+ "use_quant_conv": False,
866
+ "causal_decoder": False,
867
+ "timestep_conditioning": True,
868
+ }
869
+
870
+ double_z = config.get("double_z", True)
871
+ latent_log_var = config.get(
872
+ "latent_log_var", "per_channel" if double_z else "none"
873
+ )
874
+
875
+ self.encoder = Encoder(
876
+ dims=config["dims"],
877
+ in_channels=config.get("in_channels", 3),
878
+ out_channels=config["latent_channels"],
879
+ blocks=config.get("encoder_blocks", config.get("encoder_blocks", config.get("blocks"))),
880
+ patch_size=config.get("patch_size", 1),
881
+ latent_log_var=latent_log_var,
882
+ norm_layer=config.get("norm_layer", "group_norm"),
883
+ )
884
+
885
+ self.decoder = Decoder(
886
+ dims=config["dims"],
887
+ in_channels=config["latent_channels"],
888
+ out_channels=config.get("out_channels", 3),
889
+ blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
890
+ patch_size=config.get("patch_size", 1),
891
+ norm_layer=config.get("norm_layer", "group_norm"),
892
+ causal=config.get("causal_decoder", False),
893
+ timestep_conditioning=config.get("timestep_conditioning", False),
894
+ )
895
+
896
+ self.timestep_conditioning = config.get("timestep_conditioning", False)
897
+ self.per_channel_statistics = processor()
898
+
899
+ def encode(self, x):
900
+ means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
901
+ return self.per_channel_statistics.normalize(means)
902
+
903
+ def decode(self, x, timestep=0.05, noise_scale=0.025):
904
+ if self.timestep_conditioning: #TODO: seed
905
+ x = torch.randn_like(x) * noise_scale + (1.0 - noise_scale) * x
906
+ return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=timestep)
907
+
vae/conv_nd_factory.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Union
2
+
3
+
4
+ from .dual_conv3d import DualConv3d
5
+ from .causal_conv3d import CausalConv3d
6
+ import comfy.ops
7
+ ops = comfy.ops.disable_weight_init
8
+
9
+ def make_conv_nd(
10
+ dims: Union[int, Tuple[int, int]],
11
+ in_channels: int,
12
+ out_channels: int,
13
+ kernel_size: int,
14
+ stride=1,
15
+ padding=0,
16
+ dilation=1,
17
+ groups=1,
18
+ bias=True,
19
+ causal=False,
20
+ ):
21
+ if dims == 2:
22
+ return ops.Conv2d(
23
+ in_channels=in_channels,
24
+ out_channels=out_channels,
25
+ kernel_size=kernel_size,
26
+ stride=stride,
27
+ padding=padding,
28
+ dilation=dilation,
29
+ groups=groups,
30
+ bias=bias,
31
+ )
32
+ elif dims == 3:
33
+ if causal:
34
+ return CausalConv3d(
35
+ in_channels=in_channels,
36
+ out_channels=out_channels,
37
+ kernel_size=kernel_size,
38
+ stride=stride,
39
+ padding=padding,
40
+ dilation=dilation,
41
+ groups=groups,
42
+ bias=bias,
43
+ )
44
+ return ops.Conv3d(
45
+ in_channels=in_channels,
46
+ out_channels=out_channels,
47
+ kernel_size=kernel_size,
48
+ stride=stride,
49
+ padding=padding,
50
+ dilation=dilation,
51
+ groups=groups,
52
+ bias=bias,
53
+ )
54
+ elif dims == (2, 1):
55
+ return DualConv3d(
56
+ in_channels=in_channels,
57
+ out_channels=out_channels,
58
+ kernel_size=kernel_size,
59
+ stride=stride,
60
+ padding=padding,
61
+ bias=bias,
62
+ )
63
+ else:
64
+ raise ValueError(f"unsupported dimensions: {dims}")
65
+
66
+
67
+ def make_linear_nd(
68
+ dims: int,
69
+ in_channels: int,
70
+ out_channels: int,
71
+ bias=True,
72
+ ):
73
+ if dims == 2:
74
+ return ops.Conv2d(
75
+ in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
76
+ )
77
+ elif dims == 3 or dims == (2, 1):
78
+ return ops.Conv3d(
79
+ in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
80
+ )
81
+ else:
82
+ raise ValueError(f"unsupported dimensions: {dims}")
vae/dual_conv3d.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+
9
+
10
+ class DualConv3d(nn.Module):
11
+ def __init__(
12
+ self,
13
+ in_channels,
14
+ out_channels,
15
+ kernel_size,
16
+ stride: Union[int, Tuple[int, int, int]] = 1,
17
+ padding: Union[int, Tuple[int, int, int]] = 0,
18
+ dilation: Union[int, Tuple[int, int, int]] = 1,
19
+ groups=1,
20
+ bias=True,
21
+ ):
22
+ super(DualConv3d, self).__init__()
23
+
24
+ self.in_channels = in_channels
25
+ self.out_channels = out_channels
26
+ # Ensure kernel_size, stride, padding, and dilation are tuples of length 3
27
+ if isinstance(kernel_size, int):
28
+ kernel_size = (kernel_size, kernel_size, kernel_size)
29
+ if kernel_size == (1, 1, 1):
30
+ raise ValueError(
31
+ "kernel_size must be greater than 1. Use make_linear_nd instead."
32
+ )
33
+ if isinstance(stride, int):
34
+ stride = (stride, stride, stride)
35
+ if isinstance(padding, int):
36
+ padding = (padding, padding, padding)
37
+ if isinstance(dilation, int):
38
+ dilation = (dilation, dilation, dilation)
39
+
40
+ # Set parameters for convolutions
41
+ self.groups = groups
42
+ self.bias = bias
43
+
44
+ # Define the size of the channels after the first convolution
45
+ intermediate_channels = (
46
+ out_channels if in_channels < out_channels else in_channels
47
+ )
48
+
49
+ # Define parameters for the first convolution
50
+ self.weight1 = nn.Parameter(
51
+ torch.Tensor(
52
+ intermediate_channels,
53
+ in_channels // groups,
54
+ 1,
55
+ kernel_size[1],
56
+ kernel_size[2],
57
+ )
58
+ )
59
+ self.stride1 = (1, stride[1], stride[2])
60
+ self.padding1 = (0, padding[1], padding[2])
61
+ self.dilation1 = (1, dilation[1], dilation[2])
62
+ if bias:
63
+ self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels))
64
+ else:
65
+ self.register_parameter("bias1", None)
66
+
67
+ # Define parameters for the second convolution
68
+ self.weight2 = nn.Parameter(
69
+ torch.Tensor(
70
+ out_channels, intermediate_channels // groups, kernel_size[0], 1, 1
71
+ )
72
+ )
73
+ self.stride2 = (stride[0], 1, 1)
74
+ self.padding2 = (padding[0], 0, 0)
75
+ self.dilation2 = (dilation[0], 1, 1)
76
+ if bias:
77
+ self.bias2 = nn.Parameter(torch.Tensor(out_channels))
78
+ else:
79
+ self.register_parameter("bias2", None)
80
+
81
+ # Initialize weights and biases
82
+ self.reset_parameters()
83
+
84
+ def reset_parameters(self):
85
+ nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5))
86
+ nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5))
87
+ if self.bias:
88
+ fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1)
89
+ bound1 = 1 / math.sqrt(fan_in1)
90
+ nn.init.uniform_(self.bias1, -bound1, bound1)
91
+ fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2)
92
+ bound2 = 1 / math.sqrt(fan_in2)
93
+ nn.init.uniform_(self.bias2, -bound2, bound2)
94
+
95
+ def forward(self, x, use_conv3d=False, skip_time_conv=False):
96
+ if use_conv3d:
97
+ return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv)
98
+ else:
99
+ return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv)
100
+
101
+ def forward_with_3d(self, x, skip_time_conv):
102
+ # First convolution
103
+ x = F.conv3d(
104
+ x,
105
+ self.weight1,
106
+ self.bias1,
107
+ self.stride1,
108
+ self.padding1,
109
+ self.dilation1,
110
+ self.groups,
111
+ )
112
+
113
+ if skip_time_conv:
114
+ return x
115
+
116
+ # Second convolution
117
+ x = F.conv3d(
118
+ x,
119
+ self.weight2,
120
+ self.bias2,
121
+ self.stride2,
122
+ self.padding2,
123
+ self.dilation2,
124
+ self.groups,
125
+ )
126
+
127
+ return x
128
+
129
+ def forward_with_2d(self, x, skip_time_conv):
130
+ b, c, d, h, w = x.shape
131
+
132
+ # First 2D convolution
133
+ x = rearrange(x, "b c d h w -> (b d) c h w")
134
+ # Squeeze the depth dimension out of weight1 since it's 1
135
+ weight1 = self.weight1.squeeze(2)
136
+ # Select stride, padding, and dilation for the 2D convolution
137
+ stride1 = (self.stride1[1], self.stride1[2])
138
+ padding1 = (self.padding1[1], self.padding1[2])
139
+ dilation1 = (self.dilation1[1], self.dilation1[2])
140
+ x = F.conv2d(x, weight1, self.bias1, stride1, padding1, dilation1, self.groups)
141
+
142
+ _, _, h, w = x.shape
143
+
144
+ if skip_time_conv:
145
+ x = rearrange(x, "(b d) c h w -> b c d h w", b=b)
146
+ return x
147
+
148
+ # Second convolution which is essentially treated as a 1D convolution across the 'd' dimension
149
+ x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b)
150
+
151
+ # Reshape weight2 to match the expected dimensions for conv1d
152
+ weight2 = self.weight2.squeeze(-1).squeeze(-1)
153
+ # Use only the relevant dimension for stride, padding, and dilation for the 1D convolution
154
+ stride2 = self.stride2[0]
155
+ padding2 = self.padding2[0]
156
+ dilation2 = self.dilation2[0]
157
+ x = F.conv1d(x, weight2, self.bias2, stride2, padding2, dilation2, self.groups)
158
+ x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
159
+
160
+ return x
161
+
162
+ @property
163
+ def weight(self):
164
+ return self.weight2
165
+
166
+
167
+ def test_dual_conv3d_consistency():
168
+ # Initialize parameters
169
+ in_channels = 3
170
+ out_channels = 5
171
+ kernel_size = (3, 3, 3)
172
+ stride = (2, 2, 2)
173
+ padding = (1, 1, 1)
174
+
175
+ # Create an instance of the DualConv3d class
176
+ dual_conv3d = DualConv3d(
177
+ in_channels=in_channels,
178
+ out_channels=out_channels,
179
+ kernel_size=kernel_size,
180
+ stride=stride,
181
+ padding=padding,
182
+ bias=True,
183
+ )
184
+
185
+ # Example input tensor
186
+ test_input = torch.randn(1, 3, 10, 10, 10)
187
+
188
+ # Perform forward passes with both 3D and 2D settings
189
+ output_conv3d = dual_conv3d(test_input, use_conv3d=True)
190
+ output_2d = dual_conv3d(test_input, use_conv3d=False)
191
+
192
+ # Assert that the outputs from both methods are sufficiently close
193
+ assert torch.allclose(
194
+ output_conv3d, output_2d, atol=1e-6
195
+ ), "Outputs are not consistent between 3D and 2D convolutions."
vae/model.py ADDED
@@ -0,0 +1,711 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #original code from https://github.com/genmoai/models under apache 2.0 license
2
+ #adapted to ComfyUI
3
+
4
+ from typing import List, Optional, Tuple, Union
5
+ from functools import partial
6
+ import math
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from einops import rearrange
12
+
13
+ from comfy.ldm.modules.attention import optimized_attention
14
+
15
+ import comfy.ops
16
+ ops = comfy.ops.disable_weight_init
17
+
18
+ # import mochi_preview.dit.joint_model.context_parallel as cp
19
+ # from mochi_preview.vae.cp_conv import cp_pass_frames, gather_all_frames
20
+
21
+
22
+ def cast_tuple(t, length=1):
23
+ return t if isinstance(t, tuple) else ((t,) * length)
24
+
25
+
26
+ class GroupNormSpatial(ops.GroupNorm):
27
+ """
28
+ GroupNorm applied per-frame.
29
+ """
30
+
31
+ def forward(self, x: torch.Tensor, *, chunk_size: int = 8):
32
+ B, C, T, H, W = x.shape
33
+ x = rearrange(x, "B C T H W -> (B T) C H W")
34
+ # Run group norm in chunks.
35
+ output = torch.empty_like(x)
36
+ for b in range(0, B * T, chunk_size):
37
+ output[b : b + chunk_size] = super().forward(x[b : b + chunk_size])
38
+ return rearrange(output, "(B T) C H W -> B C T H W", B=B, T=T)
39
+
40
+ class PConv3d(ops.Conv3d):
41
+ def __init__(
42
+ self,
43
+ in_channels,
44
+ out_channels,
45
+ kernel_size: Union[int, Tuple[int, int, int]],
46
+ stride: Union[int, Tuple[int, int, int]],
47
+ causal: bool = True,
48
+ context_parallel: bool = True,
49
+ **kwargs,
50
+ ):
51
+ self.causal = causal
52
+ self.context_parallel = context_parallel
53
+ kernel_size = cast_tuple(kernel_size, 3)
54
+ stride = cast_tuple(stride, 3)
55
+ height_pad = (kernel_size[1] - 1) // 2
56
+ width_pad = (kernel_size[2] - 1) // 2
57
+
58
+ super().__init__(
59
+ in_channels=in_channels,
60
+ out_channels=out_channels,
61
+ kernel_size=kernel_size,
62
+ stride=stride,
63
+ dilation=(1, 1, 1),
64
+ padding=(0, height_pad, width_pad),
65
+ **kwargs,
66
+ )
67
+
68
+ def forward(self, x: torch.Tensor):
69
+ # Compute padding amounts.
70
+ context_size = self.kernel_size[0] - 1
71
+ if self.causal:
72
+ pad_front = context_size
73
+ pad_back = 0
74
+ else:
75
+ pad_front = context_size // 2
76
+ pad_back = context_size - pad_front
77
+
78
+ # Apply padding.
79
+ assert self.padding_mode == "replicate" # DEBUG
80
+ mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
81
+ x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
82
+ return super().forward(x)
83
+
84
+
85
+ class Conv1x1(ops.Linear):
86
+ """*1x1 Conv implemented with a linear layer."""
87
+
88
+ def __init__(self, in_features: int, out_features: int, *args, **kwargs):
89
+ super().__init__(in_features, out_features, *args, **kwargs)
90
+
91
+ def forward(self, x: torch.Tensor):
92
+ """Forward pass.
93
+
94
+ Args:
95
+ x: Input tensor. Shape: [B, C, *] or [B, *, C].
96
+
97
+ Returns:
98
+ x: Output tensor. Shape: [B, C', *] or [B, *, C'].
99
+ """
100
+ x = x.movedim(1, -1)
101
+ x = super().forward(x)
102
+ x = x.movedim(-1, 1)
103
+ return x
104
+
105
+
106
+ class DepthToSpaceTime(nn.Module):
107
+ def __init__(
108
+ self,
109
+ temporal_expansion: int,
110
+ spatial_expansion: int,
111
+ ):
112
+ super().__init__()
113
+ self.temporal_expansion = temporal_expansion
114
+ self.spatial_expansion = spatial_expansion
115
+
116
+ # When printed, this module should show the temporal and spatial expansion factors.
117
+ def extra_repr(self):
118
+ return f"texp={self.temporal_expansion}, sexp={self.spatial_expansion}"
119
+
120
+ def forward(self, x: torch.Tensor):
121
+ """Forward pass.
122
+
123
+ Args:
124
+ x: Input tensor. Shape: [B, C, T, H, W].
125
+
126
+ Returns:
127
+ x: Rearranged tensor. Shape: [B, C/(st*s*s), T*st, H*s, W*s].
128
+ """
129
+ x = rearrange(
130
+ x,
131
+ "B (C st sh sw) T H W -> B C (T st) (H sh) (W sw)",
132
+ st=self.temporal_expansion,
133
+ sh=self.spatial_expansion,
134
+ sw=self.spatial_expansion,
135
+ )
136
+
137
+ # cp_rank, _ = cp.get_cp_rank_size()
138
+ if self.temporal_expansion > 1: # and cp_rank == 0:
139
+ # Drop the first self.temporal_expansion - 1 frames.
140
+ # This is because we always want the 3x3x3 conv filter to only apply
141
+ # to the first frame, and the first frame doesn't need to be repeated.
142
+ assert all(x.shape)
143
+ x = x[:, :, self.temporal_expansion - 1 :]
144
+ assert all(x.shape)
145
+
146
+ return x
147
+
148
+
149
+ def norm_fn(
150
+ in_channels: int,
151
+ affine: bool = True,
152
+ ):
153
+ return GroupNormSpatial(affine=affine, num_groups=32, num_channels=in_channels)
154
+
155
+
156
+ class ResBlock(nn.Module):
157
+ """Residual block that preserves the spatial dimensions."""
158
+
159
+ def __init__(
160
+ self,
161
+ channels: int,
162
+ *,
163
+ affine: bool = True,
164
+ attn_block: Optional[nn.Module] = None,
165
+ causal: bool = True,
166
+ prune_bottleneck: bool = False,
167
+ padding_mode: str,
168
+ bias: bool = True,
169
+ ):
170
+ super().__init__()
171
+ self.channels = channels
172
+
173
+ assert causal
174
+ self.stack = nn.Sequential(
175
+ norm_fn(channels, affine=affine),
176
+ nn.SiLU(inplace=True),
177
+ PConv3d(
178
+ in_channels=channels,
179
+ out_channels=channels // 2 if prune_bottleneck else channels,
180
+ kernel_size=(3, 3, 3),
181
+ stride=(1, 1, 1),
182
+ padding_mode=padding_mode,
183
+ bias=bias,
184
+ causal=causal,
185
+ ),
186
+ norm_fn(channels, affine=affine),
187
+ nn.SiLU(inplace=True),
188
+ PConv3d(
189
+ in_channels=channels // 2 if prune_bottleneck else channels,
190
+ out_channels=channels,
191
+ kernel_size=(3, 3, 3),
192
+ stride=(1, 1, 1),
193
+ padding_mode=padding_mode,
194
+ bias=bias,
195
+ causal=causal,
196
+ ),
197
+ )
198
+
199
+ self.attn_block = attn_block if attn_block else nn.Identity()
200
+
201
+ def forward(self, x: torch.Tensor):
202
+ """Forward pass.
203
+
204
+ Args:
205
+ x: Input tensor. Shape: [B, C, T, H, W].
206
+ """
207
+ residual = x
208
+ x = self.stack(x)
209
+ x = x + residual
210
+ del residual
211
+
212
+ return self.attn_block(x)
213
+
214
+
215
+ class Attention(nn.Module):
216
+ def __init__(
217
+ self,
218
+ dim: int,
219
+ head_dim: int = 32,
220
+ qkv_bias: bool = False,
221
+ out_bias: bool = True,
222
+ qk_norm: bool = True,
223
+ ) -> None:
224
+ super().__init__()
225
+ self.head_dim = head_dim
226
+ self.num_heads = dim // head_dim
227
+ self.qk_norm = qk_norm
228
+
229
+ self.qkv = nn.Linear(dim, 3 * dim, bias=qkv_bias)
230
+ self.out = nn.Linear(dim, dim, bias=out_bias)
231
+
232
+ def forward(
233
+ self,
234
+ x: torch.Tensor,
235
+ ) -> torch.Tensor:
236
+ """Compute temporal self-attention.
237
+
238
+ Args:
239
+ x: Input tensor. Shape: [B, C, T, H, W].
240
+ chunk_size: Chunk size for large tensors.
241
+
242
+ Returns:
243
+ x: Output tensor. Shape: [B, C, T, H, W].
244
+ """
245
+ B, _, T, H, W = x.shape
246
+
247
+ if T == 1:
248
+ # No attention for single frame.
249
+ x = x.movedim(1, -1) # [B, C, T, H, W] -> [B, T, H, W, C]
250
+ qkv = self.qkv(x)
251
+ _, _, x = qkv.chunk(3, dim=-1) # Throw away queries and keys.
252
+ x = self.out(x)
253
+ return x.movedim(-1, 1) # [B, T, H, W, C] -> [B, C, T, H, W]
254
+
255
+ # 1D temporal attention.
256
+ x = rearrange(x, "B C t h w -> (B h w) t C")
257
+ qkv = self.qkv(x)
258
+
259
+ # Input: qkv with shape [B, t, 3 * num_heads * head_dim]
260
+ # Output: x with shape [B, num_heads, t, head_dim]
261
+ q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, self.head_dim).transpose(1, 3).unbind(2)
262
+
263
+ if self.qk_norm:
264
+ q = F.normalize(q, p=2, dim=-1)
265
+ k = F.normalize(k, p=2, dim=-1)
266
+
267
+ x = optimized_attention(q, k, v, self.num_heads, skip_reshape=True)
268
+
269
+ assert x.size(0) == q.size(0)
270
+
271
+ x = self.out(x)
272
+ x = rearrange(x, "(B h w) t C -> B C t h w", B=B, h=H, w=W)
273
+ return x
274
+
275
+
276
+ class AttentionBlock(nn.Module):
277
+ def __init__(
278
+ self,
279
+ dim: int,
280
+ **attn_kwargs,
281
+ ) -> None:
282
+ super().__init__()
283
+ self.norm = norm_fn(dim)
284
+ self.attn = Attention(dim, **attn_kwargs)
285
+
286
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
287
+ return x + self.attn(self.norm(x))
288
+
289
+
290
+ class CausalUpsampleBlock(nn.Module):
291
+ def __init__(
292
+ self,
293
+ in_channels: int,
294
+ out_channels: int,
295
+ num_res_blocks: int,
296
+ *,
297
+ temporal_expansion: int = 2,
298
+ spatial_expansion: int = 2,
299
+ **block_kwargs,
300
+ ):
301
+ super().__init__()
302
+
303
+ blocks = []
304
+ for _ in range(num_res_blocks):
305
+ blocks.append(block_fn(in_channels, **block_kwargs))
306
+ self.blocks = nn.Sequential(*blocks)
307
+
308
+ self.temporal_expansion = temporal_expansion
309
+ self.spatial_expansion = spatial_expansion
310
+
311
+ # Change channels in the final convolution layer.
312
+ self.proj = Conv1x1(
313
+ in_channels,
314
+ out_channels * temporal_expansion * (spatial_expansion**2),
315
+ )
316
+
317
+ self.d2st = DepthToSpaceTime(
318
+ temporal_expansion=temporal_expansion, spatial_expansion=spatial_expansion
319
+ )
320
+
321
+ def forward(self, x):
322
+ x = self.blocks(x)
323
+ x = self.proj(x)
324
+ x = self.d2st(x)
325
+ return x
326
+
327
+
328
+ def block_fn(channels, *, affine: bool = True, has_attention: bool = False, **block_kwargs):
329
+ attn_block = AttentionBlock(channels) if has_attention else None
330
+ return ResBlock(channels, affine=affine, attn_block=attn_block, **block_kwargs)
331
+
332
+
333
+ class DownsampleBlock(nn.Module):
334
+ def __init__(
335
+ self,
336
+ in_channels: int,
337
+ out_channels: int,
338
+ num_res_blocks,
339
+ *,
340
+ temporal_reduction=2,
341
+ spatial_reduction=2,
342
+ **block_kwargs,
343
+ ):
344
+ """
345
+ Downsample block for the VAE encoder.
346
+
347
+ Args:
348
+ in_channels: Number of input channels.
349
+ out_channels: Number of output channels.
350
+ num_res_blocks: Number of residual blocks.
351
+ temporal_reduction: Temporal reduction factor.
352
+ spatial_reduction: Spatial reduction factor.
353
+ """
354
+ super().__init__()
355
+ layers = []
356
+
357
+ # Change the channel count in the strided convolution.
358
+ # This lets the ResBlock have uniform channel count,
359
+ # as in ConvNeXt.
360
+ assert in_channels != out_channels
361
+ layers.append(
362
+ PConv3d(
363
+ in_channels=in_channels,
364
+ out_channels=out_channels,
365
+ kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction),
366
+ stride=(temporal_reduction, spatial_reduction, spatial_reduction),
367
+ # First layer in each block always uses replicate padding
368
+ padding_mode="replicate",
369
+ bias=block_kwargs["bias"],
370
+ )
371
+ )
372
+
373
+ for _ in range(num_res_blocks):
374
+ layers.append(block_fn(out_channels, **block_kwargs))
375
+
376
+ self.layers = nn.Sequential(*layers)
377
+
378
+ def forward(self, x):
379
+ return self.layers(x)
380
+
381
+
382
+ def add_fourier_features(inputs: torch.Tensor, start=6, stop=8, step=1):
383
+ num_freqs = (stop - start) // step
384
+ assert inputs.ndim == 5
385
+ C = inputs.size(1)
386
+
387
+ # Create Base 2 Fourier features.
388
+ freqs = torch.arange(start, stop, step, dtype=inputs.dtype, device=inputs.device)
389
+ assert num_freqs == len(freqs)
390
+ w = torch.pow(2.0, freqs) * (2 * torch.pi) # [num_freqs]
391
+ C = inputs.shape[1]
392
+ w = w.repeat(C)[None, :, None, None, None] # [1, C * num_freqs, 1, 1, 1]
393
+
394
+ # Interleaved repeat of input channels to match w.
395
+ h = inputs.repeat_interleave(num_freqs, dim=1) # [B, C * num_freqs, T, H, W]
396
+ # Scale channels by frequency.
397
+ h = w * h
398
+
399
+ return torch.cat(
400
+ [
401
+ inputs,
402
+ torch.sin(h),
403
+ torch.cos(h),
404
+ ],
405
+ dim=1,
406
+ )
407
+
408
+
409
+ class FourierFeatures(nn.Module):
410
+ def __init__(self, start: int = 6, stop: int = 8, step: int = 1):
411
+ super().__init__()
412
+ self.start = start
413
+ self.stop = stop
414
+ self.step = step
415
+
416
+ def forward(self, inputs):
417
+ """Add Fourier features to inputs.
418
+
419
+ Args:
420
+ inputs: Input tensor. Shape: [B, C, T, H, W]
421
+
422
+ Returns:
423
+ h: Output tensor. Shape: [B, (1 + 2 * num_freqs) * C, T, H, W]
424
+ """
425
+ return add_fourier_features(inputs, self.start, self.stop, self.step)
426
+
427
+
428
+ class Decoder(nn.Module):
429
+ def __init__(
430
+ self,
431
+ *,
432
+ out_channels: int = 3,
433
+ latent_dim: int,
434
+ base_channels: int,
435
+ channel_multipliers: List[int],
436
+ num_res_blocks: List[int],
437
+ temporal_expansions: Optional[List[int]] = None,
438
+ spatial_expansions: Optional[List[int]] = None,
439
+ has_attention: List[bool],
440
+ output_norm: bool = True,
441
+ nonlinearity: str = "silu",
442
+ output_nonlinearity: str = "silu",
443
+ causal: bool = True,
444
+ **block_kwargs,
445
+ ):
446
+ super().__init__()
447
+ self.input_channels = latent_dim
448
+ self.base_channels = base_channels
449
+ self.channel_multipliers = channel_multipliers
450
+ self.num_res_blocks = num_res_blocks
451
+ self.output_nonlinearity = output_nonlinearity
452
+ assert nonlinearity == "silu"
453
+ assert causal
454
+
455
+ ch = [mult * base_channels for mult in channel_multipliers]
456
+ self.num_up_blocks = len(ch) - 1
457
+ assert len(num_res_blocks) == self.num_up_blocks + 2
458
+
459
+ blocks = []
460
+
461
+ first_block = [
462
+ ops.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1))
463
+ ] # Input layer.
464
+ # First set of blocks preserve channel count.
465
+ for _ in range(num_res_blocks[-1]):
466
+ first_block.append(
467
+ block_fn(
468
+ ch[-1],
469
+ has_attention=has_attention[-1],
470
+ causal=causal,
471
+ **block_kwargs,
472
+ )
473
+ )
474
+ blocks.append(nn.Sequential(*first_block))
475
+
476
+ assert len(temporal_expansions) == len(spatial_expansions) == self.num_up_blocks
477
+ assert len(num_res_blocks) == len(has_attention) == self.num_up_blocks + 2
478
+
479
+ upsample_block_fn = CausalUpsampleBlock
480
+
481
+ for i in range(self.num_up_blocks):
482
+ block = upsample_block_fn(
483
+ ch[-i - 1],
484
+ ch[-i - 2],
485
+ num_res_blocks=num_res_blocks[-i - 2],
486
+ has_attention=has_attention[-i - 2],
487
+ temporal_expansion=temporal_expansions[-i - 1],
488
+ spatial_expansion=spatial_expansions[-i - 1],
489
+ causal=causal,
490
+ **block_kwargs,
491
+ )
492
+ blocks.append(block)
493
+
494
+ assert not output_norm
495
+
496
+ # Last block. Preserve channel count.
497
+ last_block = []
498
+ for _ in range(num_res_blocks[0]):
499
+ last_block.append(
500
+ block_fn(
501
+ ch[0], has_attention=has_attention[0], causal=causal, **block_kwargs
502
+ )
503
+ )
504
+ blocks.append(nn.Sequential(*last_block))
505
+
506
+ self.blocks = nn.ModuleList(blocks)
507
+ self.output_proj = Conv1x1(ch[0], out_channels)
508
+
509
+ def forward(self, x):
510
+ """Forward pass.
511
+
512
+ Args:
513
+ x: Latent tensor. Shape: [B, input_channels, t, h, w]. Scaled [-1, 1].
514
+
515
+ Returns:
516
+ x: Reconstructed video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1].
517
+ T + 1 = (t - 1) * 4.
518
+ H = h * 16, W = w * 16.
519
+ """
520
+ for block in self.blocks:
521
+ x = block(x)
522
+
523
+ if self.output_nonlinearity == "silu":
524
+ x = F.silu(x, inplace=not self.training)
525
+ else:
526
+ assert (
527
+ not self.output_nonlinearity
528
+ ) # StyleGAN3 omits the to-RGB nonlinearity.
529
+
530
+ return self.output_proj(x).contiguous()
531
+
532
+ class LatentDistribution:
533
+ def __init__(self, mean: torch.Tensor, logvar: torch.Tensor):
534
+ """Initialize latent distribution.
535
+
536
+ Args:
537
+ mean: Mean of the distribution. Shape: [B, C, T, H, W].
538
+ logvar: Logarithm of variance of the distribution. Shape: [B, C, T, H, W].
539
+ """
540
+ assert mean.shape == logvar.shape
541
+ self.mean = mean
542
+ self.logvar = logvar
543
+
544
+ def sample(self, temperature=1.0, generator: torch.Generator = None, noise=None):
545
+ if temperature == 0.0:
546
+ return self.mean
547
+
548
+ if noise is None:
549
+ noise = torch.randn(self.mean.shape, device=self.mean.device, dtype=self.mean.dtype, generator=generator)
550
+ else:
551
+ assert noise.device == self.mean.device
552
+ noise = noise.to(self.mean.dtype)
553
+
554
+ if temperature != 1.0:
555
+ raise NotImplementedError(f"Temperature {temperature} is not supported.")
556
+
557
+ # Just Gaussian sample with no scaling of variance.
558
+ return noise * torch.exp(self.logvar * 0.5) + self.mean
559
+
560
+ def mode(self):
561
+ return self.mean
562
+
563
+ class Encoder(nn.Module):
564
+ def __init__(
565
+ self,
566
+ *,
567
+ in_channels: int,
568
+ base_channels: int,
569
+ channel_multipliers: List[int],
570
+ num_res_blocks: List[int],
571
+ latent_dim: int,
572
+ temporal_reductions: List[int],
573
+ spatial_reductions: List[int],
574
+ prune_bottlenecks: List[bool],
575
+ has_attentions: List[bool],
576
+ affine: bool = True,
577
+ bias: bool = True,
578
+ input_is_conv_1x1: bool = False,
579
+ padding_mode: str,
580
+ ):
581
+ super().__init__()
582
+ self.temporal_reductions = temporal_reductions
583
+ self.spatial_reductions = spatial_reductions
584
+ self.base_channels = base_channels
585
+ self.channel_multipliers = channel_multipliers
586
+ self.num_res_blocks = num_res_blocks
587
+ self.latent_dim = latent_dim
588
+
589
+ self.fourier_features = FourierFeatures()
590
+ ch = [mult * base_channels for mult in channel_multipliers]
591
+ num_down_blocks = len(ch) - 1
592
+ assert len(num_res_blocks) == num_down_blocks + 2
593
+
594
+ layers = (
595
+ [ops.Conv3d(in_channels, ch[0], kernel_size=(1, 1, 1), bias=True)]
596
+ if not input_is_conv_1x1
597
+ else [Conv1x1(in_channels, ch[0])]
598
+ )
599
+
600
+ assert len(prune_bottlenecks) == num_down_blocks + 2
601
+ assert len(has_attentions) == num_down_blocks + 2
602
+ block = partial(block_fn, padding_mode=padding_mode, affine=affine, bias=bias)
603
+
604
+ for _ in range(num_res_blocks[0]):
605
+ layers.append(block(ch[0], has_attention=has_attentions[0], prune_bottleneck=prune_bottlenecks[0]))
606
+ prune_bottlenecks = prune_bottlenecks[1:]
607
+ has_attentions = has_attentions[1:]
608
+
609
+ assert len(temporal_reductions) == len(spatial_reductions) == len(ch) - 1
610
+ for i in range(num_down_blocks):
611
+ layer = DownsampleBlock(
612
+ ch[i],
613
+ ch[i + 1],
614
+ num_res_blocks=num_res_blocks[i + 1],
615
+ temporal_reduction=temporal_reductions[i],
616
+ spatial_reduction=spatial_reductions[i],
617
+ prune_bottleneck=prune_bottlenecks[i],
618
+ has_attention=has_attentions[i],
619
+ affine=affine,
620
+ bias=bias,
621
+ padding_mode=padding_mode,
622
+ )
623
+
624
+ layers.append(layer)
625
+
626
+ # Additional blocks.
627
+ for _ in range(num_res_blocks[-1]):
628
+ layers.append(block(ch[-1], has_attention=has_attentions[-1], prune_bottleneck=prune_bottlenecks[-1]))
629
+
630
+ self.layers = nn.Sequential(*layers)
631
+
632
+ # Output layers.
633
+ self.output_norm = norm_fn(ch[-1])
634
+ self.output_proj = Conv1x1(ch[-1], 2 * latent_dim, bias=False)
635
+
636
+ @property
637
+ def temporal_downsample(self):
638
+ return math.prod(self.temporal_reductions)
639
+
640
+ @property
641
+ def spatial_downsample(self):
642
+ return math.prod(self.spatial_reductions)
643
+
644
+ def forward(self, x) -> LatentDistribution:
645
+ """Forward pass.
646
+
647
+ Args:
648
+ x: Input video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1]
649
+
650
+ Returns:
651
+ means: Latent tensor. Shape: [B, latent_dim, t, h, w]. Scaled [-1, 1].
652
+ h = H // 8, w = W // 8, t - 1 = (T - 1) // 6
653
+ logvar: Shape: [B, latent_dim, t, h, w].
654
+ """
655
+ assert x.ndim == 5, f"Expected 5D input, got {x.shape}"
656
+ x = self.fourier_features(x)
657
+
658
+ x = self.layers(x)
659
+
660
+ x = self.output_norm(x)
661
+ x = F.silu(x, inplace=True)
662
+ x = self.output_proj(x)
663
+
664
+ means, logvar = torch.chunk(x, 2, dim=1)
665
+
666
+ assert means.ndim == 5
667
+ assert logvar.shape == means.shape
668
+ assert means.size(1) == self.latent_dim
669
+
670
+ return LatentDistribution(means, logvar)
671
+
672
+
673
+ class VideoVAE(nn.Module):
674
+ def __init__(self):
675
+ super().__init__()
676
+ self.encoder = Encoder(
677
+ in_channels=15,
678
+ base_channels=64,
679
+ channel_multipliers=[1, 2, 4, 6],
680
+ num_res_blocks=[3, 3, 4, 6, 3],
681
+ latent_dim=12,
682
+ temporal_reductions=[1, 2, 3],
683
+ spatial_reductions=[2, 2, 2],
684
+ prune_bottlenecks=[False, False, False, False, False],
685
+ has_attentions=[False, True, True, True, True],
686
+ affine=True,
687
+ bias=True,
688
+ input_is_conv_1x1=True,
689
+ padding_mode="replicate"
690
+ )
691
+ self.decoder = Decoder(
692
+ out_channels=3,
693
+ base_channels=128,
694
+ channel_multipliers=[1, 2, 4, 6],
695
+ temporal_expansions=[1, 2, 3],
696
+ spatial_expansions=[2, 2, 2],
697
+ num_res_blocks=[3, 3, 4, 6, 3],
698
+ latent_dim=12,
699
+ has_attention=[False, False, False, False, False],
700
+ padding_mode="replicate",
701
+ output_norm=False,
702
+ nonlinearity="silu",
703
+ output_nonlinearity="silu",
704
+ causal=True,
705
+ )
706
+
707
+ def encode(self, x):
708
+ return self.encoder(x).mode()
709
+
710
+ def decode(self, x):
711
+ return self.decoder(x)
vae/pixel_norm.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class PixelNorm(nn.Module):
6
+ def __init__(self, dim=1, eps=1e-8):
7
+ super(PixelNorm, self).__init__()
8
+ self.dim = dim
9
+ self.eps = eps
10
+
11
+ def forward(self, x):
12
+ return x / torch.sqrt(torch.mean(x**2, dim=self.dim, keepdim=True) + self.eps)