rakkateichou dylanebert HF staff commited on
Commit
12835dc
·
verified ·
0 Parent(s):

Duplicate from dylanebert/multi-view-diffusion

Browse files

Co-authored-by: Dylan Ebert <[email protected]>

.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ *.pt
2
+ *.yaml
3
+ **/__pycache__
4
+ *.pyc
5
+
6
+ venv/
README.md ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: openrail
3
+ pipeline_tag: image-to-3d
4
+ ---
5
+
6
+ This is a copy of [ashawkey/imagedream-ipmv-diffusers](https://huggingface.co/ashawkey/imagedream-ipmv-diffusers).
7
+
8
+ It is hosted here for persistence throughout the ML for 3D course.
9
+
10
+ # MVDream-diffusers Model Card
11
+
12
+ This is a port of https://huggingface.co/Peng-Wang/ImageDream into diffusers.
13
+
14
+ For usage, please check: https://github.com/ashawkey/mvdream_diffusers
15
+
16
+ ## Citation
17
+
18
+ ```
19
+ @article{wang2023imagedream,
20
+ title={ImageDream: Image-Prompt Multi-view Diffusion for 3D Generation},
21
+ author={Wang, Peng and Shi, Yichun},
22
+ journal={arXiv preprint arXiv:2312.02201},
23
+ year={2023}
24
+ }
25
+ ```
26
+
27
+ ## Misuse, Malicious Use, and Out-of-Scope Use
28
+
29
+ The model should not be used to intentionally create or disseminate images that create hostile or alienating environments for people. This includes generating images that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes.
feature_extractor/preprocessor_config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": {
3
+ "height": 224,
4
+ "width": 224
5
+ },
6
+ "do_center_crop": true,
7
+ "do_convert_rgb": true,
8
+ "do_normalize": true,
9
+ "do_rescale": true,
10
+ "do_resize": true,
11
+ "feature_extractor_type": "CLIPFeatureExtractor",
12
+ "image_mean": [
13
+ 0.48145466,
14
+ 0.4578275,
15
+ 0.40821073
16
+ ],
17
+ "image_processor_type": "CLIPImageProcessor",
18
+ "image_std": [
19
+ 0.26862954,
20
+ 0.26130258,
21
+ 0.27577711
22
+ ],
23
+ "resample": 3,
24
+ "rescale_factor": 0.00392156862745098,
25
+ "size": {
26
+ "shortest_edge": 224
27
+ },
28
+ "use_square_size": false
29
+ }
image_encoder/config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
3
+ "architectures": [
4
+ "CLIPVisionModel"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "dropout": 0.0,
8
+ "hidden_act": "gelu",
9
+ "hidden_size": 1280,
10
+ "image_size": 224,
11
+ "initializer_factor": 1.0,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 5120,
14
+ "layer_norm_eps": 1e-05,
15
+ "model_type": "clip_vision_model",
16
+ "num_attention_heads": 16,
17
+ "num_channels": 3,
18
+ "num_hidden_layers": 32,
19
+ "patch_size": 14,
20
+ "projection_dim": 1024,
21
+ "torch_dtype": "float16",
22
+ "transformers_version": "4.35.2"
23
+ }
image_encoder/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a56cfd4ffcf40be097c430324ec184cc37187f6dafef128ef9225438a3c03c4
3
+ size 1261595704
model_index.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "MVDreamPipeline",
3
+ "_diffusers_version": "0.25.0",
4
+ "feature_extractor": [
5
+ "transformers",
6
+ "CLIPImageProcessor"
7
+ ],
8
+ "image_encoder": [
9
+ "transformers",
10
+ "CLIPVisionModel"
11
+ ],
12
+ "requires_safety_checker": false,
13
+ "scheduler": [
14
+ "diffusers",
15
+ "DDIMScheduler"
16
+ ],
17
+ "text_encoder": [
18
+ "transformers",
19
+ "CLIPTextModel"
20
+ ],
21
+ "tokenizer": [
22
+ "transformers",
23
+ "CLIPTokenizer"
24
+ ],
25
+ "unet": [
26
+ "mv_unet",
27
+ "MultiViewUNetModel"
28
+ ],
29
+ "vae": [
30
+ "diffusers",
31
+ "AutoencoderKL"
32
+ ]
33
+ }
pipeline.py ADDED
@@ -0,0 +1,1583 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from inspect import isfunction
4
+ from typing import Any, Callable, List, Optional, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ # require xformers!
11
+ import xformers
12
+ import xformers.ops
13
+ from diffusers import AutoencoderKL, DiffusionPipeline
14
+ from diffusers.configuration_utils import ConfigMixin, FrozenDict
15
+ from diffusers.models.modeling_utils import ModelMixin
16
+ from diffusers.schedulers import DDIMScheduler
17
+ from diffusers.utils import (deprecate, is_accelerate_available,
18
+ is_accelerate_version, logging)
19
+ from diffusers.utils.torch_utils import randn_tensor
20
+ from einops import rearrange, repeat
21
+ from kiui.cam import orbit_camera
22
+ from transformers import (CLIPImageProcessor, CLIPTextModel, CLIPTokenizer,
23
+ CLIPVisionModel)
24
+
25
+
26
+ def get_camera(
27
+ num_frames,
28
+ elevation=15,
29
+ azimuth_start=0,
30
+ azimuth_span=360,
31
+ blender_coord=True,
32
+ extra_view=False,
33
+ ):
34
+ angle_gap = azimuth_span / num_frames
35
+ cameras = []
36
+ for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap):
37
+
38
+ pose = orbit_camera(
39
+ -elevation, azimuth, radius=1
40
+ ) # kiui's elevation is negated, [4, 4]
41
+
42
+ # opengl to blender
43
+ if blender_coord:
44
+ pose[2] *= -1
45
+ pose[[1, 2]] = pose[[2, 1]]
46
+
47
+ cameras.append(pose.flatten())
48
+
49
+ if extra_view:
50
+ cameras.append(np.zeros_like(cameras[0]))
51
+
52
+ return torch.from_numpy(np.stack(cameras, axis=0)).float() # [num_frames, 16]
53
+
54
+
55
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
56
+ """
57
+ Create sinusoidal timestep embeddings.
58
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
59
+ These may be fractional.
60
+ :param dim: the dimension of the output.
61
+ :param max_period: controls the minimum frequency of the embeddings.
62
+ :return: an [N x dim] Tensor of positional embeddings.
63
+ """
64
+ if not repeat_only:
65
+ half = dim // 2
66
+ freqs = torch.exp(
67
+ -math.log(max_period)
68
+ * torch.arange(start=0, end=half, dtype=torch.float32)
69
+ / half
70
+ ).to(device=timesteps.device)
71
+ args = timesteps[:, None] * freqs[None]
72
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
73
+ if dim % 2:
74
+ embedding = torch.cat(
75
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
76
+ )
77
+ else:
78
+ embedding = repeat(timesteps, "b -> b d", d=dim)
79
+ # import pdb; pdb.set_trace()
80
+ return embedding
81
+
82
+
83
+ def zero_module(module):
84
+ """
85
+ Zero out the parameters of a module and return it.
86
+ """
87
+ for p in module.parameters():
88
+ p.detach().zero_()
89
+ return module
90
+
91
+
92
+ def conv_nd(dims, *args, **kwargs):
93
+ """
94
+ Create a 1D, 2D, or 3D convolution module.
95
+ """
96
+ if dims == 1:
97
+ return nn.Conv1d(*args, **kwargs)
98
+ elif dims == 2:
99
+ return nn.Conv2d(*args, **kwargs)
100
+ elif dims == 3:
101
+ return nn.Conv3d(*args, **kwargs)
102
+ raise ValueError(f"unsupported dimensions: {dims}")
103
+
104
+
105
+ def avg_pool_nd(dims, *args, **kwargs):
106
+ """
107
+ Create a 1D, 2D, or 3D average pooling module.
108
+ """
109
+ if dims == 1:
110
+ return nn.AvgPool1d(*args, **kwargs)
111
+ elif dims == 2:
112
+ return nn.AvgPool2d(*args, **kwargs)
113
+ elif dims == 3:
114
+ return nn.AvgPool3d(*args, **kwargs)
115
+ raise ValueError(f"unsupported dimensions: {dims}")
116
+
117
+
118
+ def default(val, d):
119
+ if val is not None:
120
+ return val
121
+ return d() if isfunction(d) else d
122
+
123
+
124
+ class GEGLU(nn.Module):
125
+ def __init__(self, dim_in, dim_out):
126
+ super().__init__()
127
+ self.proj = nn.Linear(dim_in, dim_out * 2)
128
+
129
+ def forward(self, x):
130
+ x, gate = self.proj(x).chunk(2, dim=-1)
131
+ return x * F.gelu(gate)
132
+
133
+
134
+ class FeedForward(nn.Module):
135
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
136
+ super().__init__()
137
+ inner_dim = int(dim * mult)
138
+ dim_out = default(dim_out, dim)
139
+ project_in = (
140
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
141
+ if not glu
142
+ else GEGLU(dim, inner_dim)
143
+ )
144
+
145
+ self.net = nn.Sequential(
146
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
147
+ )
148
+
149
+ def forward(self, x):
150
+ return self.net(x)
151
+
152
+
153
+ class MemoryEfficientCrossAttention(nn.Module):
154
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
155
+ def __init__(
156
+ self,
157
+ query_dim,
158
+ context_dim=None,
159
+ heads=8,
160
+ dim_head=64,
161
+ dropout=0.0,
162
+ ip_dim=0,
163
+ ip_weight=1,
164
+ ):
165
+ super().__init__()
166
+
167
+ inner_dim = dim_head * heads
168
+ context_dim = default(context_dim, query_dim)
169
+
170
+ self.heads = heads
171
+ self.dim_head = dim_head
172
+
173
+ self.ip_dim = ip_dim
174
+ self.ip_weight = ip_weight
175
+
176
+ if self.ip_dim > 0:
177
+ self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
178
+ self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
179
+
180
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
181
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
182
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
183
+
184
+ self.to_out = nn.Sequential(
185
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
186
+ )
187
+ self.attention_op: Optional[Any] = None
188
+
189
+ def forward(self, x, context=None):
190
+ q = self.to_q(x)
191
+ context = default(context, x)
192
+
193
+ if self.ip_dim > 0:
194
+ # context: [B, 77 + 16(ip), 1024]
195
+ token_len = context.shape[1]
196
+ context_ip = context[:, -self.ip_dim :, :]
197
+ k_ip = self.to_k_ip(context_ip)
198
+ v_ip = self.to_v_ip(context_ip)
199
+ context = context[:, : (token_len - self.ip_dim), :]
200
+
201
+ k = self.to_k(context)
202
+ v = self.to_v(context)
203
+
204
+ b, _, _ = q.shape
205
+ q, k, v = map(
206
+ lambda t: t.unsqueeze(3)
207
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
208
+ .permute(0, 2, 1, 3)
209
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
210
+ .contiguous(),
211
+ (q, k, v),
212
+ )
213
+
214
+ # actually compute the attention, what we cannot get enough of
215
+ out = xformers.ops.memory_efficient_attention(
216
+ q, k, v, attn_bias=None, op=self.attention_op
217
+ )
218
+
219
+ if self.ip_dim > 0:
220
+ k_ip, v_ip = map(
221
+ lambda t: t.unsqueeze(3)
222
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
223
+ .permute(0, 2, 1, 3)
224
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
225
+ .contiguous(),
226
+ (k_ip, v_ip),
227
+ )
228
+ # actually compute the attention, what we cannot get enough of
229
+ out_ip = xformers.ops.memory_efficient_attention(
230
+ q, k_ip, v_ip, attn_bias=None, op=self.attention_op
231
+ )
232
+ out = out + self.ip_weight * out_ip
233
+
234
+ out = (
235
+ out.unsqueeze(0)
236
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
237
+ .permute(0, 2, 1, 3)
238
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
239
+ )
240
+ return self.to_out(out)
241
+
242
+
243
+ class BasicTransformerBlock3D(nn.Module):
244
+
245
+ def __init__(
246
+ self,
247
+ dim,
248
+ n_heads,
249
+ d_head,
250
+ context_dim,
251
+ dropout=0.0,
252
+ gated_ff=True,
253
+ ip_dim=0,
254
+ ip_weight=1,
255
+ ):
256
+ super().__init__()
257
+
258
+ self.attn1 = MemoryEfficientCrossAttention(
259
+ query_dim=dim,
260
+ context_dim=None, # self-attention
261
+ heads=n_heads,
262
+ dim_head=d_head,
263
+ dropout=dropout,
264
+ )
265
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
266
+ self.attn2 = MemoryEfficientCrossAttention(
267
+ query_dim=dim,
268
+ context_dim=context_dim,
269
+ heads=n_heads,
270
+ dim_head=d_head,
271
+ dropout=dropout,
272
+ # ip only applies to cross-attention
273
+ ip_dim=ip_dim,
274
+ ip_weight=ip_weight,
275
+ )
276
+ self.norm1 = nn.LayerNorm(dim)
277
+ self.norm2 = nn.LayerNorm(dim)
278
+ self.norm3 = nn.LayerNorm(dim)
279
+
280
+ def forward(self, x, context=None, num_frames=1):
281
+ x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
282
+ x = self.attn1(self.norm1(x), context=None) + x
283
+ x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
284
+ x = self.attn2(self.norm2(x), context=context) + x
285
+ x = self.ff(self.norm3(x)) + x
286
+ return x
287
+
288
+
289
+ class SpatialTransformer3D(nn.Module):
290
+
291
+ def __init__(
292
+ self,
293
+ in_channels,
294
+ n_heads,
295
+ d_head,
296
+ context_dim, # cross attention input dim
297
+ depth=1,
298
+ dropout=0.0,
299
+ ip_dim=0,
300
+ ip_weight=1,
301
+ ):
302
+ super().__init__()
303
+
304
+ if not isinstance(context_dim, list):
305
+ context_dim = [context_dim]
306
+
307
+ self.in_channels = in_channels
308
+
309
+ inner_dim = n_heads * d_head
310
+ self.norm = nn.GroupNorm(
311
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
312
+ )
313
+ self.proj_in = nn.Linear(in_channels, inner_dim)
314
+
315
+ self.transformer_blocks = nn.ModuleList(
316
+ [
317
+ BasicTransformerBlock3D(
318
+ inner_dim,
319
+ n_heads,
320
+ d_head,
321
+ context_dim=context_dim[d],
322
+ dropout=dropout,
323
+ ip_dim=ip_dim,
324
+ ip_weight=ip_weight,
325
+ )
326
+ for d in range(depth)
327
+ ]
328
+ )
329
+
330
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
331
+
332
+ def forward(self, x, context=None, num_frames=1):
333
+ # note: if no context is given, cross-attention defaults to self-attention
334
+ if not isinstance(context, list):
335
+ context = [context]
336
+ b, c, h, w = x.shape
337
+ x_in = x
338
+ x = self.norm(x)
339
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
340
+ x = self.proj_in(x)
341
+ for i, block in enumerate(self.transformer_blocks):
342
+ x = block(x, context=context[i], num_frames=num_frames)
343
+ x = self.proj_out(x)
344
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
345
+
346
+ return x + x_in
347
+
348
+
349
+ class PerceiverAttention(nn.Module):
350
+ def __init__(self, *, dim, dim_head=64, heads=8):
351
+ super().__init__()
352
+ self.scale = dim_head**-0.5
353
+ self.dim_head = dim_head
354
+ self.heads = heads
355
+ inner_dim = dim_head * heads
356
+
357
+ self.norm1 = nn.LayerNorm(dim)
358
+ self.norm2 = nn.LayerNorm(dim)
359
+
360
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
361
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
362
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
363
+
364
+ def forward(self, x, latents):
365
+ """
366
+ Args:
367
+ x (torch.Tensor): image features
368
+ shape (b, n1, D)
369
+ latent (torch.Tensor): latent features
370
+ shape (b, n2, D)
371
+ """
372
+ x = self.norm1(x)
373
+ latents = self.norm2(latents)
374
+
375
+ b, h, _ = latents.shape
376
+
377
+ q = self.to_q(latents)
378
+ kv_input = torch.cat((x, latents), dim=-2)
379
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
380
+
381
+ q, k, v = map(
382
+ lambda t: t.reshape(b, t.shape[1], self.heads, -1)
383
+ .transpose(1, 2)
384
+ .reshape(b, self.heads, t.shape[1], -1)
385
+ .contiguous(),
386
+ (q, k, v),
387
+ )
388
+
389
+ # attention
390
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
391
+ weight = (q * scale) @ (k * scale).transpose(
392
+ -2, -1
393
+ ) # More stable with f16 than dividing afterwards
394
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
395
+ out = weight @ v
396
+
397
+ out = out.permute(0, 2, 1, 3).reshape(b, h, -1)
398
+
399
+ return self.to_out(out)
400
+
401
+
402
+ class Resampler(nn.Module):
403
+ def __init__(
404
+ self,
405
+ dim=1024,
406
+ depth=8,
407
+ dim_head=64,
408
+ heads=16,
409
+ num_queries=8,
410
+ embedding_dim=768,
411
+ output_dim=1024,
412
+ ff_mult=4,
413
+ ):
414
+ super().__init__()
415
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
416
+ self.proj_in = nn.Linear(embedding_dim, dim)
417
+ self.proj_out = nn.Linear(dim, output_dim)
418
+ self.norm_out = nn.LayerNorm(output_dim)
419
+
420
+ self.layers = nn.ModuleList([])
421
+ for _ in range(depth):
422
+ self.layers.append(
423
+ nn.ModuleList(
424
+ [
425
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
426
+ nn.Sequential(
427
+ nn.LayerNorm(dim),
428
+ nn.Linear(dim, dim * ff_mult, bias=False),
429
+ nn.GELU(),
430
+ nn.Linear(dim * ff_mult, dim, bias=False),
431
+ ),
432
+ ]
433
+ )
434
+ )
435
+
436
+ def forward(self, x):
437
+ latents = self.latents.repeat(x.size(0), 1, 1)
438
+ x = self.proj_in(x)
439
+ for attn, ff in self.layers:
440
+ latents = attn(x, latents) + latents
441
+ latents = ff(latents) + latents
442
+
443
+ latents = self.proj_out(latents)
444
+ return self.norm_out(latents)
445
+
446
+
447
+ class CondSequential(nn.Sequential):
448
+ """
449
+ A sequential module that passes timestep embeddings to the children that
450
+ support it as an extra input.
451
+ """
452
+
453
+ def forward(self, x, emb, context=None, num_frames=1):
454
+ for layer in self:
455
+ if isinstance(layer, ResBlock):
456
+ x = layer(x, emb)
457
+ elif isinstance(layer, SpatialTransformer3D):
458
+ x = layer(x, context, num_frames=num_frames)
459
+ else:
460
+ x = layer(x)
461
+ return x
462
+
463
+
464
+ class Upsample(nn.Module):
465
+ """
466
+ An upsampling layer with an optional convolution.
467
+ :param channels: channels in the inputs and outputs.
468
+ :param use_conv: a bool determining if a convolution is applied.
469
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
470
+ upsampling occurs in the inner-two dimensions.
471
+ """
472
+
473
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
474
+ super().__init__()
475
+ self.channels = channels
476
+ self.out_channels = out_channels or channels
477
+ self.use_conv = use_conv
478
+ self.dims = dims
479
+ if use_conv:
480
+ self.conv = conv_nd(
481
+ dims, self.channels, self.out_channels, 3, padding=padding
482
+ )
483
+
484
+ def forward(self, x):
485
+ assert x.shape[1] == self.channels
486
+ if self.dims == 3:
487
+ x = F.interpolate(
488
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
489
+ )
490
+ else:
491
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
492
+ if self.use_conv:
493
+ x = self.conv(x)
494
+ return x
495
+
496
+
497
+ class Downsample(nn.Module):
498
+ """
499
+ A downsampling layer with an optional convolution.
500
+ :param channels: channels in the inputs and outputs.
501
+ :param use_conv: a bool determining if a convolution is applied.
502
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
503
+ downsampling occurs in the inner-two dimensions.
504
+ """
505
+
506
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
507
+ super().__init__()
508
+ self.channels = channels
509
+ self.out_channels = out_channels or channels
510
+ self.use_conv = use_conv
511
+ self.dims = dims
512
+ stride = 2 if dims != 3 else (1, 2, 2)
513
+ if use_conv:
514
+ self.op = conv_nd(
515
+ dims,
516
+ self.channels,
517
+ self.out_channels,
518
+ 3,
519
+ stride=stride,
520
+ padding=padding,
521
+ )
522
+ else:
523
+ assert self.channels == self.out_channels
524
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
525
+
526
+ def forward(self, x):
527
+ assert x.shape[1] == self.channels
528
+ return self.op(x)
529
+
530
+
531
+ class ResBlock(nn.Module):
532
+ """
533
+ A residual block that can optionally change the number of channels.
534
+ :param channels: the number of input channels.
535
+ :param emb_channels: the number of timestep embedding channels.
536
+ :param dropout: the rate of dropout.
537
+ :param out_channels: if specified, the number of out channels.
538
+ :param use_conv: if True and out_channels is specified, use a spatial
539
+ convolution instead of a smaller 1x1 convolution to change the
540
+ channels in the skip connection.
541
+ :param dims: determines if the signal is 1D, 2D, or 3D.
542
+ :param up: if True, use this block for upsampling.
543
+ :param down: if True, use this block for downsampling.
544
+ """
545
+
546
+ def __init__(
547
+ self,
548
+ channels,
549
+ emb_channels,
550
+ dropout,
551
+ out_channels=None,
552
+ use_conv=False,
553
+ use_scale_shift_norm=False,
554
+ dims=2,
555
+ up=False,
556
+ down=False,
557
+ ):
558
+ super().__init__()
559
+ self.channels = channels
560
+ self.emb_channels = emb_channels
561
+ self.dropout = dropout
562
+ self.out_channels = out_channels or channels
563
+ self.use_conv = use_conv
564
+ self.use_scale_shift_norm = use_scale_shift_norm
565
+
566
+ self.in_layers = nn.Sequential(
567
+ nn.GroupNorm(32, channels),
568
+ nn.SiLU(),
569
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
570
+ )
571
+
572
+ self.updown = up or down
573
+
574
+ if up:
575
+ self.h_upd = Upsample(channels, False, dims)
576
+ self.x_upd = Upsample(channels, False, dims)
577
+ elif down:
578
+ self.h_upd = Downsample(channels, False, dims)
579
+ self.x_upd = Downsample(channels, False, dims)
580
+ else:
581
+ self.h_upd = self.x_upd = nn.Identity()
582
+
583
+ self.emb_layers = nn.Sequential(
584
+ nn.SiLU(),
585
+ nn.Linear(
586
+ emb_channels,
587
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
588
+ ),
589
+ )
590
+ self.out_layers = nn.Sequential(
591
+ nn.GroupNorm(32, self.out_channels),
592
+ nn.SiLU(),
593
+ nn.Dropout(p=dropout),
594
+ zero_module(
595
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
596
+ ),
597
+ )
598
+
599
+ if self.out_channels == channels:
600
+ self.skip_connection = nn.Identity()
601
+ elif use_conv:
602
+ self.skip_connection = conv_nd(
603
+ dims, channels, self.out_channels, 3, padding=1
604
+ )
605
+ else:
606
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
607
+
608
+ def forward(self, x, emb):
609
+ if self.updown:
610
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
611
+ h = in_rest(x)
612
+ h = self.h_upd(h)
613
+ x = self.x_upd(x)
614
+ h = in_conv(h)
615
+ else:
616
+ h = self.in_layers(x)
617
+ emb_out = self.emb_layers(emb).type(h.dtype)
618
+ while len(emb_out.shape) < len(h.shape):
619
+ emb_out = emb_out[..., None]
620
+ if self.use_scale_shift_norm:
621
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
622
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
623
+ h = out_norm(h) * (1 + scale) + shift
624
+ h = out_rest(h)
625
+ else:
626
+ h = h + emb_out
627
+ h = self.out_layers(h)
628
+ return self.skip_connection(x) + h
629
+
630
+
631
+ class MultiViewUNetModel(ModelMixin, ConfigMixin):
632
+ """
633
+ The full multi-view UNet model with attention, timestep embedding and camera embedding.
634
+ :param in_channels: channels in the input Tensor.
635
+ :param model_channels: base channel count for the model.
636
+ :param out_channels: channels in the output Tensor.
637
+ :param num_res_blocks: number of residual blocks per downsample.
638
+ :param attention_resolutions: a collection of downsample rates at which
639
+ attention will take place. May be a set, list, or tuple.
640
+ For example, if this contains 4, then at 4x downsampling, attention
641
+ will be used.
642
+ :param dropout: the dropout probability.
643
+ :param channel_mult: channel multiplier for each level of the UNet.
644
+ :param conv_resample: if True, use learned convolutions for upsampling and
645
+ downsampling.
646
+ :param dims: determines if the signal is 1D, 2D, or 3D.
647
+ :param num_classes: if specified (as an int), then this model will be
648
+ class-conditional with `num_classes` classes.
649
+ :param num_heads: the number of attention heads in each attention layer.
650
+ :param num_heads_channels: if specified, ignore num_heads and instead use
651
+ a fixed channel width per attention head.
652
+ :param num_heads_upsample: works with num_heads to set a different number
653
+ of heads for upsampling. Deprecated.
654
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
655
+ :param resblock_updown: use residual blocks for up/downsampling.
656
+ :param use_new_attention_order: use a different attention pattern for potentially
657
+ increased efficiency.
658
+ :param camera_dim: dimensionality of camera input.
659
+ """
660
+
661
+ def __init__(
662
+ self,
663
+ image_size,
664
+ in_channels,
665
+ model_channels,
666
+ out_channels,
667
+ num_res_blocks,
668
+ attention_resolutions,
669
+ dropout=0,
670
+ channel_mult=(1, 2, 4, 8),
671
+ conv_resample=True,
672
+ dims=2,
673
+ num_classes=None,
674
+ num_heads=-1,
675
+ num_head_channels=-1,
676
+ num_heads_upsample=-1,
677
+ use_scale_shift_norm=False,
678
+ resblock_updown=False,
679
+ transformer_depth=1,
680
+ context_dim=None,
681
+ n_embed=None,
682
+ num_attention_blocks=None,
683
+ adm_in_channels=None,
684
+ camera_dim=None,
685
+ ip_dim=0, # imagedream uses ip_dim > 0
686
+ ip_weight=1.0,
687
+ **kwargs,
688
+ ):
689
+ super().__init__()
690
+ assert context_dim is not None
691
+
692
+ if num_heads_upsample == -1:
693
+ num_heads_upsample = num_heads
694
+
695
+ if num_heads == -1:
696
+ assert (
697
+ num_head_channels != -1
698
+ ), "Either num_heads or num_head_channels has to be set"
699
+
700
+ if num_head_channels == -1:
701
+ assert (
702
+ num_heads != -1
703
+ ), "Either num_heads or num_head_channels has to be set"
704
+
705
+ self.image_size = image_size
706
+ self.in_channels = in_channels
707
+ self.model_channels = model_channels
708
+ self.out_channels = out_channels
709
+ if isinstance(num_res_blocks, int):
710
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
711
+ else:
712
+ if len(num_res_blocks) != len(channel_mult):
713
+ raise ValueError(
714
+ "provide num_res_blocks either as an int (globally constant) or "
715
+ "as a list/tuple (per-level) with the same length as channel_mult"
716
+ )
717
+ self.num_res_blocks = num_res_blocks
718
+
719
+ if num_attention_blocks is not None:
720
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
721
+ assert all(
722
+ map(
723
+ lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
724
+ range(len(num_attention_blocks)),
725
+ )
726
+ )
727
+ print(
728
+ f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
729
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
730
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
731
+ f"attention will still not be set."
732
+ )
733
+
734
+ self.attention_resolutions = attention_resolutions
735
+ self.dropout = dropout
736
+ self.channel_mult = channel_mult
737
+ self.conv_resample = conv_resample
738
+ self.num_classes = num_classes
739
+ self.num_heads = num_heads
740
+ self.num_head_channels = num_head_channels
741
+ self.num_heads_upsample = num_heads_upsample
742
+ self.predict_codebook_ids = n_embed is not None
743
+
744
+ self.ip_dim = ip_dim
745
+ self.ip_weight = ip_weight
746
+
747
+ if self.ip_dim > 0:
748
+ self.image_embed = Resampler(
749
+ dim=context_dim,
750
+ depth=4,
751
+ dim_head=64,
752
+ heads=12,
753
+ num_queries=ip_dim, # num token
754
+ embedding_dim=1280,
755
+ output_dim=context_dim,
756
+ ff_mult=4,
757
+ )
758
+
759
+ time_embed_dim = model_channels * 4
760
+ self.time_embed = nn.Sequential(
761
+ nn.Linear(model_channels, time_embed_dim),
762
+ nn.SiLU(),
763
+ nn.Linear(time_embed_dim, time_embed_dim),
764
+ )
765
+
766
+ if camera_dim is not None:
767
+ time_embed_dim = model_channels * 4
768
+ self.camera_embed = nn.Sequential(
769
+ nn.Linear(camera_dim, time_embed_dim),
770
+ nn.SiLU(),
771
+ nn.Linear(time_embed_dim, time_embed_dim),
772
+ )
773
+
774
+ if self.num_classes is not None:
775
+ if isinstance(self.num_classes, int):
776
+ self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)
777
+ elif self.num_classes == "continuous":
778
+ # print("setting up linear c_adm embedding layer")
779
+ self.label_emb = nn.Linear(1, time_embed_dim)
780
+ elif self.num_classes == "sequential":
781
+ assert adm_in_channels is not None
782
+ self.label_emb = nn.Sequential(
783
+ nn.Sequential(
784
+ nn.Linear(adm_in_channels, time_embed_dim),
785
+ nn.SiLU(),
786
+ nn.Linear(time_embed_dim, time_embed_dim),
787
+ )
788
+ )
789
+ else:
790
+ raise ValueError()
791
+
792
+ self.input_blocks = nn.ModuleList(
793
+ [CondSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]
794
+ )
795
+ self._feature_size = model_channels
796
+ input_block_chans = [model_channels]
797
+ ch = model_channels
798
+ ds = 1
799
+ for level, mult in enumerate(channel_mult):
800
+ for nr in range(self.num_res_blocks[level]):
801
+ layers: List[Any] = [
802
+ ResBlock(
803
+ ch,
804
+ time_embed_dim,
805
+ dropout,
806
+ out_channels=mult * model_channels,
807
+ dims=dims,
808
+ use_scale_shift_norm=use_scale_shift_norm,
809
+ )
810
+ ]
811
+ ch = mult * model_channels
812
+ if ds in attention_resolutions:
813
+ if num_head_channels == -1:
814
+ dim_head = ch // num_heads
815
+ else:
816
+ num_heads = ch // num_head_channels
817
+ dim_head = num_head_channels
818
+
819
+ if num_attention_blocks is None or nr < num_attention_blocks[level]:
820
+ layers.append(
821
+ SpatialTransformer3D(
822
+ ch,
823
+ num_heads,
824
+ dim_head,
825
+ context_dim=context_dim,
826
+ depth=transformer_depth,
827
+ ip_dim=self.ip_dim,
828
+ ip_weight=self.ip_weight,
829
+ )
830
+ )
831
+ self.input_blocks.append(CondSequential(*layers))
832
+ self._feature_size += ch
833
+ input_block_chans.append(ch)
834
+ if level != len(channel_mult) - 1:
835
+ out_ch = ch
836
+ self.input_blocks.append(
837
+ CondSequential(
838
+ ResBlock(
839
+ ch,
840
+ time_embed_dim,
841
+ dropout,
842
+ out_channels=out_ch,
843
+ dims=dims,
844
+ use_scale_shift_norm=use_scale_shift_norm,
845
+ down=True,
846
+ )
847
+ if resblock_updown
848
+ else Downsample(
849
+ ch, conv_resample, dims=dims, out_channels=out_ch
850
+ )
851
+ )
852
+ )
853
+ ch = out_ch
854
+ input_block_chans.append(ch)
855
+ ds *= 2
856
+ self._feature_size += ch
857
+
858
+ if num_head_channels == -1:
859
+ dim_head = ch // num_heads
860
+ else:
861
+ num_heads = ch // num_head_channels
862
+ dim_head = num_head_channels
863
+
864
+ self.middle_block = CondSequential(
865
+ ResBlock(
866
+ ch,
867
+ time_embed_dim,
868
+ dropout,
869
+ dims=dims,
870
+ use_scale_shift_norm=use_scale_shift_norm,
871
+ ),
872
+ SpatialTransformer3D(
873
+ ch,
874
+ num_heads,
875
+ dim_head,
876
+ context_dim=context_dim,
877
+ depth=transformer_depth,
878
+ ip_dim=self.ip_dim,
879
+ ip_weight=self.ip_weight,
880
+ ),
881
+ ResBlock(
882
+ ch,
883
+ time_embed_dim,
884
+ dropout,
885
+ dims=dims,
886
+ use_scale_shift_norm=use_scale_shift_norm,
887
+ ),
888
+ )
889
+ self._feature_size += ch
890
+
891
+ self.output_blocks = nn.ModuleList([])
892
+ for level, mult in list(enumerate(channel_mult))[::-1]:
893
+ for i in range(self.num_res_blocks[level] + 1):
894
+ ich = input_block_chans.pop()
895
+ layers = [
896
+ ResBlock(
897
+ ch + ich,
898
+ time_embed_dim,
899
+ dropout,
900
+ out_channels=model_channels * mult,
901
+ dims=dims,
902
+ use_scale_shift_norm=use_scale_shift_norm,
903
+ )
904
+ ]
905
+ ch = model_channels * mult
906
+ if ds in attention_resolutions:
907
+ if num_head_channels == -1:
908
+ dim_head = ch // num_heads
909
+ else:
910
+ num_heads = ch // num_head_channels
911
+ dim_head = num_head_channels
912
+
913
+ if num_attention_blocks is None or i < num_attention_blocks[level]:
914
+ layers.append(
915
+ SpatialTransformer3D(
916
+ ch,
917
+ num_heads,
918
+ dim_head,
919
+ context_dim=context_dim,
920
+ depth=transformer_depth,
921
+ ip_dim=self.ip_dim,
922
+ ip_weight=self.ip_weight,
923
+ )
924
+ )
925
+ if level and i == self.num_res_blocks[level]:
926
+ out_ch = ch
927
+ layers.append(
928
+ ResBlock(
929
+ ch,
930
+ time_embed_dim,
931
+ dropout,
932
+ out_channels=out_ch,
933
+ dims=dims,
934
+ use_scale_shift_norm=use_scale_shift_norm,
935
+ up=True,
936
+ )
937
+ if resblock_updown
938
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
939
+ )
940
+ ds //= 2
941
+ self.output_blocks.append(CondSequential(*layers))
942
+ self._feature_size += ch
943
+
944
+ self.out = nn.Sequential(
945
+ nn.GroupNorm(32, ch),
946
+ nn.SiLU(),
947
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
948
+ )
949
+ if self.predict_codebook_ids:
950
+ self.id_predictor = nn.Sequential(
951
+ nn.GroupNorm(32, ch),
952
+ conv_nd(dims, model_channels, n_embed, 1),
953
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
954
+ )
955
+
956
+ def forward(
957
+ self,
958
+ x,
959
+ timesteps=None,
960
+ context=None,
961
+ y=None,
962
+ camera=None,
963
+ num_frames=1,
964
+ ip=None,
965
+ ip_img=None,
966
+ **kwargs,
967
+ ):
968
+ """
969
+ Apply the model to an input batch.
970
+ :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views).
971
+ :param timesteps: a 1-D batch of timesteps.
972
+ :param context: conditioning plugged in via crossattn
973
+ :param y: an [N] Tensor of labels, if class-conditional.
974
+ :param num_frames: a integer indicating number of frames for tensor reshaping.
975
+ :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views).
976
+ """
977
+ assert (
978
+ x.shape[0] % num_frames == 0
979
+ ), "input batch size must be dividable by num_frames!"
980
+ assert (y is not None) == (
981
+ self.num_classes is not None
982
+ ), "must specify y if and only if the model is class-conditional"
983
+
984
+ hs = []
985
+
986
+ t_emb = timestep_embedding(
987
+ timesteps, self.model_channels, repeat_only=False
988
+ ).to(x.dtype)
989
+
990
+ emb = self.time_embed(t_emb)
991
+
992
+ if self.num_classes is not None:
993
+ assert y is not None
994
+ assert y.shape[0] == x.shape[0]
995
+ emb = emb + self.label_emb(y)
996
+
997
+ # Add camera embeddings
998
+ if camera is not None:
999
+ emb = emb + self.camera_embed(camera)
1000
+
1001
+ # imagedream variant
1002
+ if self.ip_dim > 0:
1003
+ x[(num_frames - 1) :: num_frames, :, :, :] = ip_img # place at [4, 9]
1004
+ ip_emb = self.image_embed(ip)
1005
+ context = torch.cat((context, ip_emb), 1)
1006
+
1007
+ h = x
1008
+ for module in self.input_blocks:
1009
+ h = module(h, emb, context, num_frames=num_frames)
1010
+ hs.append(h)
1011
+ h = self.middle_block(h, emb, context, num_frames=num_frames)
1012
+ for module in self.output_blocks:
1013
+ h = torch.cat([h, hs.pop()], dim=1)
1014
+ h = module(h, emb, context, num_frames=num_frames)
1015
+ h = h.type(x.dtype)
1016
+ if self.predict_codebook_ids:
1017
+ return self.id_predictor(h)
1018
+ else:
1019
+ return self.out(h)
1020
+
1021
+
1022
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
1023
+
1024
+
1025
+ class MVDreamPipeline(DiffusionPipeline):
1026
+
1027
+ _optional_components = ["feature_extractor", "image_encoder"]
1028
+
1029
+ def __init__(
1030
+ self,
1031
+ vae: AutoencoderKL,
1032
+ unet: MultiViewUNetModel,
1033
+ tokenizer: CLIPTokenizer,
1034
+ text_encoder: CLIPTextModel,
1035
+ scheduler: DDIMScheduler,
1036
+ # imagedream variant
1037
+ feature_extractor: CLIPImageProcessor,
1038
+ image_encoder: CLIPVisionModel,
1039
+ requires_safety_checker: bool = False,
1040
+ ):
1041
+ super().__init__()
1042
+
1043
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: # type: ignore
1044
+ deprecation_message = (
1045
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
1046
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " # type: ignore
1047
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
1048
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
1049
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
1050
+ " file"
1051
+ )
1052
+ deprecate(
1053
+ "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False
1054
+ )
1055
+ new_config = dict(scheduler.config)
1056
+ new_config["steps_offset"] = 1
1057
+ scheduler._internal_dict = FrozenDict(new_config)
1058
+
1059
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: # type: ignore
1060
+ deprecation_message = (
1061
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
1062
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
1063
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
1064
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
1065
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
1066
+ )
1067
+ deprecate(
1068
+ "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False
1069
+ )
1070
+ new_config = dict(scheduler.config)
1071
+ new_config["clip_sample"] = False
1072
+ scheduler._internal_dict = FrozenDict(new_config)
1073
+
1074
+ self.register_modules(
1075
+ vae=vae,
1076
+ unet=unet,
1077
+ scheduler=scheduler,
1078
+ tokenizer=tokenizer,
1079
+ text_encoder=text_encoder,
1080
+ feature_extractor=feature_extractor,
1081
+ image_encoder=image_encoder,
1082
+ )
1083
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
1084
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
1085
+
1086
+ def enable_vae_slicing(self):
1087
+ r"""
1088
+ Enable sliced VAE decoding.
1089
+
1090
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
1091
+ steps. This is useful to save some memory and allow larger batch sizes.
1092
+ """
1093
+ self.vae.enable_slicing()
1094
+
1095
+ def disable_vae_slicing(self):
1096
+ r"""
1097
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
1098
+ computing decoding in one step.
1099
+ """
1100
+ self.vae.disable_slicing()
1101
+
1102
+ def enable_vae_tiling(self):
1103
+ r"""
1104
+ Enable tiled VAE decoding.
1105
+
1106
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
1107
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
1108
+ """
1109
+ self.vae.enable_tiling()
1110
+
1111
+ def disable_vae_tiling(self):
1112
+ r"""
1113
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
1114
+ computing decoding in one step.
1115
+ """
1116
+ self.vae.disable_tiling()
1117
+
1118
+ def enable_sequential_cpu_offload(self, gpu_id=0):
1119
+ r"""
1120
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
1121
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
1122
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
1123
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
1124
+ `enable_model_cpu_offload`, but performance is lower.
1125
+ """
1126
+ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
1127
+ from accelerate import cpu_offload
1128
+ else:
1129
+ raise ImportError(
1130
+ "`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher"
1131
+ )
1132
+
1133
+ device = torch.device(f"cuda:{gpu_id}")
1134
+
1135
+ if self.device.type != "cpu":
1136
+ self.to("cpu", silence_dtype_warnings=True)
1137
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
1138
+
1139
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
1140
+ cpu_offload(cpu_offloaded_model, device)
1141
+
1142
+ def enable_model_cpu_offload(self, gpu_id=0):
1143
+ r"""
1144
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
1145
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
1146
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
1147
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
1148
+ """
1149
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
1150
+ from accelerate import cpu_offload_with_hook
1151
+ else:
1152
+ raise ImportError(
1153
+ "`enable_model_offload` requires `accelerate v0.17.0` or higher."
1154
+ )
1155
+
1156
+ device = torch.device(f"cuda:{gpu_id}")
1157
+
1158
+ if self.device.type != "cpu":
1159
+ self.to("cpu", silence_dtype_warnings=True)
1160
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
1161
+
1162
+ hook = None
1163
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
1164
+ _, hook = cpu_offload_with_hook(
1165
+ cpu_offloaded_model, device, prev_module_hook=hook
1166
+ )
1167
+
1168
+ # We'll offload the last model manually.
1169
+ self.final_offload_hook = hook
1170
+
1171
+ @property
1172
+ def _execution_device(self):
1173
+ r"""
1174
+ Returns the device on which the pipeline's models will be executed. After calling
1175
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
1176
+ hooks.
1177
+ """
1178
+ if not hasattr(self.unet, "_hf_hook"):
1179
+ return self.device
1180
+ for module in self.unet.modules():
1181
+ if (
1182
+ hasattr(module, "_hf_hook")
1183
+ and hasattr(module._hf_hook, "execution_device")
1184
+ and module._hf_hook.execution_device is not None
1185
+ ):
1186
+ return torch.device(module._hf_hook.execution_device)
1187
+ return self.device
1188
+
1189
+ def _encode_prompt(
1190
+ self,
1191
+ prompt,
1192
+ device,
1193
+ num_images_per_prompt,
1194
+ do_classifier_free_guidance: bool,
1195
+ negative_prompt=None,
1196
+ ):
1197
+ r"""
1198
+ Encodes the prompt into text encoder hidden states.
1199
+
1200
+ Args:
1201
+ prompt (`str` or `List[str]`, *optional*):
1202
+ prompt to be encoded
1203
+ device: (`torch.device`):
1204
+ torch device
1205
+ num_images_per_prompt (`int`):
1206
+ number of images that should be generated per prompt
1207
+ do_classifier_free_guidance (`bool`):
1208
+ whether to use classifier free guidance or not
1209
+ negative_prompt (`str` or `List[str]`, *optional*):
1210
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
1211
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
1212
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
1213
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1214
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1215
+ provided, text embeddings will be generated from `prompt` input argument.
1216
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1217
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1218
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1219
+ argument.
1220
+ """
1221
+ if prompt is not None and isinstance(prompt, str):
1222
+ batch_size = 1
1223
+ elif prompt is not None and isinstance(prompt, list):
1224
+ batch_size = len(prompt)
1225
+ else:
1226
+ raise ValueError(
1227
+ f"`prompt` should be either a string or a list of strings, but got {type(prompt)}."
1228
+ )
1229
+
1230
+ text_inputs = self.tokenizer(
1231
+ prompt,
1232
+ padding="max_length",
1233
+ max_length=self.tokenizer.model_max_length,
1234
+ truncation=True,
1235
+ return_tensors="pt",
1236
+ )
1237
+ text_input_ids = text_inputs.input_ids
1238
+ untruncated_ids = self.tokenizer(
1239
+ prompt, padding="longest", return_tensors="pt"
1240
+ ).input_ids
1241
+
1242
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
1243
+ text_input_ids, untruncated_ids
1244
+ ):
1245
+ removed_text = self.tokenizer.batch_decode(
1246
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
1247
+ )
1248
+ logger.warning(
1249
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
1250
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
1251
+ )
1252
+
1253
+ if (
1254
+ hasattr(self.text_encoder.config, "use_attention_mask")
1255
+ and self.text_encoder.config.use_attention_mask
1256
+ ):
1257
+ attention_mask = text_inputs.attention_mask.to(device)
1258
+ else:
1259
+ attention_mask = None
1260
+
1261
+ prompt_embeds = self.text_encoder(
1262
+ text_input_ids.to(device),
1263
+ attention_mask=attention_mask,
1264
+ )
1265
+ prompt_embeds = prompt_embeds[0]
1266
+
1267
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
1268
+
1269
+ bs_embed, seq_len, _ = prompt_embeds.shape
1270
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
1271
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
1272
+ prompt_embeds = prompt_embeds.view(
1273
+ bs_embed * num_images_per_prompt, seq_len, -1
1274
+ )
1275
+
1276
+ # get unconditional embeddings for classifier free guidance
1277
+ if do_classifier_free_guidance:
1278
+ uncond_tokens: List[str]
1279
+ if negative_prompt is None:
1280
+ uncond_tokens = [""] * batch_size
1281
+ elif type(prompt) is not type(negative_prompt):
1282
+ raise TypeError(
1283
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
1284
+ f" {type(prompt)}."
1285
+ )
1286
+ elif isinstance(negative_prompt, str):
1287
+ uncond_tokens = [negative_prompt]
1288
+ elif batch_size != len(negative_prompt):
1289
+ raise ValueError(
1290
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
1291
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
1292
+ " the batch size of `prompt`."
1293
+ )
1294
+ else:
1295
+ uncond_tokens = negative_prompt
1296
+
1297
+ max_length = prompt_embeds.shape[1]
1298
+ uncond_input = self.tokenizer(
1299
+ uncond_tokens,
1300
+ padding="max_length",
1301
+ max_length=max_length,
1302
+ truncation=True,
1303
+ return_tensors="pt",
1304
+ )
1305
+
1306
+ if (
1307
+ hasattr(self.text_encoder.config, "use_attention_mask")
1308
+ and self.text_encoder.config.use_attention_mask
1309
+ ):
1310
+ attention_mask = uncond_input.attention_mask.to(device)
1311
+ else:
1312
+ attention_mask = None
1313
+
1314
+ negative_prompt_embeds = self.text_encoder(
1315
+ uncond_input.input_ids.to(device),
1316
+ attention_mask=attention_mask,
1317
+ )
1318
+ negative_prompt_embeds = negative_prompt_embeds[0]
1319
+
1320
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
1321
+ seq_len = negative_prompt_embeds.shape[1]
1322
+
1323
+ negative_prompt_embeds = negative_prompt_embeds.to(
1324
+ dtype=self.text_encoder.dtype, device=device
1325
+ )
1326
+
1327
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
1328
+ 1, num_images_per_prompt, 1
1329
+ )
1330
+ negative_prompt_embeds = negative_prompt_embeds.view(
1331
+ batch_size * num_images_per_prompt, seq_len, -1
1332
+ )
1333
+
1334
+ # For classifier free guidance, we need to do two forward passes.
1335
+ # Here we concatenate the unconditional and text embeddings into a single batch
1336
+ # to avoid doing two forward passes
1337
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
1338
+
1339
+ return prompt_embeds
1340
+
1341
+ def decode_latents(self, latents):
1342
+ latents = 1 / self.vae.config.scaling_factor * latents
1343
+ image = self.vae.decode(latents).sample
1344
+ image = (image / 2 + 0.5).clamp(0, 1)
1345
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
1346
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
1347
+ return image
1348
+
1349
+ def prepare_extra_step_kwargs(self, generator, eta):
1350
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
1351
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
1352
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
1353
+ # and should be between [0, 1]
1354
+
1355
+ accepts_eta = "eta" in set(
1356
+ inspect.signature(self.scheduler.step).parameters.keys()
1357
+ )
1358
+ extra_step_kwargs = {}
1359
+ if accepts_eta:
1360
+ extra_step_kwargs["eta"] = eta
1361
+
1362
+ # check if the scheduler accepts generator
1363
+ accepts_generator = "generator" in set(
1364
+ inspect.signature(self.scheduler.step).parameters.keys()
1365
+ )
1366
+ if accepts_generator:
1367
+ extra_step_kwargs["generator"] = generator
1368
+ return extra_step_kwargs
1369
+
1370
+ def prepare_latents(
1371
+ self,
1372
+ batch_size,
1373
+ num_channels_latents,
1374
+ height,
1375
+ width,
1376
+ dtype,
1377
+ device,
1378
+ generator,
1379
+ latents=None,
1380
+ ):
1381
+ shape = (
1382
+ batch_size,
1383
+ num_channels_latents,
1384
+ height // self.vae_scale_factor,
1385
+ width // self.vae_scale_factor,
1386
+ )
1387
+ if isinstance(generator, list) and len(generator) != batch_size:
1388
+ raise ValueError(
1389
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
1390
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
1391
+ )
1392
+
1393
+ if latents is None:
1394
+ latents = randn_tensor(
1395
+ shape, generator=generator, device=device, dtype=dtype
1396
+ )
1397
+ else:
1398
+ latents = latents.to(device)
1399
+
1400
+ # scale the initial noise by the standard deviation required by the scheduler
1401
+ latents = latents * self.scheduler.init_noise_sigma
1402
+ return latents
1403
+
1404
+ def encode_image(self, image, device, num_images_per_prompt):
1405
+ dtype = next(self.image_encoder.parameters()).dtype
1406
+
1407
+ if image.dtype == np.float32:
1408
+ image = (image * 255).astype(np.uint8)
1409
+
1410
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
1411
+ image = image.to(device=device, dtype=dtype)
1412
+
1413
+ image_embeds = self.image_encoder(
1414
+ image, output_hidden_states=True
1415
+ ).hidden_states[-2]
1416
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
1417
+
1418
+ return torch.zeros_like(image_embeds), image_embeds
1419
+
1420
+ def encode_image_latents(self, image, device, num_images_per_prompt):
1421
+
1422
+ dtype = next(self.image_encoder.parameters()).dtype
1423
+
1424
+ image = (
1425
+ torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2).to(device=device)
1426
+ ) # [1, 3, H, W]
1427
+ image = 2 * image - 1
1428
+ image = F.interpolate(image, (256, 256), mode="bilinear", align_corners=False)
1429
+ image = image.to(dtype=dtype)
1430
+
1431
+ posterior = self.vae.encode(image).latent_dist
1432
+ latents = posterior.sample() * self.vae.config.scaling_factor # [B, C, H, W]
1433
+ latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
1434
+
1435
+ return torch.zeros_like(latents), latents
1436
+
1437
+ @torch.no_grad()
1438
+ def __call__(
1439
+ self,
1440
+ prompt: str = "",
1441
+ image: Optional[np.ndarray] = None,
1442
+ height: int = 256,
1443
+ width: int = 256,
1444
+ elevation: float = 0,
1445
+ num_inference_steps: int = 50,
1446
+ guidance_scale: float = 7.0,
1447
+ negative_prompt: str = "",
1448
+ num_images_per_prompt: int = 1,
1449
+ eta: float = 0.0,
1450
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1451
+ output_type: Optional[str] = "numpy", # pil, numpy, latents
1452
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1453
+ callback_steps: int = 1,
1454
+ num_frames: int = 4,
1455
+ device=torch.device("cuda:0"),
1456
+ ):
1457
+ self.unet = self.unet.to(device=device)
1458
+ self.vae = self.vae.to(device=device)
1459
+ self.text_encoder = self.text_encoder.to(device=device)
1460
+
1461
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1462
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1463
+ # corresponds to doing no classifier free guidance.
1464
+ do_classifier_free_guidance = guidance_scale > 1.0
1465
+
1466
+ # Prepare timesteps
1467
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
1468
+ timesteps = self.scheduler.timesteps
1469
+
1470
+ # imagedream variant
1471
+ if image is not None:
1472
+ assert isinstance(image, np.ndarray) and image.dtype == np.float32
1473
+ self.image_encoder = self.image_encoder.to(device=device)
1474
+ image_embeds_neg, image_embeds_pos = self.encode_image(
1475
+ image, device, num_images_per_prompt
1476
+ )
1477
+ image_latents_neg, image_latents_pos = self.encode_image_latents(
1478
+ image, device, num_images_per_prompt
1479
+ )
1480
+
1481
+ _prompt_embeds = self._encode_prompt(
1482
+ prompt=prompt,
1483
+ device=device,
1484
+ num_images_per_prompt=num_images_per_prompt,
1485
+ do_classifier_free_guidance=do_classifier_free_guidance,
1486
+ negative_prompt=negative_prompt,
1487
+ ) # type: ignore
1488
+ prompt_embeds_neg, prompt_embeds_pos = _prompt_embeds.chunk(2)
1489
+
1490
+ # Prepare latent variables
1491
+ actual_num_frames = num_frames if image is None else num_frames + 1
1492
+ latents: torch.Tensor = self.prepare_latents(
1493
+ actual_num_frames * num_images_per_prompt,
1494
+ 4,
1495
+ height,
1496
+ width,
1497
+ prompt_embeds_pos.dtype,
1498
+ device,
1499
+ generator,
1500
+ None,
1501
+ )
1502
+
1503
+ # Get camera
1504
+ camera = get_camera(
1505
+ num_frames, elevation=elevation, extra_view=(image is not None)
1506
+ ).to(dtype=latents.dtype, device=device)
1507
+ camera = camera.repeat_interleave(num_images_per_prompt, dim=0)
1508
+
1509
+ # Prepare extra step kwargs.
1510
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1511
+
1512
+ # Denoising loop
1513
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1514
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1515
+ for i, t in enumerate(timesteps):
1516
+ # expand the latents if we are doing classifier free guidance
1517
+ multiplier = 2 if do_classifier_free_guidance else 1
1518
+ latent_model_input = torch.cat([latents] * multiplier)
1519
+ latent_model_input = self.scheduler.scale_model_input(
1520
+ latent_model_input, t
1521
+ )
1522
+
1523
+ unet_inputs = {
1524
+ "x": latent_model_input,
1525
+ "timesteps": torch.tensor(
1526
+ [t] * actual_num_frames * multiplier,
1527
+ dtype=latent_model_input.dtype,
1528
+ device=device,
1529
+ ),
1530
+ "context": torch.cat(
1531
+ [prompt_embeds_neg] * actual_num_frames
1532
+ + [prompt_embeds_pos] * actual_num_frames
1533
+ ),
1534
+ "num_frames": actual_num_frames,
1535
+ "camera": torch.cat([camera] * multiplier),
1536
+ }
1537
+
1538
+ if image is not None:
1539
+ unet_inputs["ip"] = torch.cat(
1540
+ [image_embeds_neg] * actual_num_frames
1541
+ + [image_embeds_pos] * actual_num_frames
1542
+ )
1543
+ unet_inputs["ip_img"] = torch.cat(
1544
+ [image_latents_neg] + [image_latents_pos]
1545
+ ) # no repeat
1546
+
1547
+ # predict the noise residual
1548
+ noise_pred = self.unet.forward(**unet_inputs)
1549
+
1550
+ # perform guidance
1551
+ if do_classifier_free_guidance:
1552
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1553
+ noise_pred = noise_pred_uncond + guidance_scale * (
1554
+ noise_pred_text - noise_pred_uncond
1555
+ )
1556
+
1557
+ # compute the previous noisy sample x_t -> x_t-1
1558
+ latents: torch.Tensor = self.scheduler.step(
1559
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
1560
+ )[0]
1561
+
1562
+ # call the callback, if provided
1563
+ if i == len(timesteps) - 1 or (
1564
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
1565
+ ):
1566
+ progress_bar.update()
1567
+ if callback is not None and i % callback_steps == 0:
1568
+ callback(i, t, latents) # type: ignore
1569
+
1570
+ # Post-processing
1571
+ if output_type == "latent":
1572
+ image = latents
1573
+ elif output_type == "pil":
1574
+ image = self.decode_latents(latents)
1575
+ image = self.numpy_to_pil(image)
1576
+ else: # numpy
1577
+ image = self.decode_latents(latents)
1578
+
1579
+ # Offload last model to CPU
1580
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1581
+ self.final_offload_hook.offload()
1582
+
1583
+ return image
requirements.txt ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.28.0
2
+ certifi==2024.2.2
3
+ charset-normalizer==3.3.2
4
+ diffusers==0.27.2
5
+ einops==0.7.0
6
+ executing==2.0.1
7
+ filelock==3.13.3
8
+ fsspec==2024.3.1
9
+ huggingface-hub==0.22.2
10
+ idna==3.6
11
+ importlib_metadata==7.1.0
12
+ Jinja2==3.1.3
13
+ kiui==0.2.7
14
+ lazy_loader==0.3
15
+ markdown-it-py==3.0.0
16
+ MarkupSafe==2.1.5
17
+ mdurl==0.1.2
18
+ mpmath==1.3.0
19
+ networkx==3.2.1
20
+ numpy==1.26.4
21
+ nvidia-cublas-cu12==12.1.3.1
22
+ nvidia-cuda-cupti-cu12==12.1.105
23
+ nvidia-cuda-nvrtc-cu12==12.1.105
24
+ nvidia-cuda-runtime-cu12==12.1.105
25
+ nvidia-cudnn-cu12==8.9.2.26
26
+ nvidia-cufft-cu12==11.0.2.54
27
+ nvidia-curand-cu12==10.3.2.106
28
+ nvidia-cusolver-cu12==11.4.5.107
29
+ nvidia-cusparse-cu12==12.1.0.106
30
+ nvidia-nccl-cu12==2.19.3
31
+ nvidia-nvjitlink-cu12==12.4.99
32
+ nvidia-nvtx-cu12==12.1.105
33
+ objprint==0.2.3
34
+ opencv-python==4.9.0.80
35
+ packaging==24.0
36
+ pillow==10.2.0
37
+ psutil==5.9.8
38
+ Pygments==2.17.2
39
+ PyYAML==6.0.1
40
+ regex==2023.12.25
41
+ requests==2.31.0
42
+ rich==13.7.1
43
+ safetensors==0.4.2
44
+ scipy==1.12.0
45
+ sympy==1.12
46
+ tokenizers==0.15.2
47
+ torch==2.2.2
48
+ tqdm==4.66.2
49
+ transformers==4.39.2
50
+ triton==2.2.0
51
+ typing_extensions==4.10.0
52
+ urllib3==2.2.1
53
+ varname==0.13.0
54
+ xformers==0.0.25.post1
55
+ zipp==3.18.1
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "DDIMScheduler",
3
+ "_diffusers_version": "0.25.0",
4
+ "beta_end": 0.012,
5
+ "beta_schedule": "scaled_linear",
6
+ "beta_start": 0.00085,
7
+ "clip_sample": false,
8
+ "clip_sample_range": 1.0,
9
+ "dynamic_thresholding_ratio": 0.995,
10
+ "num_train_timesteps": 1000,
11
+ "prediction_type": "epsilon",
12
+ "rescale_betas_zero_snr": false,
13
+ "sample_max_value": 1.0,
14
+ "set_alpha_to_one": false,
15
+ "steps_offset": 1,
16
+ "thresholding": false,
17
+ "timestep_spacing": "leading",
18
+ "trained_betas": null
19
+ }
text_encoder/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "stabilityai/stable-diffusion-2-1",
3
+ "architectures": [
4
+ "CLIPTextModel"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 0,
8
+ "dropout": 0.0,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "gelu",
11
+ "hidden_size": 1024,
12
+ "initializer_factor": 1.0,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 4096,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 77,
17
+ "model_type": "clip_text_model",
18
+ "num_attention_heads": 16,
19
+ "num_hidden_layers": 23,
20
+ "pad_token_id": 1,
21
+ "projection_dim": 512,
22
+ "torch_dtype": "float16",
23
+ "transformers_version": "4.35.2",
24
+ "vocab_size": 49408
25
+ }
text_encoder/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc1827c465450322616f06dea41596eac7d493f4e95904dcb51f0fc745c4e13f
3
+ size 680820392
tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|startoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "!",
17
+ "unk_token": {
18
+ "content": "<|endoftext|>",
19
+ "lstrip": false,
20
+ "normalized": true,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "!",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "49406": {
13
+ "content": "<|startoftext|>",
14
+ "lstrip": false,
15
+ "normalized": true,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "49407": {
21
+ "content": "<|endoftext|>",
22
+ "lstrip": false,
23
+ "normalized": true,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ }
28
+ },
29
+ "bos_token": "<|startoftext|>",
30
+ "clean_up_tokenization_spaces": true,
31
+ "do_lower_case": true,
32
+ "eos_token": "<|endoftext|>",
33
+ "errors": "replace",
34
+ "model_max_length": 77,
35
+ "pad_token": "!",
36
+ "tokenizer_class": "CLIPTokenizer",
37
+ "unk_token": "<|endoftext|>"
38
+ }
tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
unet/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "MultiViewUNetModel",
3
+ "_diffusers_version": "0.25.0",
4
+ "attention_resolutions": [
5
+ 4,
6
+ 2,
7
+ 1
8
+ ],
9
+ "camera_dim": 16,
10
+ "channel_mult": [
11
+ 1,
12
+ 2,
13
+ 4,
14
+ 4
15
+ ],
16
+ "context_dim": 1024,
17
+ "image_size": 32,
18
+ "in_channels": 4,
19
+ "ip_dim": 16,
20
+ "model_channels": 320,
21
+ "num_head_channels": 64,
22
+ "num_res_blocks": 2,
23
+ "out_channels": 4,
24
+ "transformer_depth": 1
25
+ }
unet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:28d8b241a54125fa0a041c1818a5dcdb717e6f5270eea1268172acd3ab0238e0
3
+ size 1883435904
unet/mv_unet.py ADDED
@@ -0,0 +1,1005 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ from inspect import isfunction
4
+ from typing import Optional, Any, List
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange, repeat
10
+
11
+ from diffusers.configuration_utils import ConfigMixin
12
+ from diffusers.models.modeling_utils import ModelMixin
13
+
14
+ # require xformers!
15
+ import xformers
16
+ import xformers.ops
17
+
18
+ from kiui.cam import orbit_camera
19
+
20
+ def get_camera(
21
+ num_frames, elevation=15, azimuth_start=0, azimuth_span=360, blender_coord=True, extra_view=False,
22
+ ):
23
+ angle_gap = azimuth_span / num_frames
24
+ cameras = []
25
+ for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap):
26
+
27
+ pose = orbit_camera(-elevation, azimuth, radius=1) # kiui's elevation is negated, [4, 4]
28
+
29
+ # opengl to blender
30
+ if blender_coord:
31
+ pose[2] *= -1
32
+ pose[[1, 2]] = pose[[2, 1]]
33
+
34
+ cameras.append(pose.flatten())
35
+
36
+ if extra_view:
37
+ cameras.append(np.zeros_like(cameras[0]))
38
+
39
+ return torch.from_numpy(np.stack(cameras, axis=0)).float() # [num_frames, 16]
40
+
41
+
42
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
43
+ """
44
+ Create sinusoidal timestep embeddings.
45
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
46
+ These may be fractional.
47
+ :param dim: the dimension of the output.
48
+ :param max_period: controls the minimum frequency of the embeddings.
49
+ :return: an [N x dim] Tensor of positional embeddings.
50
+ """
51
+ if not repeat_only:
52
+ half = dim // 2
53
+ freqs = torch.exp(
54
+ -math.log(max_period)
55
+ * torch.arange(start=0, end=half, dtype=torch.float32)
56
+ / half
57
+ ).to(device=timesteps.device)
58
+ args = timesteps[:, None] * freqs[None]
59
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
60
+ if dim % 2:
61
+ embedding = torch.cat(
62
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
63
+ )
64
+ else:
65
+ embedding = repeat(timesteps, "b -> b d", d=dim)
66
+ # import pdb; pdb.set_trace()
67
+ return embedding
68
+
69
+
70
+ def zero_module(module):
71
+ """
72
+ Zero out the parameters of a module and return it.
73
+ """
74
+ for p in module.parameters():
75
+ p.detach().zero_()
76
+ return module
77
+
78
+
79
+ def conv_nd(dims, *args, **kwargs):
80
+ """
81
+ Create a 1D, 2D, or 3D convolution module.
82
+ """
83
+ if dims == 1:
84
+ return nn.Conv1d(*args, **kwargs)
85
+ elif dims == 2:
86
+ return nn.Conv2d(*args, **kwargs)
87
+ elif dims == 3:
88
+ return nn.Conv3d(*args, **kwargs)
89
+ raise ValueError(f"unsupported dimensions: {dims}")
90
+
91
+
92
+ def avg_pool_nd(dims, *args, **kwargs):
93
+ """
94
+ Create a 1D, 2D, or 3D average pooling module.
95
+ """
96
+ if dims == 1:
97
+ return nn.AvgPool1d(*args, **kwargs)
98
+ elif dims == 2:
99
+ return nn.AvgPool2d(*args, **kwargs)
100
+ elif dims == 3:
101
+ return nn.AvgPool3d(*args, **kwargs)
102
+ raise ValueError(f"unsupported dimensions: {dims}")
103
+
104
+
105
+ def default(val, d):
106
+ if val is not None:
107
+ return val
108
+ return d() if isfunction(d) else d
109
+
110
+
111
+ class GEGLU(nn.Module):
112
+ def __init__(self, dim_in, dim_out):
113
+ super().__init__()
114
+ self.proj = nn.Linear(dim_in, dim_out * 2)
115
+
116
+ def forward(self, x):
117
+ x, gate = self.proj(x).chunk(2, dim=-1)
118
+ return x * F.gelu(gate)
119
+
120
+
121
+ class FeedForward(nn.Module):
122
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
123
+ super().__init__()
124
+ inner_dim = int(dim * mult)
125
+ dim_out = default(dim_out, dim)
126
+ project_in = (
127
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
128
+ if not glu
129
+ else GEGLU(dim, inner_dim)
130
+ )
131
+
132
+ self.net = nn.Sequential(
133
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
134
+ )
135
+
136
+ def forward(self, x):
137
+ return self.net(x)
138
+
139
+
140
+ class MemoryEfficientCrossAttention(nn.Module):
141
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
142
+ def __init__(
143
+ self,
144
+ query_dim,
145
+ context_dim=None,
146
+ heads=8,
147
+ dim_head=64,
148
+ dropout=0.0,
149
+ ip_dim=0,
150
+ ip_weight=1,
151
+ ):
152
+ super().__init__()
153
+
154
+ inner_dim = dim_head * heads
155
+ context_dim = default(context_dim, query_dim)
156
+
157
+ self.heads = heads
158
+ self.dim_head = dim_head
159
+
160
+ self.ip_dim = ip_dim
161
+ self.ip_weight = ip_weight
162
+
163
+ if self.ip_dim > 0:
164
+ self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
165
+ self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
166
+
167
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
168
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
169
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
170
+
171
+ self.to_out = nn.Sequential(
172
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
173
+ )
174
+ self.attention_op: Optional[Any] = None
175
+
176
+ def forward(self, x, context=None):
177
+ q = self.to_q(x)
178
+ context = default(context, x)
179
+
180
+ if self.ip_dim > 0:
181
+ # context: [B, 77 + 16(ip), 1024]
182
+ token_len = context.shape[1]
183
+ context_ip = context[:, -self.ip_dim :, :]
184
+ k_ip = self.to_k_ip(context_ip)
185
+ v_ip = self.to_v_ip(context_ip)
186
+ context = context[:, : (token_len - self.ip_dim), :]
187
+
188
+ k = self.to_k(context)
189
+ v = self.to_v(context)
190
+
191
+ b, _, _ = q.shape
192
+ q, k, v = map(
193
+ lambda t: t.unsqueeze(3)
194
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
195
+ .permute(0, 2, 1, 3)
196
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
197
+ .contiguous(),
198
+ (q, k, v),
199
+ )
200
+
201
+ # actually compute the attention, what we cannot get enough of
202
+ out = xformers.ops.memory_efficient_attention(
203
+ q, k, v, attn_bias=None, op=self.attention_op
204
+ )
205
+
206
+ if self.ip_dim > 0:
207
+ k_ip, v_ip = map(
208
+ lambda t: t.unsqueeze(3)
209
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
210
+ .permute(0, 2, 1, 3)
211
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
212
+ .contiguous(),
213
+ (k_ip, v_ip),
214
+ )
215
+ # actually compute the attention, what we cannot get enough of
216
+ out_ip = xformers.ops.memory_efficient_attention(
217
+ q, k_ip, v_ip, attn_bias=None, op=self.attention_op
218
+ )
219
+ out = out + self.ip_weight * out_ip
220
+
221
+ out = (
222
+ out.unsqueeze(0)
223
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
224
+ .permute(0, 2, 1, 3)
225
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
226
+ )
227
+ return self.to_out(out)
228
+
229
+
230
+ class BasicTransformerBlock3D(nn.Module):
231
+
232
+ def __init__(
233
+ self,
234
+ dim,
235
+ n_heads,
236
+ d_head,
237
+ context_dim,
238
+ dropout=0.0,
239
+ gated_ff=True,
240
+ ip_dim=0,
241
+ ip_weight=1,
242
+ ):
243
+ super().__init__()
244
+
245
+ self.attn1 = MemoryEfficientCrossAttention(
246
+ query_dim=dim,
247
+ context_dim=None, # self-attention
248
+ heads=n_heads,
249
+ dim_head=d_head,
250
+ dropout=dropout,
251
+ )
252
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
253
+ self.attn2 = MemoryEfficientCrossAttention(
254
+ query_dim=dim,
255
+ context_dim=context_dim,
256
+ heads=n_heads,
257
+ dim_head=d_head,
258
+ dropout=dropout,
259
+ # ip only applies to cross-attention
260
+ ip_dim=ip_dim,
261
+ ip_weight=ip_weight,
262
+ )
263
+ self.norm1 = nn.LayerNorm(dim)
264
+ self.norm2 = nn.LayerNorm(dim)
265
+ self.norm3 = nn.LayerNorm(dim)
266
+
267
+ def forward(self, x, context=None, num_frames=1):
268
+ x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
269
+ x = self.attn1(self.norm1(x), context=None) + x
270
+ x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
271
+ x = self.attn2(self.norm2(x), context=context) + x
272
+ x = self.ff(self.norm3(x)) + x
273
+ return x
274
+
275
+
276
+ class SpatialTransformer3D(nn.Module):
277
+
278
+ def __init__(
279
+ self,
280
+ in_channels,
281
+ n_heads,
282
+ d_head,
283
+ context_dim, # cross attention input dim
284
+ depth=1,
285
+ dropout=0.0,
286
+ ip_dim=0,
287
+ ip_weight=1,
288
+ ):
289
+ super().__init__()
290
+
291
+ if not isinstance(context_dim, list):
292
+ context_dim = [context_dim]
293
+
294
+ self.in_channels = in_channels
295
+
296
+ inner_dim = n_heads * d_head
297
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
298
+ self.proj_in = nn.Linear(in_channels, inner_dim)
299
+
300
+ self.transformer_blocks = nn.ModuleList(
301
+ [
302
+ BasicTransformerBlock3D(
303
+ inner_dim,
304
+ n_heads,
305
+ d_head,
306
+ context_dim=context_dim[d],
307
+ dropout=dropout,
308
+ ip_dim=ip_dim,
309
+ ip_weight=ip_weight,
310
+ )
311
+ for d in range(depth)
312
+ ]
313
+ )
314
+
315
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
316
+
317
+
318
+ def forward(self, x, context=None, num_frames=1):
319
+ # note: if no context is given, cross-attention defaults to self-attention
320
+ if not isinstance(context, list):
321
+ context = [context]
322
+ b, c, h, w = x.shape
323
+ x_in = x
324
+ x = self.norm(x)
325
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
326
+ x = self.proj_in(x)
327
+ for i, block in enumerate(self.transformer_blocks):
328
+ x = block(x, context=context[i], num_frames=num_frames)
329
+ x = self.proj_out(x)
330
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
331
+
332
+ return x + x_in
333
+
334
+
335
+ class PerceiverAttention(nn.Module):
336
+ def __init__(self, *, dim, dim_head=64, heads=8):
337
+ super().__init__()
338
+ self.scale = dim_head ** -0.5
339
+ self.dim_head = dim_head
340
+ self.heads = heads
341
+ inner_dim = dim_head * heads
342
+
343
+ self.norm1 = nn.LayerNorm(dim)
344
+ self.norm2 = nn.LayerNorm(dim)
345
+
346
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
347
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
348
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
349
+
350
+ def forward(self, x, latents):
351
+ """
352
+ Args:
353
+ x (torch.Tensor): image features
354
+ shape (b, n1, D)
355
+ latent (torch.Tensor): latent features
356
+ shape (b, n2, D)
357
+ """
358
+ x = self.norm1(x)
359
+ latents = self.norm2(latents)
360
+
361
+ b, l, _ = latents.shape
362
+
363
+ q = self.to_q(latents)
364
+ kv_input = torch.cat((x, latents), dim=-2)
365
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
366
+
367
+ q, k, v = map(
368
+ lambda t: t.reshape(b, t.shape[1], self.heads, -1)
369
+ .transpose(1, 2)
370
+ .reshape(b, self.heads, t.shape[1], -1)
371
+ .contiguous(),
372
+ (q, k, v),
373
+ )
374
+
375
+ # attention
376
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
377
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
378
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
379
+ out = weight @ v
380
+
381
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
382
+
383
+ return self.to_out(out)
384
+
385
+
386
+ class Resampler(nn.Module):
387
+ def __init__(
388
+ self,
389
+ dim=1024,
390
+ depth=8,
391
+ dim_head=64,
392
+ heads=16,
393
+ num_queries=8,
394
+ embedding_dim=768,
395
+ output_dim=1024,
396
+ ff_mult=4,
397
+ ):
398
+ super().__init__()
399
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5)
400
+ self.proj_in = nn.Linear(embedding_dim, dim)
401
+ self.proj_out = nn.Linear(dim, output_dim)
402
+ self.norm_out = nn.LayerNorm(output_dim)
403
+
404
+ self.layers = nn.ModuleList([])
405
+ for _ in range(depth):
406
+ self.layers.append(
407
+ nn.ModuleList(
408
+ [
409
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
410
+ nn.Sequential(
411
+ nn.LayerNorm(dim),
412
+ nn.Linear(dim, dim * ff_mult, bias=False),
413
+ nn.GELU(),
414
+ nn.Linear(dim * ff_mult, dim, bias=False),
415
+ )
416
+ ]
417
+ )
418
+ )
419
+
420
+ def forward(self, x):
421
+ latents = self.latents.repeat(x.size(0), 1, 1)
422
+ x = self.proj_in(x)
423
+ for attn, ff in self.layers:
424
+ latents = attn(x, latents) + latents
425
+ latents = ff(latents) + latents
426
+
427
+ latents = self.proj_out(latents)
428
+ return self.norm_out(latents)
429
+
430
+
431
+ class CondSequential(nn.Sequential):
432
+ """
433
+ A sequential module that passes timestep embeddings to the children that
434
+ support it as an extra input.
435
+ """
436
+
437
+ def forward(self, x, emb, context=None, num_frames=1):
438
+ for layer in self:
439
+ if isinstance(layer, ResBlock):
440
+ x = layer(x, emb)
441
+ elif isinstance(layer, SpatialTransformer3D):
442
+ x = layer(x, context, num_frames=num_frames)
443
+ else:
444
+ x = layer(x)
445
+ return x
446
+
447
+
448
+ class Upsample(nn.Module):
449
+ """
450
+ An upsampling layer with an optional convolution.
451
+ :param channels: channels in the inputs and outputs.
452
+ :param use_conv: a bool determining if a convolution is applied.
453
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
454
+ upsampling occurs in the inner-two dimensions.
455
+ """
456
+
457
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
458
+ super().__init__()
459
+ self.channels = channels
460
+ self.out_channels = out_channels or channels
461
+ self.use_conv = use_conv
462
+ self.dims = dims
463
+ if use_conv:
464
+ self.conv = conv_nd(
465
+ dims, self.channels, self.out_channels, 3, padding=padding
466
+ )
467
+
468
+ def forward(self, x):
469
+ assert x.shape[1] == self.channels
470
+ if self.dims == 3:
471
+ x = F.interpolate(
472
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
473
+ )
474
+ else:
475
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
476
+ if self.use_conv:
477
+ x = self.conv(x)
478
+ return x
479
+
480
+
481
+ class Downsample(nn.Module):
482
+ """
483
+ A downsampling layer with an optional convolution.
484
+ :param channels: channels in the inputs and outputs.
485
+ :param use_conv: a bool determining if a convolution is applied.
486
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
487
+ downsampling occurs in the inner-two dimensions.
488
+ """
489
+
490
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
491
+ super().__init__()
492
+ self.channels = channels
493
+ self.out_channels = out_channels or channels
494
+ self.use_conv = use_conv
495
+ self.dims = dims
496
+ stride = 2 if dims != 3 else (1, 2, 2)
497
+ if use_conv:
498
+ self.op = conv_nd(
499
+ dims,
500
+ self.channels,
501
+ self.out_channels,
502
+ 3,
503
+ stride=stride,
504
+ padding=padding,
505
+ )
506
+ else:
507
+ assert self.channels == self.out_channels
508
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
509
+
510
+ def forward(self, x):
511
+ assert x.shape[1] == self.channels
512
+ return self.op(x)
513
+
514
+
515
+ class ResBlock(nn.Module):
516
+ """
517
+ A residual block that can optionally change the number of channels.
518
+ :param channels: the number of input channels.
519
+ :param emb_channels: the number of timestep embedding channels.
520
+ :param dropout: the rate of dropout.
521
+ :param out_channels: if specified, the number of out channels.
522
+ :param use_conv: if True and out_channels is specified, use a spatial
523
+ convolution instead of a smaller 1x1 convolution to change the
524
+ channels in the skip connection.
525
+ :param dims: determines if the signal is 1D, 2D, or 3D.
526
+ :param up: if True, use this block for upsampling.
527
+ :param down: if True, use this block for downsampling.
528
+ """
529
+
530
+ def __init__(
531
+ self,
532
+ channels,
533
+ emb_channels,
534
+ dropout,
535
+ out_channels=None,
536
+ use_conv=False,
537
+ use_scale_shift_norm=False,
538
+ dims=2,
539
+ up=False,
540
+ down=False,
541
+ ):
542
+ super().__init__()
543
+ self.channels = channels
544
+ self.emb_channels = emb_channels
545
+ self.dropout = dropout
546
+ self.out_channels = out_channels or channels
547
+ self.use_conv = use_conv
548
+ self.use_scale_shift_norm = use_scale_shift_norm
549
+
550
+ self.in_layers = nn.Sequential(
551
+ nn.GroupNorm(32, channels),
552
+ nn.SiLU(),
553
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
554
+ )
555
+
556
+ self.updown = up or down
557
+
558
+ if up:
559
+ self.h_upd = Upsample(channels, False, dims)
560
+ self.x_upd = Upsample(channels, False, dims)
561
+ elif down:
562
+ self.h_upd = Downsample(channels, False, dims)
563
+ self.x_upd = Downsample(channels, False, dims)
564
+ else:
565
+ self.h_upd = self.x_upd = nn.Identity()
566
+
567
+ self.emb_layers = nn.Sequential(
568
+ nn.SiLU(),
569
+ nn.Linear(
570
+ emb_channels,
571
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
572
+ ),
573
+ )
574
+ self.out_layers = nn.Sequential(
575
+ nn.GroupNorm(32, self.out_channels),
576
+ nn.SiLU(),
577
+ nn.Dropout(p=dropout),
578
+ zero_module(
579
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
580
+ ),
581
+ )
582
+
583
+ if self.out_channels == channels:
584
+ self.skip_connection = nn.Identity()
585
+ elif use_conv:
586
+ self.skip_connection = conv_nd(
587
+ dims, channels, self.out_channels, 3, padding=1
588
+ )
589
+ else:
590
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
591
+
592
+ def forward(self, x, emb):
593
+ if self.updown:
594
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
595
+ h = in_rest(x)
596
+ h = self.h_upd(h)
597
+ x = self.x_upd(x)
598
+ h = in_conv(h)
599
+ else:
600
+ h = self.in_layers(x)
601
+ emb_out = self.emb_layers(emb).type(h.dtype)
602
+ while len(emb_out.shape) < len(h.shape):
603
+ emb_out = emb_out[..., None]
604
+ if self.use_scale_shift_norm:
605
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
606
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
607
+ h = out_norm(h) * (1 + scale) + shift
608
+ h = out_rest(h)
609
+ else:
610
+ h = h + emb_out
611
+ h = self.out_layers(h)
612
+ return self.skip_connection(x) + h
613
+
614
+
615
+ class MultiViewUNetModel(ModelMixin, ConfigMixin):
616
+ """
617
+ The full multi-view UNet model with attention, timestep embedding and camera embedding.
618
+ :param in_channels: channels in the input Tensor.
619
+ :param model_channels: base channel count for the model.
620
+ :param out_channels: channels in the output Tensor.
621
+ :param num_res_blocks: number of residual blocks per downsample.
622
+ :param attention_resolutions: a collection of downsample rates at which
623
+ attention will take place. May be a set, list, or tuple.
624
+ For example, if this contains 4, then at 4x downsampling, attention
625
+ will be used.
626
+ :param dropout: the dropout probability.
627
+ :param channel_mult: channel multiplier for each level of the UNet.
628
+ :param conv_resample: if True, use learned convolutions for upsampling and
629
+ downsampling.
630
+ :param dims: determines if the signal is 1D, 2D, or 3D.
631
+ :param num_classes: if specified (as an int), then this model will be
632
+ class-conditional with `num_classes` classes.
633
+ :param num_heads: the number of attention heads in each attention layer.
634
+ :param num_heads_channels: if specified, ignore num_heads and instead use
635
+ a fixed channel width per attention head.
636
+ :param num_heads_upsample: works with num_heads to set a different number
637
+ of heads for upsampling. Deprecated.
638
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
639
+ :param resblock_updown: use residual blocks for up/downsampling.
640
+ :param use_new_attention_order: use a different attention pattern for potentially
641
+ increased efficiency.
642
+ :param camera_dim: dimensionality of camera input.
643
+ """
644
+
645
+ def __init__(
646
+ self,
647
+ image_size,
648
+ in_channels,
649
+ model_channels,
650
+ out_channels,
651
+ num_res_blocks,
652
+ attention_resolutions,
653
+ dropout=0,
654
+ channel_mult=(1, 2, 4, 8),
655
+ conv_resample=True,
656
+ dims=2,
657
+ num_classes=None,
658
+ num_heads=-1,
659
+ num_head_channels=-1,
660
+ num_heads_upsample=-1,
661
+ use_scale_shift_norm=False,
662
+ resblock_updown=False,
663
+ transformer_depth=1,
664
+ context_dim=None,
665
+ n_embed=None,
666
+ num_attention_blocks=None,
667
+ adm_in_channels=None,
668
+ camera_dim=None,
669
+ ip_dim=0, # imagedream uses ip_dim > 0
670
+ ip_weight=1.0,
671
+ **kwargs,
672
+ ):
673
+ super().__init__()
674
+ assert context_dim is not None
675
+
676
+ if num_heads_upsample == -1:
677
+ num_heads_upsample = num_heads
678
+
679
+ if num_heads == -1:
680
+ assert (
681
+ num_head_channels != -1
682
+ ), "Either num_heads or num_head_channels has to be set"
683
+
684
+ if num_head_channels == -1:
685
+ assert (
686
+ num_heads != -1
687
+ ), "Either num_heads or num_head_channels has to be set"
688
+
689
+ self.image_size = image_size
690
+ self.in_channels = in_channels
691
+ self.model_channels = model_channels
692
+ self.out_channels = out_channels
693
+ if isinstance(num_res_blocks, int):
694
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
695
+ else:
696
+ if len(num_res_blocks) != len(channel_mult):
697
+ raise ValueError(
698
+ "provide num_res_blocks either as an int (globally constant) or "
699
+ "as a list/tuple (per-level) with the same length as channel_mult"
700
+ )
701
+ self.num_res_blocks = num_res_blocks
702
+
703
+ if num_attention_blocks is not None:
704
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
705
+ assert all(
706
+ map(
707
+ lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
708
+ range(len(num_attention_blocks)),
709
+ )
710
+ )
711
+ print(
712
+ f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
713
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
714
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
715
+ f"attention will still not be set."
716
+ )
717
+
718
+ self.attention_resolutions = attention_resolutions
719
+ self.dropout = dropout
720
+ self.channel_mult = channel_mult
721
+ self.conv_resample = conv_resample
722
+ self.num_classes = num_classes
723
+ self.num_heads = num_heads
724
+ self.num_head_channels = num_head_channels
725
+ self.num_heads_upsample = num_heads_upsample
726
+ self.predict_codebook_ids = n_embed is not None
727
+
728
+ self.ip_dim = ip_dim
729
+ self.ip_weight = ip_weight
730
+
731
+ if self.ip_dim > 0:
732
+ self.image_embed = Resampler(
733
+ dim=context_dim,
734
+ depth=4,
735
+ dim_head=64,
736
+ heads=12,
737
+ num_queries=ip_dim, # num token
738
+ embedding_dim=1280,
739
+ output_dim=context_dim,
740
+ ff_mult=4,
741
+ )
742
+
743
+ time_embed_dim = model_channels * 4
744
+ self.time_embed = nn.Sequential(
745
+ nn.Linear(model_channels, time_embed_dim),
746
+ nn.SiLU(),
747
+ nn.Linear(time_embed_dim, time_embed_dim),
748
+ )
749
+
750
+ if camera_dim is not None:
751
+ time_embed_dim = model_channels * 4
752
+ self.camera_embed = nn.Sequential(
753
+ nn.Linear(camera_dim, time_embed_dim),
754
+ nn.SiLU(),
755
+ nn.Linear(time_embed_dim, time_embed_dim),
756
+ )
757
+
758
+ if self.num_classes is not None:
759
+ if isinstance(self.num_classes, int):
760
+ self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)
761
+ elif self.num_classes == "continuous":
762
+ # print("setting up linear c_adm embedding layer")
763
+ self.label_emb = nn.Linear(1, time_embed_dim)
764
+ elif self.num_classes == "sequential":
765
+ assert adm_in_channels is not None
766
+ self.label_emb = nn.Sequential(
767
+ nn.Sequential(
768
+ nn.Linear(adm_in_channels, time_embed_dim),
769
+ nn.SiLU(),
770
+ nn.Linear(time_embed_dim, time_embed_dim),
771
+ )
772
+ )
773
+ else:
774
+ raise ValueError()
775
+
776
+ self.input_blocks = nn.ModuleList(
777
+ [
778
+ CondSequential(
779
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
780
+ )
781
+ ]
782
+ )
783
+ self._feature_size = model_channels
784
+ input_block_chans = [model_channels]
785
+ ch = model_channels
786
+ ds = 1
787
+ for level, mult in enumerate(channel_mult):
788
+ for nr in range(self.num_res_blocks[level]):
789
+ layers: List[Any] = [
790
+ ResBlock(
791
+ ch,
792
+ time_embed_dim,
793
+ dropout,
794
+ out_channels=mult * model_channels,
795
+ dims=dims,
796
+ use_scale_shift_norm=use_scale_shift_norm,
797
+ )
798
+ ]
799
+ ch = mult * model_channels
800
+ if ds in attention_resolutions:
801
+ if num_head_channels == -1:
802
+ dim_head = ch // num_heads
803
+ else:
804
+ num_heads = ch // num_head_channels
805
+ dim_head = num_head_channels
806
+
807
+ if num_attention_blocks is None or nr < num_attention_blocks[level]:
808
+ layers.append(
809
+ SpatialTransformer3D(
810
+ ch,
811
+ num_heads,
812
+ dim_head,
813
+ context_dim=context_dim,
814
+ depth=transformer_depth,
815
+ ip_dim=self.ip_dim,
816
+ ip_weight=self.ip_weight,
817
+ )
818
+ )
819
+ self.input_blocks.append(CondSequential(*layers))
820
+ self._feature_size += ch
821
+ input_block_chans.append(ch)
822
+ if level != len(channel_mult) - 1:
823
+ out_ch = ch
824
+ self.input_blocks.append(
825
+ CondSequential(
826
+ ResBlock(
827
+ ch,
828
+ time_embed_dim,
829
+ dropout,
830
+ out_channels=out_ch,
831
+ dims=dims,
832
+ use_scale_shift_norm=use_scale_shift_norm,
833
+ down=True,
834
+ )
835
+ if resblock_updown
836
+ else Downsample(
837
+ ch, conv_resample, dims=dims, out_channels=out_ch
838
+ )
839
+ )
840
+ )
841
+ ch = out_ch
842
+ input_block_chans.append(ch)
843
+ ds *= 2
844
+ self._feature_size += ch
845
+
846
+ if num_head_channels == -1:
847
+ dim_head = ch // num_heads
848
+ else:
849
+ num_heads = ch // num_head_channels
850
+ dim_head = num_head_channels
851
+
852
+ self.middle_block = CondSequential(
853
+ ResBlock(
854
+ ch,
855
+ time_embed_dim,
856
+ dropout,
857
+ dims=dims,
858
+ use_scale_shift_norm=use_scale_shift_norm,
859
+ ),
860
+ SpatialTransformer3D(
861
+ ch,
862
+ num_heads,
863
+ dim_head,
864
+ context_dim=context_dim,
865
+ depth=transformer_depth,
866
+ ip_dim=self.ip_dim,
867
+ ip_weight=self.ip_weight,
868
+ ),
869
+ ResBlock(
870
+ ch,
871
+ time_embed_dim,
872
+ dropout,
873
+ dims=dims,
874
+ use_scale_shift_norm=use_scale_shift_norm,
875
+ ),
876
+ )
877
+ self._feature_size += ch
878
+
879
+ self.output_blocks = nn.ModuleList([])
880
+ for level, mult in list(enumerate(channel_mult))[::-1]:
881
+ for i in range(self.num_res_blocks[level] + 1):
882
+ ich = input_block_chans.pop()
883
+ layers = [
884
+ ResBlock(
885
+ ch + ich,
886
+ time_embed_dim,
887
+ dropout,
888
+ out_channels=model_channels * mult,
889
+ dims=dims,
890
+ use_scale_shift_norm=use_scale_shift_norm,
891
+ )
892
+ ]
893
+ ch = model_channels * mult
894
+ if ds in attention_resolutions:
895
+ if num_head_channels == -1:
896
+ dim_head = ch // num_heads
897
+ else:
898
+ num_heads = ch // num_head_channels
899
+ dim_head = num_head_channels
900
+
901
+ if num_attention_blocks is None or i < num_attention_blocks[level]:
902
+ layers.append(
903
+ SpatialTransformer3D(
904
+ ch,
905
+ num_heads,
906
+ dim_head,
907
+ context_dim=context_dim,
908
+ depth=transformer_depth,
909
+ ip_dim=self.ip_dim,
910
+ ip_weight=self.ip_weight,
911
+ )
912
+ )
913
+ if level and i == self.num_res_blocks[level]:
914
+ out_ch = ch
915
+ layers.append(
916
+ ResBlock(
917
+ ch,
918
+ time_embed_dim,
919
+ dropout,
920
+ out_channels=out_ch,
921
+ dims=dims,
922
+ use_scale_shift_norm=use_scale_shift_norm,
923
+ up=True,
924
+ )
925
+ if resblock_updown
926
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
927
+ )
928
+ ds //= 2
929
+ self.output_blocks.append(CondSequential(*layers))
930
+ self._feature_size += ch
931
+
932
+ self.out = nn.Sequential(
933
+ nn.GroupNorm(32, ch),
934
+ nn.SiLU(),
935
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
936
+ )
937
+ if self.predict_codebook_ids:
938
+ self.id_predictor = nn.Sequential(
939
+ nn.GroupNorm(32, ch),
940
+ conv_nd(dims, model_channels, n_embed, 1),
941
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
942
+ )
943
+
944
+ def forward(
945
+ self,
946
+ x,
947
+ timesteps=None,
948
+ context=None,
949
+ y=None,
950
+ camera=None,
951
+ num_frames=1,
952
+ ip=None,
953
+ ip_img=None,
954
+ **kwargs,
955
+ ):
956
+ """
957
+ Apply the model to an input batch.
958
+ :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views).
959
+ :param timesteps: a 1-D batch of timesteps.
960
+ :param context: conditioning plugged in via crossattn
961
+ :param y: an [N] Tensor of labels, if class-conditional.
962
+ :param num_frames: a integer indicating number of frames for tensor reshaping.
963
+ :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views).
964
+ """
965
+ assert (
966
+ x.shape[0] % num_frames == 0
967
+ ), "input batch size must be dividable by num_frames!"
968
+ assert (y is not None) == (
969
+ self.num_classes is not None
970
+ ), "must specify y if and only if the model is class-conditional"
971
+
972
+ hs = []
973
+
974
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
975
+
976
+ emb = self.time_embed(t_emb)
977
+
978
+ if self.num_classes is not None:
979
+ assert y is not None
980
+ assert y.shape[0] == x.shape[0]
981
+ emb = emb + self.label_emb(y)
982
+
983
+ # Add camera embeddings
984
+ if camera is not None:
985
+ emb = emb + self.camera_embed(camera)
986
+
987
+ # imagedream variant
988
+ if self.ip_dim > 0:
989
+ x[(num_frames - 1) :: num_frames, :, :, :] = ip_img # place at [4, 9]
990
+ ip_emb = self.image_embed(ip)
991
+ context = torch.cat((context, ip_emb), 1)
992
+
993
+ h = x
994
+ for module in self.input_blocks:
995
+ h = module(h, emb, context, num_frames=num_frames)
996
+ hs.append(h)
997
+ h = self.middle_block(h, emb, context, num_frames=num_frames)
998
+ for module in self.output_blocks:
999
+ h = torch.cat([h, hs.pop()], dim=1)
1000
+ h = module(h, emb, context, num_frames=num_frames)
1001
+ h = h.type(x.dtype)
1002
+ if self.predict_codebook_ids:
1003
+ return self.id_predictor(h)
1004
+ else:
1005
+ return self.out(h)
vae/config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.25.0",
4
+ "act_fn": "silu",
5
+ "block_out_channels": [
6
+ 128,
7
+ 256,
8
+ 512,
9
+ 512
10
+ ],
11
+ "down_block_types": [
12
+ "DownEncoderBlock2D",
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D"
16
+ ],
17
+ "force_upcast": true,
18
+ "in_channels": 3,
19
+ "latent_channels": 4,
20
+ "layers_per_block": 2,
21
+ "norm_num_groups": 32,
22
+ "out_channels": 3,
23
+ "sample_size": 256,
24
+ "scaling_factor": 0.18215,
25
+ "up_block_types": [
26
+ "UpDecoderBlock2D",
27
+ "UpDecoderBlock2D",
28
+ "UpDecoderBlock2D",
29
+ "UpDecoderBlock2D"
30
+ ]
31
+ }
vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e4c08995484ee61270175e9e7a072b66a6e4eeb5f0c266667fe1f45b90daf9a
3
+ size 167335342