Image-to-3D
ashawkey commited on
Commit
f0e1e27
1 Parent(s): 8747d5d

further clean!

Browse files
Files changed (3) hide show
  1. main.py +2 -2
  2. mvdream/attention.py +13 -85
  3. mvdream/models.py +12 -176
main.py CHANGED
@@ -5,8 +5,8 @@ import argparse
5
  from mvdream.pipeline_mvdream import MVDreamStableDiffusionPipeline
6
 
7
  pipe = MVDreamStableDiffusionPipeline.from_pretrained(
8
- # "./weights", # local weights
9
- "ashawkey/mvdream-sd2.1-diffusers",
10
  torch_dtype=torch.float16
11
  )
12
  pipe = pipe.to("cuda")
 
5
  from mvdream.pipeline_mvdream import MVDreamStableDiffusionPipeline
6
 
7
  pipe = MVDreamStableDiffusionPipeline.from_pretrained(
8
+ "./weights", # local weights
9
+ # "ashawkey/mvdream-sd2.1-diffusers",
10
  torch_dtype=torch.float16
11
  )
12
  pipe = pipe.to("cuda")
mvdream/attention.py CHANGED
@@ -2,14 +2,14 @@
2
 
3
  import math
4
  import torch
 
5
  import torch.nn.functional as F
 
6
 
7
  from inspect import isfunction
8
- from torch import nn, einsum
9
- from torch.amp.autocast_mode import autocast
10
  from einops import rearrange, repeat
11
  from typing import Optional, Any
12
- from .util import checkpoint
13
 
14
  try:
15
  import xformers # type: ignore
@@ -25,28 +25,12 @@ import os
25
  _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
26
 
27
 
28
- def uniq(arr):
29
- return {el: True for el in arr}.keys()
30
-
31
-
32
  def default(val, d):
33
  if val is not None:
34
  return val
35
  return d() if isfunction(d) else d
36
 
37
 
38
- def max_neg_value(t):
39
- return -torch.finfo(t.dtype).max
40
-
41
-
42
- def init_(tensor):
43
- dim = tensor.shape[-1]
44
- std = 1 / math.sqrt(dim)
45
- tensor.uniform_(-std, std)
46
- return tensor
47
-
48
-
49
- # feedforward
50
  class GEGLU(nn.Module):
51
  def __init__(self, dim_in, dim_out):
52
  super().__init__()
@@ -76,66 +60,6 @@ class FeedForward(nn.Module):
76
  return self.net(x)
77
 
78
 
79
- def zero_module(module):
80
- """
81
- Zero out the parameters of a module and return it.
82
- """
83
- for p in module.parameters():
84
- p.detach().zero_()
85
- return module
86
-
87
-
88
- def Normalize(in_channels):
89
- return torch.nn.GroupNorm(
90
- num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
91
- )
92
-
93
-
94
- class SpatialSelfAttention(nn.Module):
95
- def __init__(self, in_channels):
96
- super().__init__()
97
- self.in_channels = in_channels
98
-
99
- self.norm = Normalize(in_channels)
100
- self.q = torch.nn.Conv2d(
101
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
102
- )
103
- self.k = torch.nn.Conv2d(
104
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
105
- )
106
- self.v = torch.nn.Conv2d(
107
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
108
- )
109
- self.proj_out = torch.nn.Conv2d(
110
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
111
- )
112
-
113
- def forward(self, x):
114
- h_ = x
115
- h_ = self.norm(h_)
116
- q = self.q(h_)
117
- k = self.k(h_)
118
- v = self.v(h_)
119
-
120
- # compute attention
121
- b, c, h, w = q.shape
122
- q = rearrange(q, "b c h w -> b (h w) c")
123
- k = rearrange(k, "b c h w -> b c (h w)")
124
- w_ = torch.einsum("bij,bjk->bik", q, k)
125
-
126
- w_ = w_ * (int(c) ** (-0.5))
127
- w_ = torch.nn.functional.softmax(w_, dim=2)
128
-
129
- # attend to values
130
- v = rearrange(v, "b c h w -> b c (h w)")
131
- w_ = rearrange(w_, "b i j -> b j i")
132
- h_ = torch.einsum("bij,bjk->bik", v, w_)
133
- h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
134
- h_ = self.proj_out(h_)
135
-
136
- return x + h_
137
-
138
-
139
  class CrossAttention(nn.Module):
140
  def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
141
  super().__init__()
@@ -167,9 +91,9 @@ class CrossAttention(nn.Module):
167
  if _ATTN_PRECISION == "fp32":
168
  with autocast(enabled=False, device_type="cuda"):
169
  q, k = q.float(), k.float()
170
- sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
171
  else:
172
- sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
173
 
174
  del q, k
175
 
@@ -182,7 +106,7 @@ class CrossAttention(nn.Module):
182
  # attention, what we cannot get enough of
183
  sim = sim.softmax(dim=-1)
184
 
185
- out = einsum("b i j, b j d -> b i d", sim, v)
186
  out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
187
  return self.to_out(out)
188
 
@@ -326,7 +250,9 @@ class SpatialTransformer(nn.Module):
326
  context_dim = [context_dim]
327
  self.in_channels = in_channels
328
  inner_dim = n_heads * d_head
329
- self.norm = Normalize(in_channels)
 
 
330
  if not use_linear:
331
  self.proj_in = nn.Conv2d(
332
  in_channels, inner_dim, kernel_size=1, stride=1, padding=0
@@ -410,7 +336,7 @@ class SpatialTransformer3D(nn.Module):
410
  dropout=0.0,
411
  context_dim=None,
412
  disable_self_attn=False,
413
- use_linear=False,
414
  use_checkpoint=True,
415
  ):
416
  super().__init__()
@@ -419,7 +345,9 @@ class SpatialTransformer3D(nn.Module):
419
  context_dim = [context_dim]
420
  self.in_channels = in_channels
421
  inner_dim = n_heads * d_head
422
- self.norm = Normalize(in_channels)
 
 
423
  if not use_linear:
424
  self.proj_in = nn.Conv2d(
425
  in_channels, inner_dim, kernel_size=1, stride=1, padding=0
 
2
 
3
  import math
4
  import torch
5
+ import torch.nn as nn
6
  import torch.nn.functional as F
7
+ from torch.amp.autocast_mode import autocast
8
 
9
  from inspect import isfunction
 
 
10
  from einops import rearrange, repeat
11
  from typing import Optional, Any
12
+ from .util import checkpoint, zero_module
13
 
14
  try:
15
  import xformers # type: ignore
 
25
  _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
26
 
27
 
 
 
 
 
28
  def default(val, d):
29
  if val is not None:
30
  return val
31
  return d() if isfunction(d) else d
32
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  class GEGLU(nn.Module):
35
  def __init__(self, dim_in, dim_out):
36
  super().__init__()
 
60
  return self.net(x)
61
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  class CrossAttention(nn.Module):
64
  def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
65
  super().__init__()
 
91
  if _ATTN_PRECISION == "fp32":
92
  with autocast(enabled=False, device_type="cuda"):
93
  q, k = q.float(), k.float()
94
+ sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
95
  else:
96
+ sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
97
 
98
  del q, k
99
 
 
106
  # attention, what we cannot get enough of
107
  sim = sim.softmax(dim=-1)
108
 
109
+ out = torch.einsum("b i j, b j d -> b i d", sim, v)
110
  out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
111
  return self.to_out(out)
112
 
 
250
  context_dim = [context_dim]
251
  self.in_channels = in_channels
252
  inner_dim = n_heads * d_head
253
+ self.norm = nn.GroupNorm(
254
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
255
+ )
256
  if not use_linear:
257
  self.proj_in = nn.Conv2d(
258
  in_channels, inner_dim, kernel_size=1, stride=1, padding=0
 
336
  dropout=0.0,
337
  context_dim=None,
338
  disable_self_attn=False,
339
+ use_linear=True,
340
  use_checkpoint=True,
341
  ):
342
  super().__init__()
 
345
  context_dim = [context_dim]
346
  self.in_channels = in_channels
347
  inner_dim = n_heads * d_head
348
+ self.norm = nn.GroupNorm(
349
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
350
+ )
351
  if not use_linear:
352
  self.proj_in = nn.Conv2d(
353
  in_channels, inner_dim, kernel_size=1, stride=1, padding=0
mvdream/models.py CHANGED
@@ -1,8 +1,7 @@
1
  # obtained and modified from https://github.com/bytedance/MVDream
2
 
3
  import math
4
- import numpy as np
5
- import torch as th
6
  import torch.nn as nn
7
  import torch.nn.functional as F
8
  from diffusers.configuration_utils import ConfigMixin
@@ -223,7 +222,7 @@ class ResBlock(TimestepBlock):
223
  emb_out = emb_out[..., None]
224
  if self.use_scale_shift_norm:
225
  out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
226
- scale, shift = th.chunk(emb_out, 2, dim=1)
227
  h = out_norm(h) * (1 + scale) + shift
228
  h = out_rest(h)
229
  else:
@@ -232,112 +231,6 @@ class ResBlock(TimestepBlock):
232
  return self.skip_connection(x) + h
233
 
234
 
235
- class AttentionBlock(nn.Module):
236
- """
237
- An attention block that allows spatial positions to attend to each other.
238
- Originally ported from here, but adapted to the N-d case.
239
- https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
240
- """
241
-
242
- def __init__(
243
- self,
244
- channels,
245
- num_heads=1,
246
- num_head_channels=-1,
247
- use_checkpoint=False,
248
- use_new_attention_order=False,
249
- ):
250
- super().__init__()
251
- self.channels = channels
252
- if num_head_channels == -1:
253
- self.num_heads = num_heads
254
- else:
255
- assert (
256
- channels % num_head_channels == 0
257
- ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
258
- self.num_heads = channels // num_head_channels
259
- self.use_checkpoint = use_checkpoint
260
- self.norm = nn.GroupNorm(32, channels)
261
- self.qkv = conv_nd(1, channels, channels * 3, 1)
262
- if use_new_attention_order:
263
- # split qkv before split heads
264
- self.attention = QKVAttention(self.num_heads)
265
- else:
266
- # split heads before split qkv
267
- self.attention = QKVAttentionLegacy(self.num_heads)
268
-
269
- self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
270
-
271
- def forward(self, x):
272
- return checkpoint(self._forward, (x,), self.parameters(), True)
273
-
274
- def _forward(self, x):
275
- b, c, *spatial = x.shape
276
- x = x.reshape(b, c, -1)
277
- qkv = self.qkv(self.norm(x))
278
- h = self.attention(qkv)
279
- h = self.proj_out(h)
280
- return (x + h).reshape(b, c, *spatial)
281
-
282
-
283
- class QKVAttentionLegacy(nn.Module):
284
- """
285
- A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
286
- """
287
-
288
- def __init__(self, n_heads):
289
- super().__init__()
290
- self.n_heads = n_heads
291
-
292
- def forward(self, qkv):
293
- """
294
- Apply QKV attention.
295
- :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
296
- :return: an [N x (H * C) x T] tensor after attention.
297
- """
298
- bs, width, length = qkv.shape
299
- assert width % (3 * self.n_heads) == 0
300
- ch = width // (3 * self.n_heads)
301
- q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
302
- scale = 1 / math.sqrt(math.sqrt(ch))
303
- weight = th.einsum(
304
- "bct,bcs->bts", q * scale, k * scale
305
- ) # More stable with f16 than dividing afterwards
306
- weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
307
- a = th.einsum("bts,bcs->bct", weight, v)
308
- return a.reshape(bs, -1, length)
309
-
310
-
311
- class QKVAttention(nn.Module):
312
- """
313
- A module which performs QKV attention and splits in a different order.
314
- """
315
-
316
- def __init__(self, n_heads):
317
- super().__init__()
318
- self.n_heads = n_heads
319
-
320
- def forward(self, qkv):
321
- """
322
- Apply QKV attention.
323
- :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
324
- :return: an [N x (H * C) x T] tensor after attention.
325
- """
326
- bs, width, length = qkv.shape
327
- assert width % (3 * self.n_heads) == 0
328
- ch = width // (3 * self.n_heads)
329
- q, k, v = qkv.chunk(3, dim=1)
330
- scale = 1 / math.sqrt(math.sqrt(ch))
331
- weight = th.einsum(
332
- "bct,bcs->bts",
333
- (q * scale).view(bs * self.n_heads, ch, length),
334
- (k * scale).view(bs * self.n_heads, ch, length),
335
- ) # More stable with f16 than dividing afterwards
336
- weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
337
- a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
338
- return a.reshape(bs, -1, length)
339
-
340
-
341
  class MultiViewUNetModel(ModelMixin, ConfigMixin):
342
  """
343
  The full multi-view UNet model with attention, timestep embedding and camera embedding.
@@ -388,34 +281,18 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
388
  num_heads_upsample=-1,
389
  use_scale_shift_norm=False,
390
  resblock_updown=False,
391
- use_new_attention_order=False,
392
- use_spatial_transformer=False, # custom transformer support
393
  transformer_depth=1, # custom transformer support
394
  context_dim=None, # custom transformer support
395
  n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
396
- legacy=True,
397
  disable_self_attentions=None,
398
  num_attention_blocks=None,
399
  disable_middle_self_attn=False,
400
- use_linear_in_transformer=False,
401
  adm_in_channels=None,
402
  camera_dim=None,
403
  ):
404
  super().__init__()
405
- if use_spatial_transformer:
406
- assert (
407
- context_dim is not None
408
- ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
409
-
410
- if context_dim is not None:
411
- assert (
412
- use_spatial_transformer
413
- ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
414
- from omegaconf.listconfig import ListConfig
415
-
416
- if type(context_dim) == ListConfig:
417
- context_dim = list(context_dim)
418
-
419
  if num_heads_upsample == -1:
420
  num_heads_upsample = num_heads
421
 
@@ -535,13 +412,7 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
535
  else:
536
  num_heads = ch // num_head_channels
537
  dim_head = num_head_channels
538
- if legacy:
539
- # num_heads = 1
540
- dim_head = (
541
- ch // num_heads
542
- if use_spatial_transformer
543
- else num_head_channels
544
- )
545
  if disable_self_attentions is not None:
546
  disabled_sa = disable_self_attentions[level]
547
  else:
@@ -549,22 +420,13 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
549
 
550
  if num_attention_blocks is None or nr < num_attention_blocks[level]:
551
  layers.append(
552
- AttentionBlock(
553
- ch,
554
- use_checkpoint=use_checkpoint,
555
- num_heads=num_heads,
556
- num_head_channels=dim_head,
557
- use_new_attention_order=use_new_attention_order,
558
- )
559
- if not use_spatial_transformer
560
- else SpatialTransformer3D(
561
  ch,
562
  num_heads,
563
  dim_head,
564
  depth=transformer_depth,
565
  context_dim=context_dim,
566
  disable_self_attn=disabled_sa,
567
- use_linear=use_linear_in_transformer,
568
  use_checkpoint=use_checkpoint,
569
  )
570
  )
@@ -601,9 +463,7 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
601
  else:
602
  num_heads = ch // num_head_channels
603
  dim_head = num_head_channels
604
- if legacy:
605
- # num_heads = 1
606
- dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
607
  self.middle_block = TimestepEmbedSequential(
608
  ResBlock(
609
  ch,
@@ -613,24 +473,15 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
613
  use_checkpoint=use_checkpoint,
614
  use_scale_shift_norm=use_scale_shift_norm,
615
  ),
616
- AttentionBlock(
617
- ch,
618
- use_checkpoint=use_checkpoint,
619
- num_heads=num_heads,
620
- num_head_channels=dim_head,
621
- use_new_attention_order=use_new_attention_order,
622
- )
623
- if not use_spatial_transformer
624
- else SpatialTransformer3D(
625
  ch,
626
  num_heads,
627
  dim_head,
628
  depth=transformer_depth,
629
  context_dim=context_dim,
630
  disable_self_attn=disable_middle_self_attn,
631
- use_linear=use_linear_in_transformer,
632
  use_checkpoint=use_checkpoint,
633
- ), # always uses a self-attn
634
  ResBlock(
635
  ch,
636
  time_embed_dim,
@@ -664,13 +515,7 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
664
  else:
665
  num_heads = ch // num_head_channels
666
  dim_head = num_head_channels
667
- if legacy:
668
- # num_heads = 1
669
- dim_head = (
670
- ch // num_heads
671
- if use_spatial_transformer
672
- else num_head_channels
673
- )
674
  if disable_self_attentions is not None:
675
  disabled_sa = disable_self_attentions[level]
676
  else:
@@ -678,22 +523,13 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
678
 
679
  if num_attention_blocks is None or i < num_attention_blocks[level]:
680
  layers.append(
681
- AttentionBlock(
682
- ch,
683
- use_checkpoint=use_checkpoint,
684
- num_heads=num_heads_upsample,
685
- num_head_channels=dim_head,
686
- use_new_attention_order=use_new_attention_order,
687
- )
688
- if not use_spatial_transformer
689
- else SpatialTransformer3D(
690
  ch,
691
  num_heads,
692
  dim_head,
693
  depth=transformer_depth,
694
  context_dim=context_dim,
695
  disable_self_attn=disabled_sa,
696
- use_linear=use_linear_in_transformer,
697
  use_checkpoint=use_checkpoint,
698
  )
699
  )
@@ -777,7 +613,7 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
777
  hs.append(h)
778
  h = self.middle_block(h, emb, context, num_frames=num_frames)
779
  for module in self.output_blocks:
780
- h = th.cat([h, hs.pop()], dim=1)
781
  h = module(h, emb, context, num_frames=num_frames)
782
  h = h.type(x.dtype)
783
  if self.predict_codebook_ids:
 
1
  # obtained and modified from https://github.com/bytedance/MVDream
2
 
3
  import math
4
+ import torch
 
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
  from diffusers.configuration_utils import ConfigMixin
 
222
  emb_out = emb_out[..., None]
223
  if self.use_scale_shift_norm:
224
  out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
225
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
226
  h = out_norm(h) * (1 + scale) + shift
227
  h = out_rest(h)
228
  else:
 
231
  return self.skip_connection(x) + h
232
 
233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  class MultiViewUNetModel(ModelMixin, ConfigMixin):
235
  """
236
  The full multi-view UNet model with attention, timestep embedding and camera embedding.
 
281
  num_heads_upsample=-1,
282
  use_scale_shift_norm=False,
283
  resblock_updown=False,
 
 
284
  transformer_depth=1, # custom transformer support
285
  context_dim=None, # custom transformer support
286
  n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
 
287
  disable_self_attentions=None,
288
  num_attention_blocks=None,
289
  disable_middle_self_attn=False,
 
290
  adm_in_channels=None,
291
  camera_dim=None,
292
  ):
293
  super().__init__()
294
+ assert context_dim is not None
295
+
 
 
 
 
 
 
 
 
 
 
 
 
296
  if num_heads_upsample == -1:
297
  num_heads_upsample = num_heads
298
 
 
412
  else:
413
  num_heads = ch // num_head_channels
414
  dim_head = num_head_channels
415
+
 
 
 
 
 
 
416
  if disable_self_attentions is not None:
417
  disabled_sa = disable_self_attentions[level]
418
  else:
 
420
 
421
  if num_attention_blocks is None or nr < num_attention_blocks[level]:
422
  layers.append(
423
+ SpatialTransformer3D(
 
 
 
 
 
 
 
 
424
  ch,
425
  num_heads,
426
  dim_head,
427
  depth=transformer_depth,
428
  context_dim=context_dim,
429
  disable_self_attn=disabled_sa,
 
430
  use_checkpoint=use_checkpoint,
431
  )
432
  )
 
463
  else:
464
  num_heads = ch // num_head_channels
465
  dim_head = num_head_channels
466
+
 
 
467
  self.middle_block = TimestepEmbedSequential(
468
  ResBlock(
469
  ch,
 
473
  use_checkpoint=use_checkpoint,
474
  use_scale_shift_norm=use_scale_shift_norm,
475
  ),
476
+ SpatialTransformer3D(
 
 
 
 
 
 
 
 
477
  ch,
478
  num_heads,
479
  dim_head,
480
  depth=transformer_depth,
481
  context_dim=context_dim,
482
  disable_self_attn=disable_middle_self_attn,
 
483
  use_checkpoint=use_checkpoint,
484
+ ),
485
  ResBlock(
486
  ch,
487
  time_embed_dim,
 
515
  else:
516
  num_heads = ch // num_head_channels
517
  dim_head = num_head_channels
518
+
 
 
 
 
 
 
519
  if disable_self_attentions is not None:
520
  disabled_sa = disable_self_attentions[level]
521
  else:
 
523
 
524
  if num_attention_blocks is None or i < num_attention_blocks[level]:
525
  layers.append(
526
+ SpatialTransformer3D(
 
 
 
 
 
 
 
 
527
  ch,
528
  num_heads,
529
  dim_head,
530
  depth=transformer_depth,
531
  context_dim=context_dim,
532
  disable_self_attn=disabled_sa,
 
533
  use_checkpoint=use_checkpoint,
534
  )
535
  )
 
613
  hs.append(h)
614
  h = self.middle_block(h, emb, context, num_frames=num_frames)
615
  for module in self.output_blocks:
616
+ h = torch.cat([h, hs.pop()], dim=1)
617
  h = module(h, emb, context, num_frames=num_frames)
618
  h = h.type(x.dtype)
619
  if self.predict_codebook_ids: