Ryukijano commited on
Commit
2a9d94f
·
verified ·
1 Parent(s): 1ab78fa

Upload 38 files

Browse files
Files changed (39) hide show
  1. .gitattributes +1 -0
  2. README.md +12 -0
  3. app.py +411 -347
  4. configs/test_unclip-512-6view.yaml +56 -0
  5. examples/3968940-PH.png +0 -0
  6. examples/A_beautiful_cyborg_with_brown_hair_rgba.png +3 -0
  7. examples/A_bulldog_with_a_black_pirate_hat_rgba.png +0 -0
  8. examples/A_pig_wearing_a_backpack_rgba.png +0 -0
  9. examples/Elon-Musk.jpg +0 -0
  10. examples/Ghost_eating_burger_rgba.png +0 -0
  11. examples/cleanrot_armor_rgba.png +0 -0
  12. examples/cute_demon_combination_angel_figure_rgba.png +0 -0
  13. examples/dslr.png +0 -0
  14. examples/duola.png +0 -0
  15. examples/k2.png +0 -0
  16. examples/kind_cartoon_lion_in_costume_of_astronaut_rgba.png +0 -0
  17. examples/kunkun.png +0 -0
  18. examples/lantern.png +0 -0
  19. examples/lewd_statue_of_an_angel_texting_on_a_cell_phone_rgba.png +0 -0
  20. examples/monkey.png +0 -0
  21. examples/yann_kecun.jpg +0 -0
  22. mvdiffusion/data/dataset.py +138 -0
  23. mvdiffusion/data/dataset_nc.py +178 -0
  24. mvdiffusion/data/dreamdata.py +355 -0
  25. mvdiffusion/data/fixed_prompt_embeds_6view/clr_embeds.pt +3 -0
  26. mvdiffusion/data/fixed_prompt_embeds_6view/normal_embeds.pt +3 -0
  27. mvdiffusion/data/generate_fixed_text_embeds.py +78 -0
  28. mvdiffusion/data/normal_utils.py +78 -0
  29. mvdiffusion/data/single_image_dataset.py +249 -0
  30. mvdiffusion/models/transformer_mv2d_image.py +1029 -0
  31. mvdiffusion/models/transformer_mv2d_rowwise.py +978 -0
  32. mvdiffusion/models/transformer_mv2d_self_rowwise.py +1038 -0
  33. mvdiffusion/models/unet_mv2d_blocks.py +971 -0
  34. mvdiffusion/models/unet_mv2d_condition.py +1686 -0
  35. mvdiffusion/pipelines/pipeline_mvdiffusion_unclip.py +633 -0
  36. requirements.txt +35 -15
  37. sam_pt/sam_vit_h_4b8939.pth +3 -0
  38. utils/misc.py +54 -0
  39. utils/utils.py +27 -0
.gitattributes CHANGED
@@ -48,3 +48,4 @@ assets/basic/img4.png filter=lfs diff=lfs merge=lfs -text
48
  assets/basic/img5.png filter=lfs diff=lfs merge=lfs -text
49
  assets/basic/img6.png filter=lfs diff=lfs merge=lfs -text
50
  assets/basic/img7.png filter=lfs diff=lfs merge=lfs -text
 
 
48
  assets/basic/img5.png filter=lfs diff=lfs merge=lfs -text
49
  assets/basic/img6.png filter=lfs diff=lfs merge=lfs -text
50
  assets/basic/img7.png filter=lfs diff=lfs merge=lfs -text
51
+ examples/A_beautiful_cyborg_with_brown_hair_rgba.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Era3D MV Demo
3
+ emoji: 🐠
4
+ colorFrom: purple
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 4.31.5
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,347 +1,411 @@
1
- import torch
2
- import torch.nn as nn
3
- import yaml
4
- import cv2
5
- import numpy as np
6
- from PIL import Image
7
- import gradio as gr
8
- from functools import partial
9
- import lib.Equirec2Perspec as E2P
10
- import lib.Perspec2Equirec as P2E
11
- import lib.multi_Perspec2Equirec as m_P2E
12
- import openai
13
- from model import Model
14
-
15
- def get_K_R(FOV, THETA, PHI, height, width):
16
- f = 0.5 * width * 1 / np.tan(0.5 * FOV / 180.0 * np.pi)
17
- cx = (width - 1) / 2.0
18
- cy = (height - 1) / 2.0
19
- K = np.array([
20
- [f, 0, cx],
21
- [0, f, cy],
22
- [0, 0, 1],
23
- ], np.float32)
24
-
25
- y_axis = np.array([0.0, 1.0, 0.0], np.float32)
26
- x_axis = np.array([1.0, 0.0, 0.0], np.float32)
27
- R1, _ = cv2.Rodrigues(y_axis * np.radians(THETA))
28
- R2, _ = cv2.Rodrigues(np.dot(R1, x_axis) * np.radians(PHI))
29
- R = R2 @ R1
30
- return K, R
31
-
32
-
33
- if __name__=='__main__':
34
- cfg_path='configs/train_mv.yaml'
35
- config = yaml.load(open(cfg_path, 'rb'), Loader=yaml.SafeLoader)
36
- config['height']=512
37
- config['width']=512
38
- config['length']=8
39
- config['model_path']='weights/last.ckpt'
40
-
41
- demo_model=Model(config)
42
- state_dict=torch.load(config['model_path'])['state_dict']
43
- demo_model.load_state_dict(state_dict, strict=False)
44
- demo_model=demo_model.cuda()
45
-
46
- batch=torch.load('batch.pth')
47
-
48
-
49
- example1=[
50
- "A room with a sofa and coffee table for relaxing.",
51
- "A corner sofa is surrounded by plants.",
52
- "A comfy sofa, bookshelf, and lamp for reading.",
53
- "A bright room with a sofa, TV, and games.",
54
- "A stylish sofa and desk setup for work.",
55
- "A sofa, dining table, and chairs for gatherings.",
56
- "A colorful sofa, art, and music fill the room.",
57
- "A sofa, yoga mat, and meditation corner for calm."
58
- ]
59
- example2=[
60
- "A room with a sofa and coffee table for relaxing, cartoon style",
61
- "A corner sofa is surrounded by plants, cartoon style",
62
- "A comfy sofa, bookshelf, and lamp for reading, cartoon style",
63
- "A bright room with a sofa, TV, and games, cartoon style",
64
- "A stylish sofa and desk setup for work, cartoon style",
65
- "A sofa, dining table, and chairs for gatherings, cartoon style",
66
- "A colorful sofa, art, and music fill the room, cartoon style",
67
- "A sofa, yoga mat, and meditation corner for calm, cartoon style"
68
- ]
69
-
70
- example3=[
71
- "A room with a sofa and coffee table for relaxing, oil painting style",
72
- "A corner sofa is surrounded by plants, oil painting style",
73
- "A comfy sofa, bookshelf, and lamp for reading, oil painting style",
74
- "A bright room with a sofa, TV, and games, oil painting style",
75
- "A stylish sofa and desk setup for work, oil painting style",
76
- "A sofa, dining table, and chairs for gatherings, oil painting style",
77
- "A colorful sofa, art, and music fill the room, oil painting style",
78
- "A sofa, yoga mat, and meditation corner for calm, oil painting style"
79
- ]
80
-
81
- example4=[
82
- "A Japanese room with muted-colored tatami mats.",
83
- "A Japanese room with a simple, folded futon sits to one side.",
84
- "A Japanese room with a low table rests in the room's center.",
85
- "A Japanese room with Shoji screens divide the room softly.",
86
- "A Japanese room with An alcove holds an elegant scroll and flowers.",
87
- "A Japanese room with a tea set rests on a bamboo tray.",
88
- "A Japanese room with a carved wooden cupboard stands against a wall.",
89
- "A Japanese room with a traditional lamp gently lights the room."
90
- ]
91
- example6=[
92
- 'This kitchen is a charming blend of rustic and modern, featuring a large reclaimed wood island with marble countertop',
93
- 'This kitchen is a charming blend of rustic and modern, featuring a large reclaimed wood island with marble countertop',
94
- 'This kitchen is a charming blend of rustic and modern, featuring a large reclaimed wood island with marble countertop',
95
- 'To the left of the island, a stainless-steel refrigerator stands tall. ',
96
- 'To the left of the island, a stainless-steel refrigerator stands tall. ',
97
- 'a sink surrounded by cabinets',
98
- 'a sink surrounded by cabinets',
99
- 'To the right of the sink, built-in wooden cabinets painted in a muted.'
100
- ]
101
-
102
- example7= [
103
- "Cobblestone streets curl between old buildings.",
104
- "Shops and cafes display signs and emit pleasant smells.",
105
- "A fruit market scents the air with fresh citrus.",
106
- "A fountain adds calm to one side of the scene.",
107
- "Bicycles rest against walls and posts.",
108
- "Flowers in boxes color the windows.",
109
- "Flowers in boxes color the windows.",
110
- "Cobblestone streets curl between old buildings."
111
- ]
112
-
113
- example8=[
114
- "The patio is open and airy.",
115
- "A table and chairs sit in the middle.",
116
- "Next the table is flowers.",
117
- "Colorful flowers fill the planters.",
118
- "A grill stands ready for barbecues.",
119
- "A grill stands ready for barbecues.",
120
- "The patio overlooks a lush garden.",
121
- "The patio overlooks a lush garden."
122
- ]
123
-
124
- example9=[
125
- "A Chinese palace with roofs curve.",
126
- "A Chinese palace, Red and gold accents gleam in the sun.",
127
- "A Chinese palace with a view of mountain in the front.",
128
- "A view of mountain in the front.",
129
- "A Chinese palace with a view of mountain in the front.",
130
- "A Chinese palace with a tree beside.",
131
- "A Chinese palace with a tree beside.",
132
- "A Chinese palace, with a tree beside."
133
- ]
134
-
135
-
136
-
137
- example_b1="This kitchen is a charming blend of rustic and modern, featuring a large reclaimed wood island with marble countertop, a sink surrounded by cabinets. To the left of the island, a stainless-steel refrigerator stands tall. To the right of the sink, built-in wooden cabinets painted in a muted."
138
- example_b2="Bursting with vibrant hues and exaggerated proportions, the cartoon-styled room sparkled with whimsy and cheer, with floating shelves crammed with oddly shaped trinkets, a comically oversized polka-dot armchair perched near a gravity-defying, tilted lamp, and the candy-striped wallpaper creating a playful backdrop to the merry chaos, exuding a sense of fun and boundless imagination."
139
- example_b3="Bathed in the pulsating glow of neon lights that painted stark contrasts of shadow and color, the cyberpunk room was a high-tech, low-life sanctuary, where sleek, metallic surfaces met jagged, improvised tech; a wall of glitchy monitors flickered with unending streams of data, and the buzz of electric current and the low hum of cooling fans formed a dystopian symphony, adding to the room's relentless, gritty energy."
140
- example_b4="Majestically rising towards the heavens, the snow-capped mountain stood, its jagged peaks cloaked in a shroud of ethereal clouds, its rugged slopes a stark contrast against the serene azure sky, and its silent grandeur exuding an air of ancient wisdom and timeless solitude, commanding awe and reverence from all who beheld it."
141
- example_b5='Bathed in the soft, dappled light of the setting sun, the silent street lay undisturbed, revealing the grandeur of its cobblestone texture, the rusted lampposts bearing witness to forgotten stories, and the ancient, ivy-clad houses standing stoically, their shuttered windows and weather-beaten doors speaking volumes about their passage through time.'
142
- example_b6='Awash with the soothing hues of an array of blossoms, the tranquil garden was a symphony of life and color, where the soft murmur of the babbling brook intertwined with the whispering willows, and the iridescent petals danced in the gentle breeze, creating an enchanting sanctuary of beauty and serenity.'
143
- example_b7="Canopied by a patchwork quilt of sunlight and shadows, the sprawling park was a panorama of lush green grass, meandering trails etched through vibrant wildflowers, towering oaks reaching towards the sky, and tranquil ponds mirroring the clear, blue expanse above, offering a serene retreat in the heart of nature's splendor."
144
-
145
- examples_basic=[example_b1, example_b2, example_b3, example_b4, example_b5, example_b6]
146
- examples_advanced=[example1, example2, example3, example4, example6, example7, example8, example9]
147
-
148
- description="The demo generates 8 perspective images, with FOV of 90 and rotation angle of 45. Please type 8 sentences corresponding to each perspective image."
149
-
150
- outputs=[gr.Image(shape=(484, 2048))]
151
- outputs.extend([gr.Image(shape=(1, 1)) for i in range(8)])
152
-
153
- def load_example_img(path):
154
- img=Image.open(path)
155
- img.resize((1024, 242))
156
- return img
157
-
158
- def copy(text):
159
- return [text]*8
160
-
161
- def clear():
162
- return None, None, None, None, None, None, None, None, None
163
-
164
- def load_basic(example):
165
- return example
166
-
167
- def generate_advanced(acc, text1, text2, text3, text4, text5, text6, text7, text8):
168
- texts=[text1, text2, text3, text4, text5, text6, text7, text8]
169
- for text in texts:
170
- if text is None or text=='':
171
- raise gr.Error('Text cannot be empty')
172
- images_low_res_pred=demo_model(texts, batch)[0]
173
- imgs=[]
174
- degrees = [[90, 0, 0],[90, 45, 0],[90, 90, 0],[90, 135, 0],[90, 180, 0],[90, 225, 0],[90, 270, 0],[90, 315, 0]]
175
- width = 2048
176
- height = 1024
177
- for i in range(8):
178
- imgs.append(images_low_res_pred[i])
179
- equ = m_P2E.Perspective(imgs,
180
- degrees)
181
-
182
-
183
- img = equ.GetEquirec(height,width).astype(np.uint8)
184
- img=img[270:-270]
185
- imgs=[img]+imgs
186
- return [acc.update(open=False)]+imgs
187
-
188
- def generate_basic(acc, text):
189
- print(text)
190
- if text is None or text=='':
191
- raise gr.Error('Text cannot be empty')
192
- model='gpt-4o-mini'
193
- openai.api_key = "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"
194
-
195
- # Start sending prompts
196
- flag=False
197
- for i in range(20):
198
- try:
199
- response = openai.ChatCompletion.create(
200
- model=model,
201
- messages=[
202
- {"role": "user", "content": "Can you describe the following with 5 or 6 sentences? {}".format(text)}],
203
- max_tokens=193,
204
- temperature=0,
205
- )
206
- text=response.choices[0]['message']['content']
207
- flag=True
208
- break
209
- except:
210
- flag=False
211
- if not flag:
212
- raise gr.Error('Text error')
213
-
214
- texts=[text]*8
215
- if text=='':
216
- raise gr.Error('Text cannot be empty')
217
- images_low_res_pred=demo_model(texts, batch)[0]
218
- imgs=[]
219
- degrees = [[90, 0, 0],[90, 45, 0],[90, 90, 0],[90, 135, 0],[90, 180, 0],[90, 225, 0],[90, 270, 0],[90, 315, 0]]
220
- width = 2048
221
- height = 1024
222
- for i in range(8):
223
- imgs.append(images_low_res_pred[i])
224
- equ = m_P2E.Perspective(imgs,
225
- degrees)
226
-
227
-
228
- img = equ.GetEquirec(height,width).astype(np.uint8)
229
- img=img[270:-270]
230
- imgs=[img]+imgs
231
- return [acc.update(open=False)]+imgs
232
-
233
- default_text='This kitchen is a charming blend of rustic and modern, featuring a large reclaimed wood island with marble countertop, a sink surrounded by cabinets. To the left of the island, a stainless-steel refrigerator stands tall. To the right of the sink, built-in wooden cabinets painted in a muted.'
234
- css = """
235
- #warning {background-color: #000000}
236
- .feedback textarea {font-size: 16px !important}
237
- #foo {}
238
- .text111 textarea {
239
- color: rgba(0, 0, 0, 0.5);
240
- }
241
- """
242
-
243
- inputs=[gr.Textbox(type="text", label='Text{}'.format(i)) for i in range(8)]
244
-
245
- with gr.Blocks(css=css) as demo:
246
- with gr.Row():
247
- gr.Markdown(
248
- """
249
- # <center>Text2Pano with MVDiffusion</center>
250
- """)
251
- with gr.Row():
252
- gr.Markdown(
253
- """
254
- <center>Text2Pano demonstration: Write a scene you want in Text, then click "Generate panorama". Alternatively, you can load the example text prompts below to populate text-boxes. The advanced mode allows to specify text prompts for each perspective image</center>
255
- """)
256
- with gr.Tab("Basic"):
257
- with gr.Row():
258
- textbox1=gr.Textbox(type="text", label='Text', value=default_text, elem_id='warning', elem_classes="feedback")
259
-
260
- with gr.Row():
261
- submit_btn = gr.Button("Generate panorama")
262
- clear_btn = gr.Button("Clear all texts")
263
- clear_btn.click(
264
- clear,
265
- outputs=inputs+[textbox1]
266
- )
267
-
268
- with gr.Accordion("Example expand/hide") as acc:
269
- for i in range(0, len(examples_basic)):
270
- with gr.Row():
271
- gr.Image(load_example_img('assets/basic/img{}.png'.format(i+1)), label='example {}'.format(i+1))
272
- #gr.Image('demo/assets/basic/img{}.png'.format(i+2), label='example {}'.format(i+2))
273
- with gr.Row():
274
- gr.Textbox(type="text", label='Example text {}'.format(i+1), value=examples_basic[i])
275
- #gr.Textbox(type="text", label='Example text {}'.format(i+2), value=examples_basic[i+1])
276
- with gr.Row():
277
- load_btn=gr.Button("Load text to the main box")
278
- load_btn.click(
279
- partial(load_basic, examples_basic[i]),
280
- outputs=[textbox1]
281
- )
282
-
283
- submit_btn.click(
284
- partial(generate_basic, acc),
285
- inputs=textbox1,
286
- outputs=[acc]+outputs
287
- )
288
-
289
- with gr.Tab("Advanced"):
290
- with gr.Row():
291
- for text_bar in inputs[:4]:
292
- text_bar.render()
293
- with gr.Row():
294
- for text_bar in inputs[4:]:
295
- text_bar.render()
296
-
297
- with gr.Row():
298
-
299
- submit_btn = gr.Button("Generate panorama")
300
- clear_btn = gr.Button("Clear all texts")
301
- clear_btn.click(
302
- clear,
303
- outputs=inputs+[textbox1]
304
- )
305
- with gr.Accordion("Example expand/hide") as acc_advanced:
306
- for i, example in enumerate(examples_advanced):
307
- with gr.Row():
308
- gr.Image(load_example_img('assets/advanced/img{}.png'.format(i+1)), label='example {}'.format(i+1))
309
- with gr.Row():
310
- gr.Textbox(type="text", label='Text 1', value=example[0])
311
- gr.Textbox(type="text", label='Text 2', value=example[1])
312
- gr.Textbox(type="text", label='Text 3', value=example[2])
313
- gr.Textbox(type="text", label='Text 4', value=example[3])
314
- with gr.Row():
315
- gr.Textbox(type="text", label='Text 4', value=example[4])
316
- gr.Textbox(type="text", label='Text 5', value=example[5])
317
- gr.Textbox(type="text", label='Text 6', value=example[6])
318
- gr.Textbox(type="text", label='Text 7', value=example[7])
319
- with gr.Row():
320
- load_btn=gr.Button("Load text to other text boxes")
321
- load_btn.click(
322
- partial(load_basic, example),
323
- outputs=inputs
324
- )
325
- submit_btn.click(
326
- partial(generate_advanced, acc_advanced),
327
- inputs=inputs,
328
- outputs=[acc_advanced]+outputs
329
- )
330
-
331
- with gr.Row():
332
- outputs[0].render()
333
- with gr.Row():
334
- outputs[1].render()
335
- outputs[2].render()
336
- with gr.Row():
337
- outputs[3].render()
338
- outputs[4].render()
339
- with gr.Row():
340
- outputs[5].render()
341
- outputs[6].render()
342
- with gr.Row():
343
- outputs[7].render()
344
- outputs[8].render()
345
-
346
- demo.queue(concurrency_count=3)
347
- demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import fire
4
+ import gradio as gr
5
+ from PIL import Image
6
+ from functools import partial
7
+ import spaces
8
+ import cv2
9
+ import time
10
+ import numpy as np
11
+ from rembg import remove
12
+ from segment_anything import sam_model_registry, SamPredictor
13
+
14
+ import os
15
+ import torch
16
+
17
+ from PIL import Image
18
+ from typing import Dict, Optional, List
19
+ from dataclasses import dataclass
20
+ from mvdiffusion.data.single_image_dataset import SingleImageDataset
21
+ from mvdiffusion.pipelines.pipeline_mvdiffusion_unclip import StableUnCLIPImg2ImgPipeline
22
+ from einops import rearrange
23
+ import numpy as np
24
+ import subprocess
25
+ from datetime import datetime
26
+ from icecream import ic
27
+ def save_image(tensor):
28
+ ndarr = tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
29
+ # pdb.set_trace()
30
+ im = Image.fromarray(ndarr)
31
+ return ndarr
32
+
33
+
34
+ def save_image_to_disk(tensor, fp):
35
+ ndarr = tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
36
+ # pdb.set_trace()
37
+ im = Image.fromarray(ndarr)
38
+ im.save(fp)
39
+ return ndarr
40
+
41
+
42
+ def save_image_numpy(ndarr, fp):
43
+ im = Image.fromarray(ndarr)
44
+ im.save(fp)
45
+
46
+
47
+ weight_dtype = torch.float16
48
+
49
+ _TITLE = '''Era3D: High-Resolution Multiview Diffusion using Efficient Row-wise Attention'''
50
+ _DESCRIPTION = '''
51
+ <div>
52
+ Generate consistent high-resolution multi-view normals maps and color images.
53
+ </div>
54
+ <div>
55
+ The demo does not include the mesh reconstruction part, please visit <a href="https://github.com/pengHTYX/Era3D"><img src='https://img.shields.io/github/stars/pengHTYX/Era3D?style=social' style="display: inline-block; vertical-align: middle;"/></a> to get a textured mesh.
56
+ </div>
57
+ '''
58
+ _GPU_ID = 0
59
+
60
+
61
+ if not hasattr(Image, 'Resampling'):
62
+ Image.Resampling = Image
63
+
64
+
65
+ def sam_init():
66
+ sam_checkpoint = os.path.join(os.path.dirname(__file__), "sam_pt", "sam_vit_h_4b8939.pth")
67
+ model_type = "vit_h"
68
+
69
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=f"cuda:{_GPU_ID}")
70
+ predictor = SamPredictor(sam)
71
+ return predictor
72
+
73
+ @spaces.GPU
74
+ def sam_segment(predictor, input_image, *bbox_coords):
75
+ bbox = np.array(bbox_coords)
76
+ image = np.asarray(input_image)
77
+
78
+ start_time = time.time()
79
+ predictor.set_image(image)
80
+
81
+ masks_bbox, scores_bbox, logits_bbox = predictor.predict(box=bbox, multimask_output=True)
82
+
83
+ print(f"SAM Time: {time.time() - start_time:.3f}s")
84
+ out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
85
+ out_image[:, :, :3] = image
86
+ out_image_bbox = out_image.copy()
87
+ out_image_bbox[:, :, 3] = masks_bbox[-1].astype(np.uint8) * 255
88
+ torch.cuda.empty_cache()
89
+ return Image.fromarray(out_image_bbox, mode='RGBA')
90
+
91
+
92
+ def expand2square(pil_img, background_color):
93
+ width, height = pil_img.size
94
+ if width == height:
95
+ return pil_img
96
+ elif width > height:
97
+ result = Image.new(pil_img.mode, (width, width), background_color)
98
+ result.paste(pil_img, (0, (width - height) // 2))
99
+ return result
100
+ else:
101
+ result = Image.new(pil_img.mode, (height, height), background_color)
102
+ result.paste(pil_img, ((height - width) // 2, 0))
103
+ return result
104
+
105
+
106
+ def preprocess(predictor, input_image, chk_group=None, segment=True, rescale=False):
107
+ RES = 1024
108
+ input_image.thumbnail([RES, RES], Image.Resampling.LANCZOS)
109
+ if chk_group is not None:
110
+ segment = "Background Removal" in chk_group
111
+ rescale = "Rescale" in chk_group
112
+ if segment:
113
+ image_rem = input_image.convert('RGBA')
114
+ image_nobg = remove(image_rem, alpha_matting=True)
115
+ arr = np.asarray(image_nobg)[:, :, -1]
116
+ x_nonzero = np.nonzero(arr.sum(axis=0))
117
+ y_nonzero = np.nonzero(arr.sum(axis=1))
118
+ x_min = int(x_nonzero[0].min())
119
+ y_min = int(y_nonzero[0].min())
120
+ x_max = int(x_nonzero[0].max())
121
+ y_max = int(y_nonzero[0].max())
122
+ input_image = sam_segment(predictor, input_image.convert('RGB'), x_min, y_min, x_max, y_max)
123
+ # Rescale and recenter
124
+ if rescale:
125
+ image_arr = np.array(input_image)
126
+ in_w, in_h = image_arr.shape[:2]
127
+ out_res = min(RES, max(in_w, in_h))
128
+ ret, mask = cv2.threshold(np.array(input_image.split()[-1]), 0, 255, cv2.THRESH_BINARY)
129
+ x, y, w, h = cv2.boundingRect(mask)
130
+ max_size = max(w, h)
131
+ ratio = 0.75
132
+ side_len = int(max_size / ratio)
133
+ padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8)
134
+ center = side_len // 2
135
+ padded_image[center - h // 2 : center - h // 2 + h, center - w // 2 : center - w // 2 + w] = image_arr[y : y + h, x : x + w]
136
+ rgba = Image.fromarray(padded_image).resize((out_res, out_res), Image.LANCZOS)
137
+
138
+ rgba_arr = np.array(rgba) / 255.0
139
+ rgb = rgba_arr[..., :3] * rgba_arr[..., -1:] + (1 - rgba_arr[..., -1:])
140
+ input_image = Image.fromarray((rgb * 255).astype(np.uint8))
141
+ else:
142
+ input_image = expand2square(input_image, (127, 127, 127, 0))
143
+ return input_image, input_image.resize((320, 320), Image.Resampling.LANCZOS)
144
+
145
+ def load_era3d_pipeline(cfg):
146
+ # Load scheduler, tokenizer and models.
147
+
148
+ pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(
149
+ cfg.pretrained_model_name_or_path,
150
+ torch_dtype=weight_dtype
151
+ )
152
+ # sys.main_lock = threading.Lock()
153
+ return pipeline
154
+
155
+
156
+ from mvdiffusion.data.single_image_dataset import SingleImageDataset
157
+
158
+
159
+ def prepare_data(single_image, crop_size, cfg):
160
+ dataset = SingleImageDataset(root_dir='', num_views=6, img_wh=[512, 512], bg_color='white',
161
+ crop_size=crop_size, single_image=single_image, prompt_embeds_path=cfg.validation_dataset.prompt_embeds_path)
162
+ return dataset[0]
163
+
164
+ scene = 'scene'
165
+ def run_pipeline(pipeline, cfg, single_image, guidance_scale, steps, seed, crop_size, chk_group=None):
166
+ pipeline.to(device=f'cuda:{_GPU_ID}')
167
+ pipeline.unet.enable_xformers_memory_efficient_attention()
168
+
169
+ global scene
170
+ # pdb.set_trace()
171
+
172
+ if chk_group is not None:
173
+ write_image = "Write Results" in chk_group
174
+
175
+ batch = prepare_data(single_image, crop_size, cfg)
176
+
177
+ pipeline.set_progress_bar_config(disable=True)
178
+ seed = int(seed)
179
+ generator = torch.Generator(device=pipeline.unet.device).manual_seed(seed)
180
+
181
+
182
+ imgs_in = torch.cat([batch['imgs_in']]*2, dim=0)
183
+ num_views = imgs_in.shape[1]
184
+ imgs_in = rearrange(imgs_in, "B Nv C H W -> (B Nv) C H W")# (B*Nv, 3, H, W)
185
+
186
+ normal_prompt_embeddings, clr_prompt_embeddings = batch['normal_prompt_embeddings'], batch['color_prompt_embeddings']
187
+ prompt_embeddings = torch.cat([normal_prompt_embeddings, clr_prompt_embeddings], dim=0)
188
+ prompt_embeddings = rearrange(prompt_embeddings, "B Nv N C -> (B Nv) N C")
189
+
190
+
191
+ imgs_in = imgs_in.to(device=f'cuda:{_GPU_ID}', dtype=weight_dtype)
192
+ prompt_embeddings = prompt_embeddings.to(device=f'cuda:{_GPU_ID}', dtype=weight_dtype)
193
+
194
+ out = pipeline(
195
+ imgs_in,
196
+ None,
197
+ prompt_embeds=prompt_embeddings,
198
+ generator=generator,
199
+ guidance_scale=guidance_scale,
200
+ output_type='pt',
201
+ num_images_per_prompt=1,
202
+ # return_elevation_focal=cfg.log_elevation_focal_length,
203
+ **cfg.pipe_validation_kwargs
204
+ ).images
205
+
206
+ bsz = out.shape[0] // 2
207
+ normals_pred = out[:bsz]
208
+ images_pred = out[bsz:]
209
+ num_views = 6
210
+ if write_image:
211
+ VIEWS = ['front', 'front_right', 'right', 'back', 'left', 'front_left']
212
+ cur_dir = os.path.join(cfg.save_dir, f"cropsize-{int(crop_size)}-cfg{guidance_scale:.1f}")
213
+
214
+ scene = 'scene'+datetime.now().strftime('@%Y%m%d-%H%M%S')
215
+ scene_dir = os.path.join(cur_dir, scene)
216
+ os.makedirs(scene_dir, exist_ok=True)
217
+
218
+ for j in range(num_views):
219
+ view = VIEWS[j]
220
+ normal = normals_pred[j]
221
+ color = images_pred[j]
222
+
223
+ normal_filename = f"normals_{view}_masked.png"
224
+ color_filename = f"color_{view}_masked.png"
225
+ normal = save_image_to_disk(normal, os.path.join(scene_dir, normal_filename))
226
+ color = save_image_to_disk(color, os.path.join(scene_dir, color_filename))
227
+
228
+
229
+ normals_pred = [save_image(normals_pred[i]) for i in range(bsz)]
230
+ images_pred = [save_image(images_pred[i]) for i in range(bsz)]
231
+
232
+ out = images_pred + normals_pred
233
+ return images_pred, normals_pred
234
+
235
+
236
+ def process_3d(mode, data_dir, guidance_scale, crop_size):
237
+ dir = None
238
+ global scene
239
+
240
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
241
+
242
+ subprocess.run(
243
+ f'cd instant-nsr-pl && bash run.sh 0 {scene} exp_demo && cd ..',
244
+ shell=True,
245
+ )
246
+ import glob
247
+
248
+ obj_files = glob.glob(f'{cur_dir}/instant-nsr-pl/exp_demo/{scene}/*/save/*.obj', recursive=True)
249
+ print(obj_files)
250
+ if obj_files:
251
+ dir = obj_files[0]
252
+ return dir
253
+
254
+
255
+ @dataclass
256
+ class TestConfig:
257
+ pretrained_model_name_or_path: str
258
+ pretrained_unet_path:Optional[str]
259
+ revision: Optional[str]
260
+ validation_dataset: Dict
261
+ save_dir: str
262
+ seed: Optional[int]
263
+ validation_batch_size: int
264
+ dataloader_num_workers: int
265
+ # save_single_views: bool
266
+ save_mode: str
267
+ local_rank: int
268
+
269
+ pipe_kwargs: Dict
270
+ pipe_validation_kwargs: Dict
271
+ unet_from_pretrained_kwargs: Dict
272
+ validation_guidance_scales: List[float]
273
+ validation_grid_nrow: int
274
+ camera_embedding_lr_mult: float
275
+
276
+ num_views: int
277
+ camera_embedding_type: str
278
+
279
+ pred_type: str # joint, or ablation
280
+ regress_elevation: bool
281
+ enable_xformers_memory_efficient_attention: bool
282
+
283
+ cond_on_normals: bool
284
+ cond_on_colors: bool
285
+
286
+ regress_elevation: bool
287
+ regress_focal_length: bool
288
+
289
+
290
+
291
+ def run_demo():
292
+ from utils.misc import load_config
293
+ from omegaconf import OmegaConf
294
+
295
+ # parse YAML config to OmegaConf
296
+ cfg = load_config("./configs/test_unclip-512-6view.yaml")
297
+ # print(cfg)
298
+ schema = OmegaConf.structured(TestConfig)
299
+ cfg = OmegaConf.merge(schema, cfg)
300
+
301
+ pipeline = load_era3d_pipeline(cfg)
302
+ torch.set_grad_enabled(False)
303
+
304
+
305
+ predictor = sam_init()
306
+
307
+
308
+ custom_theme = gr.themes.Soft(primary_hue="blue").set(
309
+ button_secondary_background_fill="*neutral_100", button_secondary_background_fill_hover="*neutral_200"
310
+ )
311
+ custom_css = '''#disp_image {
312
+ text-align: center; /* Horizontally center the content */
313
+ }'''
314
+
315
+
316
+ with gr.Blocks(title=_TITLE, theme=custom_theme, css=custom_css) as demo:
317
+ with gr.Row():
318
+ with gr.Column(scale=1):
319
+ gr.Markdown('# ' + _TITLE)
320
+ gr.Markdown(_DESCRIPTION)
321
+ with gr.Row(variant='panel'):
322
+ with gr.Column(scale=1):
323
+ input_image = gr.Image(type='pil', image_mode='RGBA', height=320, label='Input image')
324
+
325
+ with gr.Column(scale=1):
326
+ processed_image_highres = gr.Image(type='pil', image_mode='RGBA', visible=False)
327
+
328
+ processed_image = gr.Image(
329
+ type='pil',
330
+ label="Processed Image",
331
+ interactive=False,
332
+ # height=320,
333
+ image_mode='RGBA',
334
+ elem_id="disp_image",
335
+ visible=True,
336
+ )
337
+ # with gr.Column(scale=1):
338
+ # ## add 3D Model
339
+ # obj_3d = gr.Model3D(
340
+ # # clear_color=[0.0, 0.0, 0.0, 0.0],
341
+ # label="3D Model", height=320,
342
+ # # camera_position=[0,0,2.0]
343
+ # )
344
+
345
+ with gr.Row(variant='panel'):
346
+ with gr.Column(scale=1):
347
+ example_folder = os.path.join(os.path.dirname(__file__), "./examples")
348
+ example_fns = [os.path.join(example_folder, example) for example in os.listdir(example_folder)]
349
+ gr.Examples(
350
+ examples=example_fns,
351
+ inputs=[input_image],
352
+ outputs=[input_image],
353
+ cache_examples=False,
354
+ label='Examples (click one of the images below to start)',
355
+ examples_per_page=30,
356
+ )
357
+ with gr.Column(scale=1):
358
+ with gr.Row():
359
+ with gr.Column():
360
+ with gr.Accordion('Advanced options', open=True):
361
+ input_processing = gr.CheckboxGroup(
362
+ ['Background Removal'],
363
+ label='Input Image Preprocessing',
364
+ value=['Background Removal'],
365
+ info='untick this, if masked image with alpha channel',
366
+ )
367
+ with gr.Column():
368
+ with gr.Accordion('Advanced options', open=False):
369
+ output_processing = gr.CheckboxGroup(
370
+ ['Write Results'], label='write the results in mv_res folder', value=['Write Results']
371
+ )
372
+ with gr.Row():
373
+ with gr.Column():
374
+ scale_slider = gr.Slider(1, 5, value=3, step=1, label='Classifier Free Guidance Scale')
375
+ with gr.Column():
376
+ steps_slider = gr.Slider(15, 100, value=40, step=1, label='Number of Diffusion Inference Steps')
377
+ with gr.Row():
378
+ with gr.Column():
379
+ seed = gr.Number(600, label='Seed', info='100 for digital portraits')
380
+ with gr.Column():
381
+ crop_size = gr.Number(420, label='Crop size', info='380 for digital portraits')
382
+
383
+ mode = gr.Textbox('train', visible=False)
384
+ data_dir = gr.Textbox('outputs', visible=False)
385
+ # with gr.Row():
386
+ # method = gr.Radio(choices=['instant-nsr-pl', 'NeuS'], label='Method (Default: instant-nsr-pl)', value='instant-nsr-pl')
387
+ run_btn = gr.Button('Generate Normals and Colors', variant='primary', interactive=True)
388
+ # recon_btn = gr.Button('Reconstruct 3D model', variant='primary', interactive=True)
389
+ # gr.Markdown("<span style='color:red'>First click Generate button, then click Reconstruct button. Reconstruction may cost several minutes.</span>")
390
+
391
+ with gr.Row():
392
+ view_gallery = gr.Gallery(label='Multiview Images')
393
+ normal_gallery = gr.Gallery(label='Multiview Normals')
394
+
395
+ print('Launching...')
396
+ run_btn.click(
397
+ fn=partial(preprocess, predictor), inputs=[input_image, input_processing], outputs=[processed_image_highres, processed_image], queue=True
398
+ ).success(
399
+ fn=partial(run_pipeline, pipeline, cfg),
400
+ inputs=[processed_image_highres, scale_slider, steps_slider, seed, crop_size, output_processing],
401
+ outputs=[view_gallery, normal_gallery],
402
+ )
403
+ # recon_btn.click(
404
+ # process_3d, inputs=[mode, data_dir, scale_slider, crop_size], outputs=[obj_3d]
405
+ # )
406
+
407
+ demo.queue().launch(share=True, max_threads=80)
408
+
409
+
410
+ if __name__ == '__main__':
411
+ fire.Fire(run_demo)
configs/test_unclip-512-6view.yaml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_model_name_or_path: './MacLab-Era3D-512-6view'
2
+ revision: null
3
+
4
+ num_views: 6
5
+ validation_dataset:
6
+ prompt_embeds_path: mvdiffusion/data/fixed_prompt_embeds_6view
7
+ root_dir: 'examples'
8
+ num_views: ${num_views}
9
+ bg_color: 'white'
10
+ img_wh: [512, 512]
11
+ num_validation_samples: 1000
12
+ crop_size: 420
13
+
14
+ pred_type: 'joint'
15
+ save_dir: 'mv_res'
16
+ save_mode: 'rgba' # 'concat', 'rgba', 'rgb'
17
+ seed: 42
18
+ validation_batch_size: 1
19
+ dataloader_num_workers: 1
20
+ local_rank: -1
21
+
22
+ pipe_kwargs:
23
+ num_views: ${num_views}
24
+
25
+ validation_guidance_scales: [3.0]
26
+ pipe_validation_kwargs:
27
+ num_inference_steps: 40
28
+ eta: 1.0
29
+
30
+ validation_grid_nrow: ${num_views}
31
+ regress_elevation: true
32
+ regress_focal_length: true
33
+ unet_from_pretrained_kwargs:
34
+ unclip: true
35
+ sdxl: false
36
+ num_views: ${num_views}
37
+ sample_size: 64
38
+ zero_init_conv_in: false # modify
39
+
40
+ regress_elevation: ${regress_elevation}
41
+ regress_focal_length: ${regress_focal_length}
42
+ camera_embedding_type: e_de_da_sincos
43
+ projection_camera_embeddings_input_dim: 4 # 2 for elevation and 6 for focal_length
44
+ zero_init_camera_projection: false
45
+ num_regress_blocks: 3
46
+
47
+ cd_attention_last: false
48
+ cd_attention_mid: false
49
+ multiview_attention: true
50
+ sparse_mv_attention: true
51
+ selfattn_block: self_rowwise
52
+ mvcd_attention: true
53
+
54
+ use_dino: false
55
+
56
+ enable_xformers_memory_efficient_attention: true
examples/3968940-PH.png ADDED
examples/A_beautiful_cyborg_with_brown_hair_rgba.png ADDED

Git LFS Details

  • SHA256: 3dd8d815ba5bc0a7e17587f8a4d2cec64d196ba5b5f44fff3fed13e1783de366
  • Pointer size: 132 Bytes
  • Size of remote file: 1.13 MB
examples/A_bulldog_with_a_black_pirate_hat_rgba.png ADDED
examples/A_pig_wearing_a_backpack_rgba.png ADDED
examples/Elon-Musk.jpg ADDED
examples/Ghost_eating_burger_rgba.png ADDED
examples/cleanrot_armor_rgba.png ADDED
examples/cute_demon_combination_angel_figure_rgba.png ADDED
examples/dslr.png ADDED
examples/duola.png ADDED
examples/k2.png ADDED
examples/kind_cartoon_lion_in_costume_of_astronaut_rgba.png ADDED
examples/kunkun.png ADDED
examples/lantern.png ADDED
examples/lewd_statue_of_an_angel_texting_on_a_cell_phone_rgba.png ADDED
examples/monkey.png ADDED
examples/yann_kecun.jpg ADDED
mvdiffusion/data/dataset.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import decord
2
+ # decord.bridge.set_bridge('torch')
3
+
4
+ from torch.utils.data import Dataset
5
+ from einops import rearrange
6
+ from typing import Literal, Tuple, Optional, Any
7
+ import glob
8
+ import os
9
+ import json
10
+ import random
11
+ import cv2
12
+ import math
13
+ import numpy as np
14
+ import torch
15
+ from PIL import Image
16
+
17
+
18
+ class MVDiffusionDatasetV1(Dataset):
19
+ def __init__(
20
+ self,
21
+ root_dir: str,
22
+ num_views: int,
23
+ bg_color: Any,
24
+ img_wh: Tuple[int, int],
25
+ validation: bool = False,
26
+ num_validation_samples: int = 64,
27
+ num_samples: Optional[int] = None,
28
+ caption_path: Optional[str] = None,
29
+ elevation_range_deg: Tuple[float,float] = (-90, 90),
30
+ azimuth_range_deg: Tuple[float, float] = (0, 360),
31
+ ):
32
+ self.all_obj_paths = sorted(glob.glob(os.path.join(root_dir, "*/*")))
33
+ if not validation:
34
+ self.all_obj_paths = self.all_obj_paths[:-num_validation_samples]
35
+ else:
36
+ self.all_obj_paths = self.all_obj_paths[-num_validation_samples:]
37
+ if num_samples is not None:
38
+ self.all_obj_paths = self.all_obj_paths[:num_samples]
39
+ self.all_obj_ids = [os.path.basename(path) for path in self.all_obj_paths]
40
+ self.num_views = num_views
41
+ self.bg_color = bg_color
42
+ self.img_wh = img_wh
43
+
44
+ def get_bg_color(self):
45
+ if self.bg_color == 'white':
46
+ bg_color = np.array([1., 1., 1.], dtype=np.float32)
47
+ elif self.bg_color == 'black':
48
+ bg_color = np.array([0., 0., 0.], dtype=np.float32)
49
+ elif self.bg_color == 'gray':
50
+ bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32)
51
+ elif self.bg_color == 'random':
52
+ bg_color = np.random.rand(3)
53
+ elif isinstance(self.bg_color, float):
54
+ bg_color = np.array([self.bg_color] * 3, dtype=np.float32)
55
+ else:
56
+ raise NotImplementedError
57
+ return bg_color
58
+
59
+ def load_image(self, img_path, bg_color, return_type='np'):
60
+ # not using cv2 as may load in uint16 format
61
+ # img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # [0, 255]
62
+ # img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_CUBIC)
63
+ # pil always returns uint8
64
+ img = np.array(Image.open(img_path).resize(self.img_wh))
65
+ img = img.astype(np.float32) / 255. # [0, 1]
66
+ assert img.shape[-1] == 4 # RGBA
67
+
68
+ alpha = img[...,3:4]
69
+ img = img[...,:3] * alpha + bg_color * (1 - alpha)
70
+
71
+ if return_type == "np":
72
+ pass
73
+ elif return_type == "pt":
74
+ img = torch.from_numpy(img)
75
+ else:
76
+ raise NotImplementedError
77
+
78
+ return img
79
+
80
+ def __len__(self):
81
+ return len(self.all_obj_ids)
82
+
83
+ def __getitem__(self, index):
84
+ obj_path = self.all_obj_paths[index]
85
+ obj_id = self.all_obj_ids[index]
86
+ with open(os.path.join(obj_path, 'meta.json')) as f:
87
+ meta = json.loads(f.read())
88
+
89
+ num_views_all = len(meta['locations'])
90
+ num_groups = num_views_all // self.num_views
91
+
92
+ # random a set of 4 views
93
+ # the data is arranged in ascending order of the azimuth angle
94
+ group_ids = random.sample(range(num_groups), k=2)
95
+ cond_group_id, tgt_group_id = group_ids
96
+ cond_location = meta['locations'][cond_group_id * self.num_views + random.randint(0, self.num_views - 1)]
97
+ tgt_locations = meta['locations'][tgt_group_id * self.num_views : tgt_group_id * self.num_views + self.num_views]
98
+ # random an order
99
+ start_id = random.randint(0, self.num_views - 1)
100
+ tgt_locations = tgt_locations[start_id:] + tgt_locations[:start_id]
101
+
102
+ cond_elevation = cond_location['elevation']
103
+ cond_azimuth = cond_location['azimuth']
104
+ tgt_elevations = [loc['elevation'] for loc in tgt_locations]
105
+ tgt_azimuths = [loc['azimuth'] for loc in tgt_locations]
106
+
107
+ elevations = [ele - cond_elevation for ele in tgt_elevations]
108
+ azimuths = [(azi - cond_azimuth) % (math.pi * 2) for azi in tgt_azimuths]
109
+ elevations = torch.as_tensor(elevations).float()
110
+ azimuths = torch.as_tensor(azimuths).float()
111
+ elevations_cond = torch.as_tensor([cond_elevation] * self.num_views).float()
112
+
113
+ bg_color = self.get_bg_color()
114
+ img_tensors_in = [
115
+ self.load_image(os.path.join(obj_path, cond_location['frames'][0]['name']), bg_color, return_type='pt').permute(2, 0, 1)
116
+ ] * self.num_views
117
+ img_tensors_out = []
118
+ for loc in tgt_locations:
119
+ img_path = os.path.join(obj_path, loc['frames'][0]['name'])
120
+ img_tensor = self.load_image(img_path, bg_color, return_type="pt").permute(2, 0, 1)
121
+ img_tensors_out.append(img_tensor)
122
+ img_tensors_in = torch.stack(img_tensors_in, dim=0).float() # (Nv, 3, H, W)
123
+ img_tensors_out = torch.stack(img_tensors_out, dim=0).float() # (Nv, 3, H, W)
124
+
125
+ camera_embeddings = torch.stack([elevations_cond, elevations, azimuths], dim=-1) # (Nv, 3)
126
+
127
+ return {
128
+ 'elevations_cond': elevations_cond,
129
+ 'elevations_cond_deg': torch.rad2deg(elevations_cond),
130
+ 'elevations': elevations,
131
+ 'azimuths': azimuths,
132
+ 'elevations_deg': torch.rad2deg(elevations),
133
+ 'azimuths_deg': torch.rad2deg(azimuths),
134
+ 'imgs_in': img_tensors_in,
135
+ 'imgs_out': img_tensors_out,
136
+ 'camera_embeddings': camera_embeddings
137
+ }
138
+
mvdiffusion/data/dataset_nc.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import decord
2
+ # decord.bridge.set_bridge('torch')
3
+
4
+ from torch.utils.data import Dataset
5
+ from einops import rearrange
6
+ from typing import Literal, Tuple, Optional, Any
7
+ import glob
8
+ import os
9
+ import json
10
+ import random
11
+ import cv2
12
+ import math
13
+ import numpy as np
14
+ import torch
15
+ from PIL import Image
16
+ from .normal_utils import trans_normal, img2normal, normal2img
17
+
18
+ """
19
+ load normal and color images together
20
+ """
21
+ class MVDiffusionDatasetV2(Dataset):
22
+ def __init__(
23
+ self,
24
+ root_dir: str,
25
+ num_views: int,
26
+ bg_color: Any,
27
+ img_wh: Tuple[int, int],
28
+ validation: bool = False,
29
+ num_validation_samples: int = 64,
30
+ num_samples: Optional[int] = None,
31
+ caption_path: Optional[str] = None,
32
+ elevation_range_deg: Tuple[float,float] = (-90, 90),
33
+ azimuth_range_deg: Tuple[float, float] = (0, 360),
34
+ ):
35
+ self.all_obj_paths = sorted(glob.glob(os.path.join(root_dir, "*/*")))
36
+ if not validation:
37
+ self.all_obj_paths = self.all_obj_paths[:-num_validation_samples]
38
+ else:
39
+ self.all_obj_paths = self.all_obj_paths[-num_validation_samples:]
40
+ if num_samples is not None:
41
+ self.all_obj_paths = self.all_obj_paths[:num_samples]
42
+ self.all_obj_ids = [os.path.basename(path) for path in self.all_obj_paths]
43
+ self.num_views = num_views
44
+ self.bg_color = bg_color
45
+ self.img_wh = img_wh
46
+
47
+ def get_bg_color(self):
48
+ if self.bg_color == 'white':
49
+ bg_color = np.array([1., 1., 1.], dtype=np.float32)
50
+ elif self.bg_color == 'black':
51
+ bg_color = np.array([0., 0., 0.], dtype=np.float32)
52
+ elif self.bg_color == 'gray':
53
+ bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32)
54
+ elif self.bg_color == 'random':
55
+ bg_color = np.random.rand(3)
56
+ elif isinstance(self.bg_color, float):
57
+ bg_color = np.array([self.bg_color] * 3, dtype=np.float32)
58
+ else:
59
+ raise NotImplementedError
60
+ return bg_color
61
+
62
+ def load_image(self, img_path, bg_color, return_type='np'):
63
+ # not using cv2 as may load in uint16 format
64
+ # img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # [0, 255]
65
+ # img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_CUBIC)
66
+ # pil always returns uint8
67
+ img = np.array(Image.open(img_path).resize(self.img_wh))
68
+ img = img.astype(np.float32) / 255. # [0, 1]
69
+ assert img.shape[-1] == 4 # RGBA
70
+
71
+ alpha = img[...,3:4]
72
+ img = img[...,:3] * alpha + bg_color * (1 - alpha)
73
+
74
+ if return_type == "np":
75
+ pass
76
+ elif return_type == "pt":
77
+ img = torch.from_numpy(img)
78
+ else:
79
+ raise NotImplementedError
80
+
81
+ return img, alpha
82
+
83
+ def load_normal(self, img_path, bg_color, alpha, RT_w2c=None, RT_w2c_cond=None, return_type='np'):
84
+ # not using cv2 as may load in uint16 format
85
+ # img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # [0, 255]
86
+ # img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_CUBIC)
87
+ # pil always returns uint8
88
+ normal = np.array(Image.open(img_path).resize(self.img_wh))
89
+
90
+ assert normal.shape[-1] == 3 # RGB
91
+
92
+ normal = trans_normal(img2normal(normal), RT_w2c, RT_w2c_cond)
93
+ img = normal2img(normal)
94
+
95
+ img = img.astype(np.float32) / 255. # [0, 1]
96
+
97
+ img = img[...,:3] * alpha + bg_color * (1 - alpha)
98
+
99
+ if return_type == "np":
100
+ pass
101
+ elif return_type == "pt":
102
+ img = torch.from_numpy(img)
103
+ else:
104
+ raise NotImplementedError
105
+
106
+ return img
107
+
108
+ def __len__(self):
109
+ return len(self.all_obj_ids)
110
+
111
+ def __getitem__(self, index):
112
+ obj_path = self.all_obj_paths[index]
113
+ obj_id = self.all_obj_ids[index]
114
+ with open(os.path.join(obj_path, 'meta.json')) as f:
115
+ meta = json.loads(f.read())
116
+
117
+ num_views_all = len(meta['locations'])
118
+ num_groups = num_views_all // self.num_views
119
+
120
+ # random a set of 4 views
121
+ # the data is arranged in ascending order of the azimuth angle
122
+ group_ids = random.sample(range(num_groups), k=2)
123
+ cond_group_id, tgt_group_id = group_ids
124
+ cond_location = meta['locations'][cond_group_id * self.num_views + random.randint(0, self.num_views - 1)]
125
+ tgt_locations = meta['locations'][tgt_group_id * self.num_views : tgt_group_id * self.num_views + self.num_views]
126
+ # random an order
127
+ start_id = random.randint(0, self.num_views - 1)
128
+ tgt_locations = tgt_locations[start_id:] + tgt_locations[:start_id]
129
+
130
+ cond_elevation = cond_location['elevation']
131
+ cond_azimuth = cond_location['azimuth']
132
+ cond_c2w = cond_location['transform_matrix']
133
+ cond_w2c = np.linalg.inv(cond_c2w)
134
+ tgt_elevations = [loc['elevation'] for loc in tgt_locations]
135
+ tgt_azimuths = [loc['azimuth'] for loc in tgt_locations]
136
+ tgt_c2ws = [loc['transform_matrix'] for loc in tgt_locations]
137
+ tgt_w2cs = [np.linalg.inv(loc['transform_matrix']) for loc in tgt_locations]
138
+
139
+ elevations = [ele - cond_elevation for ele in tgt_elevations]
140
+ azimuths = [(azi - cond_azimuth) % (math.pi * 2) for azi in tgt_azimuths]
141
+ elevations = torch.as_tensor(elevations).float()
142
+ azimuths = torch.as_tensor(azimuths).float()
143
+ elevations_cond = torch.as_tensor([cond_elevation] * self.num_views).float()
144
+
145
+ bg_color = self.get_bg_color()
146
+ img_tensors_in = [
147
+ self.load_image(os.path.join(obj_path, cond_location['frames'][0]['name']), bg_color, return_type='pt')[0].permute(2, 0, 1)
148
+ ] * self.num_views
149
+ img_tensors_out = []
150
+ normal_tensors_out = []
151
+ for loc, tgt_w2c in zip(tgt_locations, tgt_w2cs):
152
+ img_path = os.path.join(obj_path, loc['frames'][0]['name'])
153
+ img_tensor, alpha = self.load_image(img_path, bg_color, return_type="pt")
154
+ img_tensor = img_tensor.permute(2, 0, 1)
155
+ img_tensors_out.append(img_tensor)
156
+
157
+ normal_path = os.path.join(obj_path, loc['frames'][1]['name'])
158
+ normal_tensor = self.load_normal(normal_path, bg_color, alpha, RT_w2c=tgt_w2c, RT_w2c_cond=cond_w2c, return_type="pt").permute(2, 0, 1)
159
+ normal_tensors_out.append(normal_tensor)
160
+ img_tensors_in = torch.stack(img_tensors_in, dim=0).float() # (Nv, 3, H, W)
161
+ img_tensors_out = torch.stack(img_tensors_out, dim=0).float() # (Nv, 3, H, W)
162
+ normal_tensors_out = torch.stack(normal_tensors_out, dim=0).float() # (Nv, 3, H, W)
163
+
164
+ camera_embeddings = torch.stack([elevations_cond, elevations, azimuths], dim=-1) # (Nv, 3)
165
+
166
+ return {
167
+ 'elevations_cond': elevations_cond,
168
+ 'elevations_cond_deg': torch.rad2deg(elevations_cond),
169
+ 'elevations': elevations,
170
+ 'azimuths': azimuths,
171
+ 'elevations_deg': torch.rad2deg(elevations),
172
+ 'azimuths_deg': torch.rad2deg(azimuths),
173
+ 'imgs_in': img_tensors_in,
174
+ 'imgs_out': img_tensors_out,
175
+ 'normals_out': normal_tensors_out,
176
+ 'camera_embeddings': camera_embeddings
177
+ }
178
+
mvdiffusion/data/dreamdata.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ import numpy as np
3
+
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+ from pathlib import Path
7
+ import json
8
+ from PIL import Image
9
+ from torchvision import transforms
10
+ from einops import rearrange, repeat
11
+ from typing import Literal, Tuple, Optional, Any
12
+ import cv2
13
+ import random
14
+
15
+ import json
16
+ import os, sys
17
+ import math
18
+
19
+ from PIL import Image, ImageOps
20
+ from normal_utils import worldNormal2camNormal, plot_grid_images, img2normal, norm_normalize, deg2rad
21
+
22
+ import pdb
23
+ from icecream import ic
24
+ def shift_list(lst, n):
25
+ length = len(lst)
26
+ n = n % length # Ensure n is within the range of the list length
27
+ return lst[-n:] + lst[:-n]
28
+
29
+
30
+ class ObjaverseDataset(Dataset):
31
+ def __init__(self,
32
+ root_dir: str,
33
+ azi_interval: float,
34
+ random_views: int,
35
+ predict_relative_views: list,
36
+ bg_color: Any,
37
+ object_list: str,
38
+ prompt_embeds_path: str,
39
+ img_wh: Tuple[int, int],
40
+ validation: bool = False,
41
+ num_validation_samples: int = 64,
42
+ num_samples: Optional[int] = None,
43
+ invalid_list: Optional[str] = None,
44
+ trans_norm_system: bool = True, # if True, transform all normals map into the cam system of front view
45
+ # augment_data: bool = False,
46
+ side_views_rate: float = 0.,
47
+ read_normal: bool = True,
48
+ read_color: bool = False,
49
+ read_depth: bool = False,
50
+ mix_color_normal: bool = False,
51
+ random_view_and_domain: bool = False,
52
+ load_cache: bool = False,
53
+ exten: str = '.png',
54
+ elevation_list: Optional[str] = None,
55
+ ) -> None:
56
+ """Create a dataset from a folder of images.
57
+ If you pass in a root directory it will be searched for images
58
+ ending in ext (ext can be a list)
59
+ """
60
+ self.root_dir = root_dir
61
+ self.fixed_views = int(360 // azi_interval)
62
+ self.bg_color = bg_color
63
+ self.validation = validation
64
+ self.num_samples = num_samples
65
+ self.trans_norm_system = trans_norm_system
66
+ # self.augment_data = augment_data
67
+ self.invalid_list = invalid_list
68
+ self.img_wh = img_wh
69
+ self.read_normal = read_normal
70
+ self.read_color = read_color
71
+ self.read_depth = read_depth
72
+ self.mix_color_normal = mix_color_normal # mix load color and normal maps
73
+ self.random_view_and_domain = random_view_and_domain # load normal or rgb of a single view
74
+ self.random_views = random_views
75
+ self.load_cache = load_cache
76
+ self.total_views = int(self.fixed_views * (self.random_views + 1))
77
+ self.predict_relative_views = predict_relative_views
78
+ self.pred_view_nums = len(self.predict_relative_views)
79
+ self.exten = exten
80
+ self.side_views_rate = side_views_rate
81
+
82
+ # ic(self.augment_data)
83
+ ic(self.total_views)
84
+ ic(self.fixed_views)
85
+ ic(self.predict_relative_views)
86
+
87
+ self.objects = []
88
+ if object_list is not None:
89
+ for dataset_list in object_list:
90
+ with open(dataset_list, 'r') as f:
91
+ # objects = f.readlines()
92
+ # objects = [o.strip() for o in objects]
93
+ objects = json.load(f)
94
+ self.objects.extend(objects)
95
+ else:
96
+ self.objects = os.listdir(self.root_dir)
97
+
98
+ # load fixed camera poses
99
+ self.trans_cv2gl_mat = np.linalg.inv(np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]]))
100
+ self.fix_cam_poses = []
101
+ camera_path = os.path.join(self.root_dir, self.objects[0], 'camera')
102
+ for vid in range(0, self.total_views, self.random_views+1):
103
+ cam_info = np.load(f'{camera_path}/{vid:03d}.npy', allow_pickle=True).item()
104
+ assert cam_info['camera'] == 'ortho', 'Only support predict ortho camera !!!'
105
+ self.fix_cam_poses.append(cam_info['extrinsic'])
106
+ random.shuffle(self.objects)
107
+
108
+ # import pdb; pdb.set_trace()
109
+ invalid_objects = []
110
+ if self.invalid_list is not None:
111
+ for invalid_list in self.invalid_list:
112
+ if invalid_list[-4:] == '.txt':
113
+ with open(invalid_list, 'r') as f:
114
+ sub_invalid = f.readlines()
115
+ invalid_objects.extend([o.strip() for o in sub_invalid])
116
+ else:
117
+ with open(invalid_list) as f:
118
+ invalid_objects.extend(json.load(f))
119
+ self.invalid_objects = invalid_objects
120
+ ic(len(self.invalid_objects))
121
+
122
+ if elevation_list:
123
+ with open(elevation_list, 'r') as f:
124
+ ele_list = [o.strip() for o in f.readlines()]
125
+ self.objects = set(ele_list) & set(self.objects)
126
+
127
+ self.all_objects = set(self.objects) - (set(self.invalid_objects) & set(self.objects))
128
+ self.all_objects = list(self.all_objects)
129
+
130
+ self.validation = validation
131
+ if not validation:
132
+ self.all_objects = self.all_objects[:-num_validation_samples]
133
+ # print('Warning: you are fitting in small-scale dataset')
134
+ # self.all_objects = self.all_objects
135
+ else:
136
+ self.all_objects = self.all_objects[-num_validation_samples:]
137
+
138
+ if num_samples is not None:
139
+ self.all_objects = self.all_objects[:num_samples]
140
+ ic(len(self.all_objects))
141
+ print("loading ", len(self.all_objects), " objects in the dataset")
142
+
143
+ self.normal_prompt_embedding = torch.load(f'{prompt_embeds_path}/normal_embeds.pt')
144
+ self.color_prompt_embedding = torch.load(f'{prompt_embeds_path}/clr_embeds.pt')
145
+
146
+ if self.mix_color_normal:
147
+ self.backup_data = self.__getitem_mix__(0, '8609cf7e67bf413487a7d94c73aeaa3e')
148
+ else:
149
+ self.backup_data = self.__getitem_norm__(0, '8609cf7e67bf413487a7d94c73aeaa3e')
150
+
151
+ def trans_cv2gl(self, rt):
152
+ r, t = rt[:3, :3], rt[:3, -1]
153
+ r = np.matmul(self.trans_cv2gl_mat, r)
154
+ t = np.matmul(self.trans_cv2gl_mat, t)
155
+ return np.concatenate([r, t[:, None]], axis=-1)
156
+
157
+ def get_bg_color(self):
158
+ if self.bg_color == 'white':
159
+ bg_color = np.array([1., 1., 1.], dtype=np.float32)
160
+ elif self.bg_color == 'black':
161
+ bg_color = np.array([0., 0., 0.], dtype=np.float32)
162
+ elif self.bg_color == 'gray':
163
+ bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32)
164
+ elif self.bg_color == 'random':
165
+ bg_color = np.random.rand(3)
166
+ elif self.bg_color == 'three_choices':
167
+ white = np.array([1., 1., 1.], dtype=np.float32)
168
+ black = np.array([0., 0., 0.], dtype=np.float32)
169
+ gray = np.array([0.5, 0.5, 0.5], dtype=np.float32)
170
+ bg_color = random.choice([white, black, gray])
171
+ elif isinstance(self.bg_color, float):
172
+ bg_color = np.array([self.bg_color] * 3, dtype=np.float32)
173
+ else:
174
+ raise NotImplementedError
175
+ return bg_color
176
+
177
+
178
+ def load_image(self, img_path, bg_color, alpha=None, return_type='np'):
179
+ # not using cv2 as may load in uint16 format
180
+ # img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # [0, 255]
181
+ # img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_CUBIC)
182
+ # pil always returns uint8
183
+ rgba = np.array(Image.open(img_path).resize(self.img_wh))
184
+ rgba = rgba.astype(np.float32) / 255. # [0, 1]
185
+
186
+ img = rgba[..., :3]
187
+ if alpha is None:
188
+ assert rgba.shape[-1] == 4
189
+ alpha = rgba[..., 3:4]
190
+ assert alpha.sum() > 1e-8, 'w/o foreground'
191
+ img = img[...,:3] * alpha + bg_color * (1 - alpha)
192
+
193
+ if return_type == "np":
194
+ pass
195
+ elif return_type == "pt":
196
+ img = torch.from_numpy(img)
197
+ alpha = torch.from_numpy(alpha)
198
+ else:
199
+ raise NotImplementedError
200
+
201
+ return img, alpha
202
+
203
+ def load_depth(self, img_path, bg_color, alpha, input_type='png', return_type='np'):
204
+ # not using cv2 as may load in uint16 format
205
+ # img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # [0, 255]
206
+ # img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_CUBIC)
207
+ # pil always returns uint8
208
+ img = np.array(Image.open(img_path).resize(self.img_wh))
209
+ img = img.astype(np.float32) / 65535. # [0, 1]
210
+
211
+ img[img > 0.4] = 0
212
+ img = img / 0.4
213
+
214
+ assert img.ndim == 2 # depth
215
+ img = np.stack([img]*3, axis=-1)
216
+
217
+ if alpha.shape[-1] != 1:
218
+ alpha = alpha[:, :, None]
219
+
220
+ # print(np.max(img[:, :, 0]))
221
+
222
+ img = img[...,:3] * alpha + bg_color * (1 - alpha)
223
+
224
+ if return_type == "np":
225
+ pass
226
+ elif return_type == "pt":
227
+ img = torch.from_numpy(img)
228
+ else:
229
+ raise NotImplementedError
230
+
231
+ return img
232
+
233
+ def load_normal(self, img_path, bg_color, alpha, RT_w2c_cond=None, return_type='np'):
234
+ normal_np = np.array(Image.open(img_path).resize(self.img_wh))[:, :, :3]
235
+ assert np.var(normal_np) > 1e-8, 'pure normal'
236
+ normal_cv = img2normal(normal_np)
237
+
238
+ normal_relative_cv = worldNormal2camNormal(RT_w2c_cond[:3, :3], normal_cv)
239
+ normal_relative_cv = norm_normalize(normal_relative_cv)
240
+ # normal_relative_gl = normal_relative_cv[..., [ 0, 2, 1]]
241
+ # normal_relative_gl[..., 2] = -normal_relative_gl[..., 2]
242
+ normal_relative_gl = normal_relative_cv
243
+ normal_relative_gl[..., 1:] = -normal_relative_gl[..., 1:]
244
+
245
+ img = (normal_relative_cv*0.5 + 0.5).astype(np.float32) # [0, 1]
246
+
247
+ if alpha.shape[-1] != 1:
248
+ alpha = alpha[:, :, None]
249
+
250
+
251
+ img = img[...,:3] * alpha + bg_color * (1 - alpha)
252
+
253
+ if return_type == "np":
254
+ pass
255
+ elif return_type == "pt":
256
+ img = torch.from_numpy(img)
257
+ else:
258
+ raise NotImplementedError
259
+
260
+ return img
261
+
262
+ def __len__(self):
263
+ return len(self.all_objects)
264
+
265
+ def __getitem_norm__(self, index, debug_object=None):
266
+ # get the bg color
267
+ bg_color = self.get_bg_color()
268
+ if debug_object is not None:
269
+ object_name = debug_object
270
+ else:
271
+ object_name = self.all_objects[index % len(self.all_objects)]
272
+
273
+ if self.validation:
274
+ cond_ele0_idx = 12
275
+ else:
276
+ rand = random.random()
277
+ if rand < self.side_views_rate: # 0.1
278
+ cond_ele0_idx = random.sample([8, 0], 1)[0]
279
+ elif rand < 3 * self.side_views_rate: # 0.3
280
+ cond_ele0_idx = random.sample([10, 14], 1)[0]
281
+ else:
282
+ cond_ele0_idx = 12 # front view
283
+ cond_random_idx = random.sample(range(self.random_views+1), 1)[0]
284
+
285
+ # condition info
286
+ cond_ele0_vid = cond_ele0_idx * (self.random_views + 1)
287
+ cond_vid = cond_ele0_vid + cond_random_idx
288
+ cond_ele0_w2c = self.fix_cam_poses[cond_ele0_idx]
289
+ cond_info = np.load(f'{self.root_dir}/{object_name}/camera/{cond_vid:03d}.npy', allow_pickle=True).item()
290
+ cond_type = cond_info['camera']
291
+ focal_len = cond_info['focal']
292
+
293
+ cond_eles = np.array([deg2rad(cond_info['elevation'])])
294
+
295
+ img_tensors_in = [
296
+ self.load_image(f"{self.root_dir}/{object_name}/image/{cond_vid:03d}{self.exten}", bg_color, return_type='pt')[0].permute(2, 0, 1)
297
+ ] * self.pred_view_nums
298
+
299
+ # output info
300
+ pred_vids = [(cond_ele0_vid + i * (self.random_views+1)) % self.total_views for i in self.predict_relative_views]
301
+ # pred_w2cs = [self.fix_cam_poses[(cond_ele0_idx + i) % self.fixed_views] for i in self.predict_relative_views]
302
+ img_tensors_out = []
303
+ normal_tensors_out = []
304
+ for i, vid in enumerate(pred_vids):
305
+ try:
306
+ img_tensor, alpha_ = self.load_image(f"{self.root_dir}/{object_name}/image/{vid:03d}{self.exten}", bg_color, return_type='pt')
307
+ except:
308
+ img_tensor, alpha_ = self.load_image(f"{self.root_dir}/{object_name}/image_relit/{vid:03d}{self.exten}", bg_color, return_type='pt')
309
+
310
+ img_tensor = img_tensor.permute(2, 0, 1) # (3, H, W)
311
+ img_tensors_out.append(img_tensor)
312
+
313
+
314
+ normal_tensor = self.load_normal(f"{self.root_dir}/{object_name}/normal/{vid:03d}{self.exten}", bg_color, alpha_.numpy(), RT_w2c_cond=cond_ele0_w2c[:3, :], return_type="pt").permute(2, 0, 1)
315
+ normal_tensors_out.append(normal_tensor)
316
+
317
+
318
+ img_tensors_in = torch.stack(img_tensors_in, dim=0).float() # (Nv, 3, H, W)
319
+ img_tensors_out = torch.stack(img_tensors_out, dim=0).float() # (Nv, 3, H, W)
320
+ normal_tensors_out = torch.stack(normal_tensors_out, dim=0).float() # (Nv, 3, H, W)
321
+
322
+ elevations_cond = torch.as_tensor(cond_eles).float()
323
+ if cond_type == 'ortho':
324
+ focal_embed = torch.tensor([0.])
325
+ else:
326
+ focal_embed = torch.tensor([24./focal_len])
327
+
328
+
329
+ if not self.load_cache:
330
+ return {
331
+ 'elevations_cond': elevations_cond,
332
+ 'focal_cond': focal_embed,
333
+ 'id': object_name,
334
+ 'vid':cond_vid,
335
+ 'imgs_in': img_tensors_in,
336
+ 'imgs_out': img_tensors_out,
337
+ 'normals_out': normal_tensors_out,
338
+ 'normal_prompt_embeddings': self.normal_prompt_embedding,
339
+ 'color_prompt_embeddings': self.color_prompt_embedding
340
+ }
341
+
342
+
343
+
344
+ def __getitem__(self, index):
345
+ try:
346
+ return self.__getitem_norm__(index)
347
+ except:
348
+ print("load error ", self.all_objects[index%len(self.all_objects)] )
349
+ return self.backup_data
350
+
351
+
352
+
353
+
354
+
355
+
mvdiffusion/data/fixed_prompt_embeds_6view/clr_embeds.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9e51666588d0f075e031262744d371e12076160231aab19a531dbf7ab976e4d
3
+ size 946932
mvdiffusion/data/fixed_prompt_embeds_6view/normal_embeds.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:53dfcd17f62fbfd8aeba60b1b05fa7559d72179738fd048e2ac1d53e5be5ed9d
3
+ size 946941
mvdiffusion/data/generate_fixed_text_embeds.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import CLIPTokenizer, CLIPTextModel
2
+ import torch
3
+ import os
4
+
5
+ root = '/mnt/data/lipeng/'
6
+ pretrained_model_name_or_path = 'stabilityai/stable-diffusion-2-1-unclip'
7
+
8
+
9
+ weight_dtype = torch.float16
10
+ device = torch.device("cuda:0")
11
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
12
+ text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
13
+ text_encoder = text_encoder.to(device, dtype=weight_dtype)
14
+
15
+ def generate_mv_embeds():
16
+ path = './fixed_prompt_embeds_8view'
17
+ os.makedirs(path, exist_ok=True)
18
+ views = ["front", "front_right", "right", "back_right", "back", " back_left", "left", "front_left"]
19
+ # views = ["front", "front_right", "right", "back", "left", "front_left"]
20
+ # views = ["front", "right", "back", "left"]
21
+ clr_prompt = [f"a rendering image of 3D models, {view} view, color map." for view in views]
22
+ normal_prompt = [f"a rendering image of 3D models, {view} view, normal map." for view in views]
23
+
24
+
25
+ for id, text_prompt in enumerate([clr_prompt, normal_prompt]):
26
+ print(text_prompt)
27
+ text_inputs = tokenizer(text_prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt").to(device)
28
+ text_input_ids = text_inputs.input_ids
29
+ untruncated_ids = tokenizer(text_prompt, padding="longest", return_tensors="pt").input_ids
30
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
31
+ text_input_ids, untruncated_ids):
32
+ removed_text = tokenizer.batch_decode(
33
+ untruncated_ids[:, tokenizer.model_max_length - 1 : -1]
34
+ )
35
+ if hasattr(text_encoder.config, "use_attention_mask") and text_encoder.config.use_attention_mask:
36
+ attention_mask = text_inputs.attention_mask.to(device)
37
+ else:
38
+ attention_mask = None
39
+ prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=attention_mask,)
40
+ prompt_embeds = prompt_embeds[0].detach().cpu()
41
+ print(prompt_embeds.shape)
42
+
43
+
44
+ # print(prompt_embeds.dtype)
45
+ if id == 0:
46
+ torch.save(prompt_embeds, f'./{path}/clr_embeds.pt')
47
+ else:
48
+ torch.save(prompt_embeds, f'./{path}/normal_embeds.pt')
49
+ print('done')
50
+
51
+
52
+ def generate_img_embeds():
53
+ path = './fixed_prompt_embeds_persp2ortho'
54
+ os.makedirs(path, exist_ok=True)
55
+ text_prompt = ["a orthogonal renderining image of 3D models"]
56
+ print(text_prompt)
57
+ text_inputs = tokenizer(text_prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt").to(device)
58
+ text_input_ids = text_inputs.input_ids
59
+ untruncated_ids = tokenizer(text_prompt, padding="longest", return_tensors="pt").input_ids
60
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
61
+ text_input_ids, untruncated_ids):
62
+ removed_text = tokenizer.batch_decode(
63
+ untruncated_ids[:, tokenizer.model_max_length - 1 : -1]
64
+ )
65
+ if hasattr(text_encoder.config, "use_attention_mask") and text_encoder.config.use_attention_mask:
66
+ attention_mask = text_inputs.attention_mask.to(device)
67
+ else:
68
+ attention_mask = None
69
+ prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=attention_mask,)
70
+ prompt_embeds = prompt_embeds[0].detach().cpu()
71
+ print(prompt_embeds.shape)
72
+
73
+ # print(prompt_embeds.dtype)
74
+
75
+ torch.save(prompt_embeds, f'./{path}/embeds.pt')
76
+ print('done')
77
+
78
+ generate_img_embeds()
mvdiffusion/data/normal_utils.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ def deg2rad(deg):
3
+ return deg*np.pi/180
4
+
5
+ def inv_RT(RT):
6
+ # RT_h = np.concatenate([RT, np.array([[0,0,0,1]])], axis=0)
7
+ RT_inv = np.linalg.inv(RT)
8
+
9
+ return RT_inv[:3, :]
10
+ def camNormal2worldNormal(rot_c2w, camNormal):
11
+ H,W,_ = camNormal.shape
12
+ normal_img = np.matmul(rot_c2w[None, :, :], camNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3])
13
+
14
+ return normal_img
15
+
16
+ def worldNormal2camNormal(rot_w2c, normal_map_world):
17
+ H,W,_ = normal_map_world.shape
18
+ # normal_img = np.matmul(rot_w2c[None, :, :], worldNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3])
19
+
20
+ # faster version
21
+ # Reshape the normal map into a 2D array where each row represents a normal vector
22
+ normal_map_flat = normal_map_world.reshape(-1, 3)
23
+
24
+ # Transform the normal vectors using the transformation matrix
25
+ normal_map_camera_flat = np.dot(normal_map_flat, rot_w2c.T)
26
+
27
+ # Reshape the transformed normal map back to its original shape
28
+ normal_map_camera = normal_map_camera_flat.reshape(normal_map_world.shape)
29
+
30
+ return normal_map_camera
31
+
32
+ def trans_normal(normal, RT_w2c, RT_w2c_target):
33
+
34
+ # normal_world = camNormal2worldNormal(np.linalg.inv(RT_w2c[:3,:3]), normal)
35
+ # normal_target_cam = worldNormal2camNormal(RT_w2c_target[:3,:3], normal_world)
36
+
37
+ relative_RT = np.matmul(RT_w2c_target[:3,:3], np.linalg.inv(RT_w2c[:3,:3]))
38
+ return worldNormal2camNormal(relative_RT[:3,:3], normal)
39
+
40
+ def trans_normal_complex(normal, RT_w2c, RT_w2c_rela_to_cond):
41
+ # camview -> world -> condview
42
+ normal_world = camNormal2worldNormal(np.linalg.inv(RT_w2c[:3,:3]), normal)
43
+ # debug_normal_world = normal2img(normal_world)
44
+
45
+ # relative_RT = np.matmul(RT_w2c_rela_to_cond[:3,:3], np.linalg.inv(RT_w2c[:3,:3]))
46
+ normal_target_cam = worldNormal2camNormal(RT_w2c_rela_to_cond[:3,:3], normal_world)
47
+ # normal_condview = normal2img(normal_target_cam)
48
+ return normal_target_cam
49
+ def img2normal(img):
50
+ return (img/255.)*2-1
51
+
52
+ def normal2img(normal):
53
+ return np.uint8((normal*0.5+0.5)*255)
54
+
55
+ def norm_normalize(normal, dim=-1):
56
+
57
+ normal = normal/(np.linalg.norm(normal, axis=dim, keepdims=True)+1e-6)
58
+
59
+ return normal
60
+
61
+ def plot_grid_images(images, row, col, path=None):
62
+ import cv2
63
+ """
64
+ Args:
65
+ images: np.array [B, H, W, 3]
66
+ row:
67
+ col:
68
+ save_path:
69
+
70
+ Returns:
71
+
72
+ """
73
+ images = images.detach().cpu().numpy()
74
+ assert row * col == images.shape[0]
75
+ images = np.vstack([np.hstack(images[r * col:(r + 1) * col]) for r in range(row)])
76
+ if path:
77
+ cv2.imwrite(path, images[:,:,::-1] * 255)
78
+ return images
mvdiffusion/data/single_image_dataset.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ import numpy as np
3
+ from omegaconf import DictConfig, ListConfig
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+ from pathlib import Path
7
+ import json
8
+ from PIL import Image
9
+ from torchvision import transforms
10
+ from einops import rearrange
11
+ from typing import Literal, Tuple, Optional, Any
12
+ import cv2
13
+ import random
14
+
15
+ import json
16
+ import os, sys
17
+ import math
18
+
19
+ from glob import glob
20
+
21
+ import PIL.Image
22
+ from .normal_utils import trans_normal, normal2img, img2normal
23
+ import pdb
24
+ from icecream import ic
25
+
26
+ import cv2
27
+ import numpy as np
28
+
29
+ def add_margin(pil_img, color=0, size=256):
30
+ width, height = pil_img.size
31
+ result = Image.new(pil_img.mode, (size, size), color)
32
+ result.paste(pil_img, ((size - width) // 2, (size - height) // 2))
33
+ return result
34
+
35
+ def scale_and_place_object(image, scale_factor):
36
+ assert np.shape(image)[-1]==4 # RGBA
37
+
38
+ # Extract the alpha channel (transparency) and the object (RGB channels)
39
+ alpha_channel = image[:, :, 3]
40
+
41
+ # Find the bounding box coordinates of the object
42
+ coords = cv2.findNonZero(alpha_channel)
43
+ x, y, width, height = cv2.boundingRect(coords)
44
+
45
+ # Calculate the scale factor for resizing
46
+ original_height, original_width = image.shape[:2]
47
+
48
+ if width > height:
49
+ size = width
50
+ original_size = original_width
51
+ else:
52
+ size = height
53
+ original_size = original_height
54
+
55
+ scale_factor = min(scale_factor, size / (original_size+0.0))
56
+
57
+ new_size = scale_factor * original_size
58
+ scale_factor = new_size / size
59
+
60
+ # Calculate the new size based on the scale factor
61
+ new_width = int(width * scale_factor)
62
+ new_height = int(height * scale_factor)
63
+
64
+ center_x = original_width // 2
65
+ center_y = original_height // 2
66
+
67
+ paste_x = center_x - (new_width // 2)
68
+ paste_y = center_y - (new_height // 2)
69
+
70
+ # Resize the object (RGB channels) to the new size
71
+ rescaled_object = cv2.resize(image[y:y+height, x:x+width], (new_width, new_height))
72
+
73
+ # Create a new RGBA image with the resized image
74
+ new_image = np.zeros((original_height, original_width, 4), dtype=np.uint8)
75
+
76
+ new_image[paste_y:paste_y + new_height, paste_x:paste_x + new_width] = rescaled_object
77
+
78
+ return new_image
79
+
80
+ class SingleImageDataset(Dataset):
81
+ def __init__(self,
82
+ root_dir: str,
83
+ num_views: int,
84
+ img_wh: Tuple[int, int],
85
+ bg_color: str,
86
+ crop_size: int = 224,
87
+ single_image: Optional[PIL.Image.Image] = None,
88
+ num_validation_samples: Optional[int] = None,
89
+ filepaths: Optional[list] = None,
90
+ cond_type: Optional[str] = None,
91
+ prompt_embeds_path: Optional[str] = None,
92
+ gt_path: Optional[str] = None
93
+ ) -> None:
94
+ """Create a dataset from a folder of images.
95
+ If you pass in a root directory it will be searched for images
96
+ ending in ext (ext can be a list)
97
+ """
98
+ self.root_dir = root_dir
99
+ self.num_views = num_views
100
+ self.img_wh = img_wh
101
+ self.crop_size = crop_size
102
+ self.bg_color = bg_color
103
+ self.cond_type = cond_type
104
+ self.gt_path = gt_path
105
+
106
+
107
+ if single_image is None:
108
+ if filepaths is None:
109
+ # Get a list of all files in the directory
110
+ file_list = os.listdir(self.root_dir)
111
+ else:
112
+ file_list = filepaths
113
+
114
+ # Filter the files that end with .png or .jpg
115
+ self.file_list = [file for file in file_list if file.endswith(('.png', '.jpg', '.webp'))]
116
+ else:
117
+ self.file_list = None
118
+
119
+ # load all images
120
+ self.all_images = []
121
+ self.all_alphas = []
122
+ bg_color = self.get_bg_color()
123
+
124
+ if single_image is not None:
125
+ image, alpha = self.load_image(None, bg_color, return_type='pt', Imagefile=single_image)
126
+ self.all_images.append(image)
127
+ self.all_alphas.append(alpha)
128
+ else:
129
+ for file in self.file_list:
130
+ print(os.path.join(self.root_dir, file))
131
+ image, alpha = self.load_image(os.path.join(self.root_dir, file), bg_color, return_type='pt')
132
+ self.all_images.append(image)
133
+ self.all_alphas.append(alpha)
134
+
135
+
136
+
137
+ self.all_images = self.all_images[:num_validation_samples]
138
+ self.all_alphas = self.all_alphas[:num_validation_samples]
139
+ ic(len(self.all_images))
140
+
141
+ try:
142
+ self.normal_text_embeds = torch.load(f'{prompt_embeds_path}/normal_embeds.pt')
143
+ self.color_text_embeds = torch.load(f'{prompt_embeds_path}/clr_embeds.pt') # 4view
144
+ except:
145
+ self.color_text_embeds = torch.load(f'{prompt_embeds_path}/embeds.pt')
146
+ self.normal_text_embeds = None
147
+
148
+ def __len__(self):
149
+ return len(self.all_images)
150
+
151
+ def get_bg_color(self):
152
+ if self.bg_color == 'white':
153
+ bg_color = np.array([1., 1., 1.], dtype=np.float32)
154
+ elif self.bg_color == 'black':
155
+ bg_color = np.array([0., 0., 0.], dtype=np.float32)
156
+ elif self.bg_color == 'gray':
157
+ bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32)
158
+ elif self.bg_color == 'random':
159
+ bg_color = np.random.rand(3)
160
+ elif isinstance(self.bg_color, float):
161
+ bg_color = np.array([self.bg_color] * 3, dtype=np.float32)
162
+ else:
163
+ raise NotImplementedError
164
+ return bg_color
165
+
166
+
167
+ def load_image(self, img_path, bg_color, return_type='np', Imagefile=None):
168
+ # pil always returns uint8
169
+ if Imagefile is None:
170
+ image_input = Image.open(img_path)
171
+ else:
172
+ image_input = Imagefile
173
+ image_size = self.img_wh[0]
174
+
175
+ if self.crop_size!=-1:
176
+ alpha_np = np.asarray(image_input)[:, :, 3]
177
+ coords = np.stack(np.nonzero(alpha_np), 1)[:, (1, 0)]
178
+ min_x, min_y = np.min(coords, 0)
179
+ max_x, max_y = np.max(coords, 0)
180
+ ref_img_ = image_input.crop((min_x, min_y, max_x, max_y))
181
+ h, w = ref_img_.height, ref_img_.width
182
+ scale = self.crop_size / max(h, w)
183
+ h_, w_ = int(scale * h), int(scale * w)
184
+ ref_img_ = ref_img_.resize((w_, h_))
185
+ image_input = add_margin(ref_img_, size=image_size)
186
+ else:
187
+ image_input = add_margin(image_input, size=max(image_input.height, image_input.width))
188
+ image_input = image_input.resize((image_size, image_size))
189
+
190
+ # img = scale_and_place_object(img, self.scale_ratio)
191
+ img = np.array(image_input)
192
+ img = img.astype(np.float32) / 255. # [0, 1]
193
+ assert img.shape[-1] == 4 # RGBA
194
+
195
+ alpha = img[...,3:4]
196
+ img = img[...,:3] * alpha + bg_color * (1 - alpha)
197
+
198
+ if return_type == "np":
199
+ pass
200
+ elif return_type == "pt":
201
+ img = torch.from_numpy(img)
202
+ alpha = torch.from_numpy(alpha)
203
+ else:
204
+ raise NotImplementedError
205
+
206
+ return img, alpha
207
+
208
+
209
+ def __getitem__(self, index):
210
+ image = self.all_images[index%len(self.all_images)]
211
+ alpha = self.all_alphas[index%len(self.all_images)]
212
+ if self.file_list is not None:
213
+ filename = self.file_list[index%len(self.all_images)].replace(".png", "")
214
+ else:
215
+ filename = 'null'
216
+ img_tensors_in = [
217
+ image.permute(2, 0, 1)
218
+ ] * self.num_views
219
+
220
+ alpha_tensors_in = [
221
+ alpha.permute(2, 0, 1)
222
+ ] * self.num_views
223
+
224
+ img_tensors_in = torch.stack(img_tensors_in, dim=0).float() # (Nv, 3, H, W)
225
+ alpha_tensors_in = torch.stack(alpha_tensors_in, dim=0).float() # (Nv, 3, H, W)
226
+
227
+ if self.gt_path is not None:
228
+ gt_image = self.gt_images[index%len(self.all_images)]
229
+ gt_alpha = self.gt_alpha[index%len(self.all_images)]
230
+ gt_img_tensors_in = [gt_image.permute(2, 0, 1) ] * self.num_views
231
+ gt_alpha_tensors_in = [gt_alpha.permute(2, 0, 1) ] * self.num_views
232
+ gt_img_tensors_in = torch.stack(gt_img_tensors_in, dim=0).float()
233
+ gt_alpha_tensors_in = torch.stack(gt_alpha_tensors_in, dim=0).float()
234
+
235
+ normal_prompt_embeddings = self.normal_text_embeds if hasattr(self, 'normal_text_embeds') else None
236
+ color_prompt_embeddings = self.color_text_embeds if hasattr(self, 'color_text_embeds') else None
237
+
238
+ out = {
239
+ 'imgs_in': img_tensors_in.unsqueeze(0),
240
+ 'alphas': alpha_tensors_in.unsqueeze(0),
241
+ 'normal_prompt_embeddings': normal_prompt_embeddings.unsqueeze(0),
242
+ 'color_prompt_embeddings': color_prompt_embeddings.unsqueeze(0),
243
+ 'filename': filename,
244
+ }
245
+
246
+ return out
247
+
248
+
249
+
mvdiffusion/models/transformer_mv2d_image.py ADDED
@@ -0,0 +1,1029 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, Optional
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
23
+ from diffusers.utils import BaseOutput, deprecate
24
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
25
+ from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero, Attention
26
+ from diffusers.models.embeddings import PatchEmbed
27
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
28
+ from diffusers.models.modeling_utils import ModelMixin
29
+ from diffusers.utils.import_utils import is_xformers_available
30
+
31
+ from einops import rearrange, repeat
32
+ import pdb
33
+ import random
34
+
35
+
36
+ if is_xformers_available():
37
+ import xformers
38
+ import xformers.ops
39
+ else:
40
+ xformers = None
41
+
42
+ def my_repeat(tensor, num_repeats):
43
+ """
44
+ Repeat a tensor along a given dimension
45
+ """
46
+ if len(tensor.shape) == 3:
47
+ return repeat(tensor, "b d c -> (b v) d c", v=num_repeats)
48
+ elif len(tensor.shape) == 4:
49
+ return repeat(tensor, "a b d c -> (a v) b d c", v=num_repeats)
50
+
51
+
52
+ @dataclass
53
+ class TransformerMV2DModelOutput(BaseOutput):
54
+ """
55
+ The output of [`Transformer2DModel`].
56
+
57
+ Args:
58
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
59
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
60
+ distributions for the unnoised latent pixels.
61
+ """
62
+
63
+ sample: torch.FloatTensor
64
+
65
+
66
+ class TransformerMV2DModel(ModelMixin, ConfigMixin):
67
+ """
68
+ A 2D Transformer model for image-like data.
69
+
70
+ Parameters:
71
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
72
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
73
+ in_channels (`int`, *optional*):
74
+ The number of channels in the input and output (specify if the input is **continuous**).
75
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
76
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
77
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
78
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
79
+ This is fixed during training since it is used to learn a number of position embeddings.
80
+ num_vector_embeds (`int`, *optional*):
81
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
82
+ Includes the class for the masked latent pixel.
83
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
84
+ num_embeds_ada_norm ( `int`, *optional*):
85
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
86
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
87
+ added to the hidden states.
88
+
89
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
90
+ attention_bias (`bool`, *optional*):
91
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
92
+ """
93
+
94
+ @register_to_config
95
+ def __init__(
96
+ self,
97
+ num_attention_heads: int = 16,
98
+ attention_head_dim: int = 88,
99
+ in_channels: Optional[int] = None,
100
+ out_channels: Optional[int] = None,
101
+ num_layers: int = 1,
102
+ dropout: float = 0.0,
103
+ norm_num_groups: int = 32,
104
+ cross_attention_dim: Optional[int] = None,
105
+ attention_bias: bool = False,
106
+ sample_size: Optional[int] = None,
107
+ num_vector_embeds: Optional[int] = None,
108
+ patch_size: Optional[int] = None,
109
+ activation_fn: str = "geglu",
110
+ num_embeds_ada_norm: Optional[int] = None,
111
+ use_linear_projection: bool = False,
112
+ only_cross_attention: bool = False,
113
+ upcast_attention: bool = False,
114
+ norm_type: str = "layer_norm",
115
+ norm_elementwise_affine: bool = True,
116
+ num_views: int = 1,
117
+ cd_attention_last: bool=False,
118
+ cd_attention_mid: bool=False,
119
+ multiview_attention: bool=True,
120
+ sparse_mv_attention: bool = False,
121
+ mvcd_attention: bool=False
122
+ ):
123
+ super().__init__()
124
+ self.use_linear_projection = use_linear_projection
125
+ self.num_attention_heads = num_attention_heads
126
+ self.attention_head_dim = attention_head_dim
127
+ inner_dim = num_attention_heads * attention_head_dim
128
+
129
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
130
+ # Define whether input is continuous or discrete depending on configuration
131
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
132
+ self.is_input_vectorized = num_vector_embeds is not None
133
+ self.is_input_patches = in_channels is not None and patch_size is not None
134
+
135
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
136
+ deprecation_message = (
137
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
138
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
139
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
140
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
141
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
142
+ )
143
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
144
+ norm_type = "ada_norm"
145
+
146
+ if self.is_input_continuous and self.is_input_vectorized:
147
+ raise ValueError(
148
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
149
+ " sure that either `in_channels` or `num_vector_embeds` is None."
150
+ )
151
+ elif self.is_input_vectorized and self.is_input_patches:
152
+ raise ValueError(
153
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
154
+ " sure that either `num_vector_embeds` or `num_patches` is None."
155
+ )
156
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
157
+ raise ValueError(
158
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
159
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
160
+ )
161
+
162
+ # 2. Define input layers
163
+ if self.is_input_continuous:
164
+ self.in_channels = in_channels
165
+
166
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
167
+ if use_linear_projection:
168
+ self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
169
+ else:
170
+ self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
171
+ elif self.is_input_vectorized:
172
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
173
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
174
+
175
+ self.height = sample_size
176
+ self.width = sample_size
177
+ self.num_vector_embeds = num_vector_embeds
178
+ self.num_latent_pixels = self.height * self.width
179
+
180
+ self.latent_image_embedding = ImagePositionalEmbeddings(
181
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
182
+ )
183
+ elif self.is_input_patches:
184
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
185
+
186
+ self.height = sample_size
187
+ self.width = sample_size
188
+
189
+ self.patch_size = patch_size
190
+ self.pos_embed = PatchEmbed(
191
+ height=sample_size,
192
+ width=sample_size,
193
+ patch_size=patch_size,
194
+ in_channels=in_channels,
195
+ embed_dim=inner_dim,
196
+ )
197
+
198
+ # 3. Define transformers blocks
199
+ self.transformer_blocks = nn.ModuleList(
200
+ [
201
+ BasicMVTransformerBlock(
202
+ inner_dim,
203
+ num_attention_heads,
204
+ attention_head_dim,
205
+ dropout=dropout,
206
+ cross_attention_dim=cross_attention_dim,
207
+ activation_fn=activation_fn,
208
+ num_embeds_ada_norm=num_embeds_ada_norm,
209
+ attention_bias=attention_bias,
210
+ only_cross_attention=only_cross_attention,
211
+ upcast_attention=upcast_attention,
212
+ norm_type=norm_type,
213
+ norm_elementwise_affine=norm_elementwise_affine,
214
+ num_views=num_views,
215
+ cd_attention_last=cd_attention_last,
216
+ cd_attention_mid=cd_attention_mid,
217
+ multiview_attention=multiview_attention,
218
+ sparse_mv_attention=sparse_mv_attention,
219
+ mvcd_attention=mvcd_attention
220
+ )
221
+ for d in range(num_layers)
222
+ ]
223
+ )
224
+
225
+ # 4. Define output layers
226
+ self.out_channels = in_channels if out_channels is None else out_channels
227
+ if self.is_input_continuous:
228
+ # TODO: should use out_channels for continuous projections
229
+ if use_linear_projection:
230
+ self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
231
+ else:
232
+ self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
233
+ elif self.is_input_vectorized:
234
+ self.norm_out = nn.LayerNorm(inner_dim)
235
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
236
+ elif self.is_input_patches:
237
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
238
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
239
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
240
+
241
+ def forward(
242
+ self,
243
+ hidden_states: torch.Tensor,
244
+ encoder_hidden_states: Optional[torch.Tensor] = None,
245
+ timestep: Optional[torch.LongTensor] = None,
246
+ class_labels: Optional[torch.LongTensor] = None,
247
+ cross_attention_kwargs: Dict[str, Any] = None,
248
+ attention_mask: Optional[torch.Tensor] = None,
249
+ encoder_attention_mask: Optional[torch.Tensor] = None,
250
+ return_dict: bool = True,
251
+ ):
252
+ """
253
+ The [`Transformer2DModel`] forward method.
254
+
255
+ Args:
256
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
257
+ Input `hidden_states`.
258
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
259
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
260
+ self-attention.
261
+ timestep ( `torch.LongTensor`, *optional*):
262
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
263
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
264
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
265
+ `AdaLayerZeroNorm`.
266
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
267
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
268
+
269
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
270
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
271
+
272
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
273
+ above. This bias will be added to the cross-attention scores.
274
+ return_dict (`bool`, *optional*, defaults to `True`):
275
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
276
+ tuple.
277
+
278
+ Returns:
279
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
280
+ `tuple` where the first element is the sample tensor.
281
+ """
282
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
283
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
284
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
285
+ # expects mask of shape:
286
+ # [batch, key_tokens]
287
+ # adds singleton query_tokens dimension:
288
+ # [batch, 1, key_tokens]
289
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
290
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
291
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
292
+ if attention_mask is not None and attention_mask.ndim == 2:
293
+ # assume that mask is expressed as:
294
+ # (1 = keep, 0 = discard)
295
+ # convert mask into a bias that can be added to attention scores:
296
+ # (keep = +0, discard = -10000.0)
297
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
298
+ attention_mask = attention_mask.unsqueeze(1)
299
+
300
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
301
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
302
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
303
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
304
+
305
+ # 1. Input
306
+ if self.is_input_continuous:
307
+ batch, _, height, width = hidden_states.shape
308
+ residual = hidden_states
309
+
310
+ hidden_states = self.norm(hidden_states)
311
+ if not self.use_linear_projection:
312
+ hidden_states = self.proj_in(hidden_states)
313
+ inner_dim = hidden_states.shape[1]
314
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
315
+ else:
316
+ inner_dim = hidden_states.shape[1]
317
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
318
+ hidden_states = self.proj_in(hidden_states)
319
+ elif self.is_input_vectorized:
320
+ hidden_states = self.latent_image_embedding(hidden_states)
321
+ elif self.is_input_patches:
322
+ hidden_states = self.pos_embed(hidden_states)
323
+
324
+ # 2. Blocks
325
+ for block in self.transformer_blocks:
326
+ hidden_states = block(
327
+ hidden_states,
328
+ attention_mask=attention_mask,
329
+ encoder_hidden_states=encoder_hidden_states,
330
+ encoder_attention_mask=encoder_attention_mask,
331
+ timestep=timestep,
332
+ cross_attention_kwargs=cross_attention_kwargs,
333
+ class_labels=class_labels,
334
+ )
335
+
336
+ # 3. Output
337
+ if self.is_input_continuous:
338
+ if not self.use_linear_projection:
339
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
340
+ hidden_states = self.proj_out(hidden_states)
341
+ else:
342
+ hidden_states = self.proj_out(hidden_states)
343
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
344
+
345
+ output = hidden_states + residual
346
+ elif self.is_input_vectorized:
347
+ hidden_states = self.norm_out(hidden_states)
348
+ logits = self.out(hidden_states)
349
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
350
+ logits = logits.permute(0, 2, 1)
351
+
352
+ # log(p(x_0))
353
+ output = F.log_softmax(logits.double(), dim=1).float()
354
+ elif self.is_input_patches:
355
+ # TODO: cleanup!
356
+ conditioning = self.transformer_blocks[0].norm1.emb(
357
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
358
+ )
359
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
360
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
361
+ hidden_states = self.proj_out_2(hidden_states)
362
+
363
+ # unpatchify
364
+ height = width = int(hidden_states.shape[1] ** 0.5)
365
+ hidden_states = hidden_states.reshape(
366
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
367
+ )
368
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
369
+ output = hidden_states.reshape(
370
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
371
+ )
372
+
373
+ if not return_dict:
374
+ return (output,)
375
+
376
+ return TransformerMV2DModelOutput(sample=output)
377
+
378
+
379
+ @maybe_allow_in_graph
380
+ class BasicMVTransformerBlock(nn.Module):
381
+ r"""
382
+ A basic Transformer block.
383
+
384
+ Parameters:
385
+ dim (`int`): The number of channels in the input and output.
386
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
387
+ attention_head_dim (`int`): The number of channels in each head.
388
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
389
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
390
+ only_cross_attention (`bool`, *optional*):
391
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
392
+ double_self_attention (`bool`, *optional*):
393
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
394
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
395
+ num_embeds_ada_norm (:
396
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
397
+ attention_bias (:
398
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
399
+ """
400
+
401
+ def __init__(
402
+ self,
403
+ dim: int,
404
+ num_attention_heads: int,
405
+ attention_head_dim: int,
406
+ dropout=0.0,
407
+ cross_attention_dim: Optional[int] = None,
408
+ activation_fn: str = "geglu",
409
+ num_embeds_ada_norm: Optional[int] = None,
410
+ attention_bias: bool = False,
411
+ only_cross_attention: bool = False,
412
+ double_self_attention: bool = False,
413
+ upcast_attention: bool = False,
414
+ norm_elementwise_affine: bool = True,
415
+ norm_type: str = "layer_norm",
416
+ final_dropout: bool = False,
417
+ num_views: int = 1,
418
+ cd_attention_last: bool = False,
419
+ cd_attention_mid: bool = False,
420
+ multiview_attention: bool = True,
421
+ sparse_mv_attention: bool = False,
422
+ mvcd_attention: bool = False
423
+ ):
424
+ super().__init__()
425
+ self.only_cross_attention = only_cross_attention
426
+
427
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
428
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
429
+
430
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
431
+ raise ValueError(
432
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
433
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
434
+ )
435
+
436
+ # Define 3 blocks. Each block has its own normalization layer.
437
+ # 1. Self-Attn
438
+ if self.use_ada_layer_norm:
439
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
440
+ elif self.use_ada_layer_norm_zero:
441
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
442
+ else:
443
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
444
+
445
+ self.multiview_attention = multiview_attention
446
+ self.sparse_mv_attention = sparse_mv_attention
447
+ self.mvcd_attention = mvcd_attention
448
+
449
+ self.attn1 = CustomAttention(
450
+ query_dim=dim,
451
+ heads=num_attention_heads,
452
+ dim_head=attention_head_dim,
453
+ dropout=dropout,
454
+ bias=attention_bias,
455
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
456
+ upcast_attention=upcast_attention,
457
+ processor=MVAttnProcessor()
458
+ )
459
+
460
+ # 2. Cross-Attn
461
+ if cross_attention_dim is not None or double_self_attention:
462
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
463
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
464
+ # the second cross attention block.
465
+ self.norm2 = (
466
+ AdaLayerNorm(dim, num_embeds_ada_norm)
467
+ if self.use_ada_layer_norm
468
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
469
+ )
470
+ self.attn2 = Attention(
471
+ query_dim=dim,
472
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
473
+ heads=num_attention_heads,
474
+ dim_head=attention_head_dim,
475
+ dropout=dropout,
476
+ bias=attention_bias,
477
+ upcast_attention=upcast_attention,
478
+ ) # is self-attn if encoder_hidden_states is none
479
+ else:
480
+ self.norm2 = None
481
+ self.attn2 = None
482
+
483
+ # 3. Feed-forward
484
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
485
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
486
+
487
+ # let chunk size default to None
488
+ self._chunk_size = None
489
+ self._chunk_dim = 0
490
+
491
+ self.num_views = num_views
492
+
493
+ self.cd_attention_last = cd_attention_last
494
+
495
+ if self.cd_attention_last:
496
+ # Joint task -Attn
497
+ self.attn_joint_last = CustomJointAttention(
498
+ query_dim=dim,
499
+ heads=num_attention_heads,
500
+ dim_head=attention_head_dim,
501
+ dropout=dropout,
502
+ bias=attention_bias,
503
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
504
+ upcast_attention=upcast_attention,
505
+ processor=JointAttnProcessor()
506
+ )
507
+ nn.init.zeros_(self.attn_joint_last.to_out[0].weight.data)
508
+ self.norm_joint_last = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
509
+
510
+
511
+ self.cd_attention_mid = cd_attention_mid
512
+
513
+ if self.cd_attention_mid:
514
+ print("cross-domain attn in the middle")
515
+ # Joint task -Attn
516
+ self.attn_joint_mid = CustomJointAttention(
517
+ query_dim=dim,
518
+ heads=num_attention_heads,
519
+ dim_head=attention_head_dim,
520
+ dropout=dropout,
521
+ bias=attention_bias,
522
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
523
+ upcast_attention=upcast_attention,
524
+ processor=JointAttnProcessor()
525
+ )
526
+ nn.init.zeros_(self.attn_joint_mid.to_out[0].weight.data)
527
+ self.norm_joint_mid = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
528
+
529
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
530
+ # Sets chunk feed-forward
531
+ self._chunk_size = chunk_size
532
+ self._chunk_dim = dim
533
+
534
+ def forward(
535
+ self,
536
+ hidden_states: torch.FloatTensor,
537
+ attention_mask: Optional[torch.FloatTensor] = None,
538
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
539
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
540
+ timestep: Optional[torch.LongTensor] = None,
541
+ cross_attention_kwargs: Dict[str, Any] = None,
542
+ class_labels: Optional[torch.LongTensor] = None,
543
+ ):
544
+ assert attention_mask is None # not supported yet
545
+ # Notice that normalization is always applied before the real computation in the following blocks.
546
+ # 1. Self-Attention
547
+ if self.use_ada_layer_norm:
548
+ norm_hidden_states = self.norm1(hidden_states, timestep)
549
+ elif self.use_ada_layer_norm_zero:
550
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
551
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
552
+ )
553
+ else:
554
+ norm_hidden_states = self.norm1(hidden_states)
555
+
556
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
557
+
558
+ attn_output = self.attn1(
559
+ norm_hidden_states,
560
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
561
+ attention_mask=attention_mask,
562
+ num_views=self.num_views,
563
+ multiview_attention=self.multiview_attention,
564
+ sparse_mv_attention=self.sparse_mv_attention,
565
+ mvcd_attention=self.mvcd_attention,
566
+ **cross_attention_kwargs,
567
+ )
568
+
569
+
570
+ if self.use_ada_layer_norm_zero:
571
+ attn_output = gate_msa.unsqueeze(1) * attn_output
572
+ hidden_states = attn_output + hidden_states
573
+
574
+ # joint attention twice
575
+ if self.cd_attention_mid:
576
+ norm_hidden_states = (
577
+ self.norm_joint_mid(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_mid(hidden_states)
578
+ )
579
+ hidden_states = self.attn_joint_mid(norm_hidden_states) + hidden_states
580
+
581
+ # 2. Cross-Attention
582
+ if self.attn2 is not None:
583
+ norm_hidden_states = (
584
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
585
+ )
586
+
587
+ attn_output = self.attn2(
588
+ norm_hidden_states,
589
+ encoder_hidden_states=encoder_hidden_states,
590
+ attention_mask=encoder_attention_mask,
591
+ **cross_attention_kwargs,
592
+ )
593
+ hidden_states = attn_output + hidden_states
594
+
595
+ # 3. Feed-forward
596
+ norm_hidden_states = self.norm3(hidden_states)
597
+
598
+ if self.use_ada_layer_norm_zero:
599
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
600
+
601
+ if self._chunk_size is not None:
602
+ # "feed_forward_chunk_size" can be used to save memory
603
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
604
+ raise ValueError(
605
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
606
+ )
607
+
608
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
609
+ ff_output = torch.cat(
610
+ [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
611
+ dim=self._chunk_dim,
612
+ )
613
+ else:
614
+ ff_output = self.ff(norm_hidden_states)
615
+
616
+ if self.use_ada_layer_norm_zero:
617
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
618
+
619
+ hidden_states = ff_output + hidden_states
620
+
621
+ if self.cd_attention_last:
622
+ norm_hidden_states = (
623
+ self.norm_joint_last(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_last(hidden_states)
624
+ )
625
+ hidden_states = self.attn_joint_last(norm_hidden_states) + hidden_states
626
+
627
+ return hidden_states
628
+
629
+
630
+ class CustomAttention(Attention):
631
+ def set_use_memory_efficient_attention_xformers(
632
+ self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
633
+ ):
634
+ processor = XFormersMVAttnProcessor()
635
+ self.set_processor(processor)
636
+ # print("using xformers attention processor")
637
+
638
+
639
+ class CustomJointAttention(Attention):
640
+ def set_use_memory_efficient_attention_xformers(
641
+ self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
642
+ ):
643
+ processor = XFormersJointAttnProcessor()
644
+ self.set_processor(processor)
645
+ # print("using xformers attention processor")
646
+
647
+ class MVAttnProcessor:
648
+ r"""
649
+ Default processor for performing attention-related computations.
650
+ """
651
+
652
+ def __call__(
653
+ self,
654
+ attn: Attention,
655
+ hidden_states,
656
+ encoder_hidden_states=None,
657
+ attention_mask=None,
658
+ temb=None,
659
+ num_views=1,
660
+ multiview_attention=True
661
+ ):
662
+ residual = hidden_states
663
+
664
+ if attn.spatial_norm is not None:
665
+ hidden_states = attn.spatial_norm(hidden_states, temb)
666
+
667
+ input_ndim = hidden_states.ndim
668
+
669
+ if input_ndim == 4:
670
+ batch_size, channel, height, width = hidden_states.shape
671
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
672
+
673
+ batch_size, sequence_length, _ = (
674
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
675
+ )
676
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
677
+
678
+ if attn.group_norm is not None:
679
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
680
+
681
+ query = attn.to_q(hidden_states)
682
+
683
+ if encoder_hidden_states is None:
684
+ encoder_hidden_states = hidden_states
685
+ elif attn.norm_cross:
686
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
687
+
688
+ key = attn.to_k(encoder_hidden_states)
689
+ value = attn.to_v(encoder_hidden_states)
690
+
691
+ # print('query', query.shape, 'key', key.shape, 'value', value.shape)
692
+ #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320])
693
+ # pdb.set_trace()
694
+ # multi-view self-attention
695
+ if multiview_attention:
696
+ if num_views <= 6:
697
+ # after use xformer; possible to train with 6 views
698
+ key = rearrange(key, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
699
+ value = rearrange(value, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
700
+ else:# apply sparse attention
701
+ pass
702
+ # print("use sparse attention")
703
+ # # seems that the sparse random sampling cause problems
704
+ # # don't use random sampling, just fix the indexes
705
+ # onekey = rearrange(key, "(b t) d c -> b t d c", t=num_views)
706
+ # onevalue = rearrange(value, "(b t) d c -> b t d c", t=num_views)
707
+ # allkeys = []
708
+ # allvalues = []
709
+ # all_indexes = {
710
+ # 0 : [0, 2, 3, 4],
711
+ # 1: [0, 1, 3, 5],
712
+ # 2: [0, 2, 3, 4],
713
+ # 3: [0, 2, 3, 4],
714
+ # 4: [0, 2, 3, 4],
715
+ # 5: [0, 1, 3, 5]
716
+ # }
717
+ # for jj in range(num_views):
718
+ # # valid_index = [x for x in range(0, num_views) if x!= jj]
719
+ # # indexes = random.sample(valid_index, 3) + [jj] + [0]
720
+ # indexes = all_indexes[jj]
721
+
722
+ # indexes = torch.tensor(indexes).long().to(key.device)
723
+ # allkeys.append(onekey[:, indexes])
724
+ # allvalues.append(onevalue[:, indexes])
725
+ # keys = torch.stack(allkeys, dim=1) # checked, should be dim=1
726
+ # values = torch.stack(allvalues, dim=1)
727
+ # key = rearrange(keys, 'b t f d c -> (b t) (f d) c')
728
+ # value = rearrange(values, 'b t f d c -> (b t) (f d) c')
729
+
730
+
731
+ query = attn.head_to_batch_dim(query).contiguous()
732
+ key = attn.head_to_batch_dim(key).contiguous()
733
+ value = attn.head_to_batch_dim(value).contiguous()
734
+
735
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
736
+ hidden_states = torch.bmm(attention_probs, value)
737
+ hidden_states = attn.batch_to_head_dim(hidden_states)
738
+
739
+ # linear proj
740
+ hidden_states = attn.to_out[0](hidden_states)
741
+ # dropout
742
+ hidden_states = attn.to_out[1](hidden_states)
743
+
744
+ if input_ndim == 4:
745
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
746
+
747
+ if attn.residual_connection:
748
+ hidden_states = hidden_states + residual
749
+
750
+ hidden_states = hidden_states / attn.rescale_output_factor
751
+
752
+ return hidden_states
753
+
754
+
755
+ class XFormersMVAttnProcessor:
756
+ r"""
757
+ Default processor for performing attention-related computations.
758
+ """
759
+
760
+ def __call__(
761
+ self,
762
+ attn: Attention,
763
+ hidden_states,
764
+ encoder_hidden_states=None,
765
+ attention_mask=None,
766
+ temb=None,
767
+ num_views=1.,
768
+ multiview_attention=True,
769
+ sparse_mv_attention=False,
770
+ mvcd_attention=False,
771
+ ):
772
+ residual = hidden_states
773
+
774
+ if attn.spatial_norm is not None:
775
+ hidden_states = attn.spatial_norm(hidden_states, temb)
776
+
777
+ input_ndim = hidden_states.ndim
778
+
779
+ if input_ndim == 4:
780
+ batch_size, channel, height, width = hidden_states.shape
781
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
782
+
783
+ batch_size, sequence_length, _ = (
784
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
785
+ )
786
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
787
+
788
+ # from yuancheng; here attention_mask is None
789
+ if attention_mask is not None:
790
+ # expand our mask's singleton query_tokens dimension:
791
+ # [batch*heads, 1, key_tokens] ->
792
+ # [batch*heads, query_tokens, key_tokens]
793
+ # so that it can be added as a bias onto the attention scores that xformers computes:
794
+ # [batch*heads, query_tokens, key_tokens]
795
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
796
+ _, query_tokens, _ = hidden_states.shape
797
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
798
+
799
+ if attn.group_norm is not None:
800
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
801
+
802
+ query = attn.to_q(hidden_states)
803
+
804
+ if encoder_hidden_states is None:
805
+ encoder_hidden_states = hidden_states
806
+ elif attn.norm_cross:
807
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
808
+
809
+ key_raw = attn.to_k(encoder_hidden_states)
810
+ value_raw = attn.to_v(encoder_hidden_states)
811
+
812
+ # print('query', query.shape, 'key', key.shape, 'value', value.shape)
813
+ #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320])
814
+ # pdb.set_trace()
815
+ # multi-view self-attention
816
+ if multiview_attention:
817
+ if not sparse_mv_attention:
818
+ key = my_repeat(rearrange(key_raw, "(b t) d c -> b (t d) c", t=num_views), num_views)
819
+ value = my_repeat(rearrange(value_raw, "(b t) d c -> b (t d) c", t=num_views), num_views)
820
+ else:
821
+ key_front = my_repeat(rearrange(key_raw, "(b t) d c -> b t d c", t=num_views)[:, 0, :, :], num_views) # [(b t), d, c]
822
+ value_front = my_repeat(rearrange(value_raw, "(b t) d c -> b t d c", t=num_views)[:, 0, :, :], num_views)
823
+ key = torch.cat([key_front, key_raw], dim=1) # shape (b t) (2 d) c
824
+ value = torch.cat([value_front, value_raw], dim=1)
825
+
826
+ if mvcd_attention:
827
+ # memory efficient, cross domain attention
828
+ key_0, key_1 = torch.chunk(key_raw, dim=0, chunks=2) # keys shape (b t) d c
829
+ value_0, value_1 = torch.chunk(value_raw, dim=0, chunks=2)
830
+ key_cross = torch.concat([key_1, key_0], dim=0)
831
+ value_cross = torch.concat([value_1, value_0], dim=0) # shape (b t) d c
832
+ key = torch.cat([key, key_cross], dim=1)
833
+ value = torch.cat([value, value_cross], dim=1) # shape (b t) (t+1 d) c
834
+ else:
835
+ # print("don't use multiview attention.")
836
+ key = key_raw
837
+ value = value_raw
838
+
839
+ query = attn.head_to_batch_dim(query)
840
+ key = attn.head_to_batch_dim(key)
841
+ value = attn.head_to_batch_dim(value)
842
+
843
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
844
+ hidden_states = attn.batch_to_head_dim(hidden_states)
845
+
846
+ # linear proj
847
+ hidden_states = attn.to_out[0](hidden_states)
848
+ # dropout
849
+ hidden_states = attn.to_out[1](hidden_states)
850
+
851
+ if input_ndim == 4:
852
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
853
+
854
+ if attn.residual_connection:
855
+ hidden_states = hidden_states + residual
856
+
857
+ hidden_states = hidden_states / attn.rescale_output_factor
858
+
859
+ return hidden_states
860
+
861
+
862
+
863
+ class XFormersJointAttnProcessor:
864
+ r"""
865
+ Default processor for performing attention-related computations.
866
+ """
867
+
868
+ def __call__(
869
+ self,
870
+ attn: Attention,
871
+ hidden_states,
872
+ encoder_hidden_states=None,
873
+ attention_mask=None,
874
+ temb=None,
875
+ num_tasks=2
876
+ ):
877
+
878
+ residual = hidden_states
879
+
880
+ if attn.spatial_norm is not None:
881
+ hidden_states = attn.spatial_norm(hidden_states, temb)
882
+
883
+ input_ndim = hidden_states.ndim
884
+
885
+ if input_ndim == 4:
886
+ batch_size, channel, height, width = hidden_states.shape
887
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
888
+
889
+ batch_size, sequence_length, _ = (
890
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
891
+ )
892
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
893
+
894
+ # from yuancheng; here attention_mask is None
895
+ if attention_mask is not None:
896
+ # expand our mask's singleton query_tokens dimension:
897
+ # [batch*heads, 1, key_tokens] ->
898
+ # [batch*heads, query_tokens, key_tokens]
899
+ # so that it can be added as a bias onto the attention scores that xformers computes:
900
+ # [batch*heads, query_tokens, key_tokens]
901
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
902
+ _, query_tokens, _ = hidden_states.shape
903
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
904
+
905
+ if attn.group_norm is not None:
906
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
907
+
908
+ query = attn.to_q(hidden_states)
909
+
910
+ if encoder_hidden_states is None:
911
+ encoder_hidden_states = hidden_states
912
+ elif attn.norm_cross:
913
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
914
+
915
+ key = attn.to_k(encoder_hidden_states)
916
+ value = attn.to_v(encoder_hidden_states)
917
+
918
+ assert num_tasks == 2 # only support two tasks now
919
+
920
+ key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
921
+ value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
922
+ key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
923
+ value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
924
+ key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c
925
+ value = torch.cat([value]*2, dim=0) # (2 b t) 2d c
926
+
927
+
928
+ query = attn.head_to_batch_dim(query).contiguous()
929
+ key = attn.head_to_batch_dim(key).contiguous()
930
+ value = attn.head_to_batch_dim(value).contiguous()
931
+
932
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
933
+ hidden_states = attn.batch_to_head_dim(hidden_states)
934
+
935
+ # linear proj
936
+ hidden_states = attn.to_out[0](hidden_states)
937
+ # dropout
938
+ hidden_states = attn.to_out[1](hidden_states)
939
+
940
+ if input_ndim == 4:
941
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
942
+
943
+ if attn.residual_connection:
944
+ hidden_states = hidden_states + residual
945
+
946
+ hidden_states = hidden_states / attn.rescale_output_factor
947
+
948
+ return hidden_states
949
+
950
+
951
+ class JointAttnProcessor:
952
+ r"""
953
+ Default processor for performing attention-related computations.
954
+ """
955
+
956
+ def __call__(
957
+ self,
958
+ attn: Attention,
959
+ hidden_states,
960
+ encoder_hidden_states=None,
961
+ attention_mask=None,
962
+ temb=None,
963
+ num_tasks=2
964
+ ):
965
+
966
+ residual = hidden_states
967
+
968
+ if attn.spatial_norm is not None:
969
+ hidden_states = attn.spatial_norm(hidden_states, temb)
970
+
971
+ input_ndim = hidden_states.ndim
972
+
973
+ if input_ndim == 4:
974
+ batch_size, channel, height, width = hidden_states.shape
975
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
976
+
977
+ batch_size, sequence_length, _ = (
978
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
979
+ )
980
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
981
+
982
+
983
+ if attn.group_norm is not None:
984
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
985
+
986
+ query = attn.to_q(hidden_states)
987
+
988
+ if encoder_hidden_states is None:
989
+ encoder_hidden_states = hidden_states
990
+ elif attn.norm_cross:
991
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
992
+
993
+ key = attn.to_k(encoder_hidden_states)
994
+ value = attn.to_v(encoder_hidden_states)
995
+
996
+ assert num_tasks == 2 # only support two tasks now
997
+
998
+ key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
999
+ value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
1000
+ key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
1001
+ value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
1002
+ key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c
1003
+ value = torch.cat([value]*2, dim=0) # (2 b t) 2d c
1004
+
1005
+
1006
+ query = attn.head_to_batch_dim(query).contiguous()
1007
+ key = attn.head_to_batch_dim(key).contiguous()
1008
+ value = attn.head_to_batch_dim(value).contiguous()
1009
+
1010
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
1011
+ hidden_states = torch.bmm(attention_probs, value)
1012
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1013
+
1014
+ # linear proj
1015
+ hidden_states = attn.to_out[0](hidden_states)
1016
+ # dropout
1017
+ hidden_states = attn.to_out[1](hidden_states)
1018
+
1019
+ if input_ndim == 4:
1020
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1021
+
1022
+ if attn.residual_connection:
1023
+ hidden_states = hidden_states + residual
1024
+
1025
+ hidden_states = hidden_states / attn.rescale_output_factor
1026
+
1027
+ return hidden_states
1028
+
1029
+
mvdiffusion/models/transformer_mv2d_rowwise.py ADDED
@@ -0,0 +1,978 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, Optional
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
23
+ from diffusers.utils import BaseOutput, deprecate
24
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
25
+ from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero, Attention
26
+ from diffusers.models.embeddings import PatchEmbed
27
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
28
+ from diffusers.models.modeling_utils import ModelMixin
29
+ from diffusers.utils.import_utils import is_xformers_available
30
+
31
+ from einops import rearrange
32
+ import pdb
33
+ import random
34
+ import math
35
+
36
+
37
+ if is_xformers_available():
38
+ import xformers
39
+ import xformers.ops
40
+ else:
41
+ xformers = None
42
+
43
+
44
+ @dataclass
45
+ class TransformerMV2DModelOutput(BaseOutput):
46
+ """
47
+ The output of [`Transformer2DModel`].
48
+
49
+ Args:
50
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
51
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
52
+ distributions for the unnoised latent pixels.
53
+ """
54
+
55
+ sample: torch.FloatTensor
56
+
57
+
58
+ class TransformerMV2DModel(ModelMixin, ConfigMixin):
59
+ """
60
+ A 2D Transformer model for image-like data.
61
+
62
+ Parameters:
63
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
64
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
65
+ in_channels (`int`, *optional*):
66
+ The number of channels in the input and output (specify if the input is **continuous**).
67
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
68
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
69
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
70
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
71
+ This is fixed during training since it is used to learn a number of position embeddings.
72
+ num_vector_embeds (`int`, *optional*):
73
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
74
+ Includes the class for the masked latent pixel.
75
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
76
+ num_embeds_ada_norm ( `int`, *optional*):
77
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
78
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
79
+ added to the hidden states.
80
+
81
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
82
+ attention_bias (`bool`, *optional*):
83
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
84
+ """
85
+
86
+ @register_to_config
87
+ def __init__(
88
+ self,
89
+ num_attention_heads: int = 16,
90
+ attention_head_dim: int = 88,
91
+ in_channels: Optional[int] = None,
92
+ out_channels: Optional[int] = None,
93
+ num_layers: int = 1,
94
+ dropout: float = 0.0,
95
+ norm_num_groups: int = 32,
96
+ cross_attention_dim: Optional[int] = None,
97
+ attention_bias: bool = False,
98
+ sample_size: Optional[int] = None,
99
+ num_vector_embeds: Optional[int] = None,
100
+ patch_size: Optional[int] = None,
101
+ activation_fn: str = "geglu",
102
+ num_embeds_ada_norm: Optional[int] = None,
103
+ use_linear_projection: bool = False,
104
+ only_cross_attention: bool = False,
105
+ upcast_attention: bool = False,
106
+ norm_type: str = "layer_norm",
107
+ norm_elementwise_affine: bool = True,
108
+ num_views: int = 1,
109
+ cd_attention_last: bool=False,
110
+ cd_attention_mid: bool=False,
111
+ multiview_attention: bool=True,
112
+ sparse_mv_attention: bool = True, # not used
113
+ mvcd_attention: bool=False
114
+ ):
115
+ super().__init__()
116
+ self.use_linear_projection = use_linear_projection
117
+ self.num_attention_heads = num_attention_heads
118
+ self.attention_head_dim = attention_head_dim
119
+ inner_dim = num_attention_heads * attention_head_dim
120
+
121
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
122
+ # Define whether input is continuous or discrete depending on configuration
123
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
124
+ self.is_input_vectorized = num_vector_embeds is not None
125
+ self.is_input_patches = in_channels is not None and patch_size is not None
126
+
127
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
128
+ deprecation_message = (
129
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
130
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
131
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
132
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
133
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
134
+ )
135
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
136
+ norm_type = "ada_norm"
137
+
138
+ if self.is_input_continuous and self.is_input_vectorized:
139
+ raise ValueError(
140
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
141
+ " sure that either `in_channels` or `num_vector_embeds` is None."
142
+ )
143
+ elif self.is_input_vectorized and self.is_input_patches:
144
+ raise ValueError(
145
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
146
+ " sure that either `num_vector_embeds` or `num_patches` is None."
147
+ )
148
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
149
+ raise ValueError(
150
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
151
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
152
+ )
153
+
154
+ # 2. Define input layers
155
+ if self.is_input_continuous:
156
+ self.in_channels = in_channels
157
+
158
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
159
+ if use_linear_projection:
160
+ self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
161
+ else:
162
+ self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
163
+ elif self.is_input_vectorized:
164
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
165
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
166
+
167
+ self.height = sample_size
168
+ self.width = sample_size
169
+ self.num_vector_embeds = num_vector_embeds
170
+ self.num_latent_pixels = self.height * self.width
171
+
172
+ self.latent_image_embedding = ImagePositionalEmbeddings(
173
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
174
+ )
175
+ elif self.is_input_patches:
176
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
177
+
178
+ self.height = sample_size
179
+ self.width = sample_size
180
+
181
+ self.patch_size = patch_size
182
+ self.pos_embed = PatchEmbed(
183
+ height=sample_size,
184
+ width=sample_size,
185
+ patch_size=patch_size,
186
+ in_channels=in_channels,
187
+ embed_dim=inner_dim,
188
+ )
189
+
190
+ # 3. Define transformers blocks
191
+ self.transformer_blocks = nn.ModuleList(
192
+ [
193
+ BasicMVTransformerBlock(
194
+ inner_dim,
195
+ num_attention_heads,
196
+ attention_head_dim,
197
+ dropout=dropout,
198
+ cross_attention_dim=cross_attention_dim,
199
+ activation_fn=activation_fn,
200
+ num_embeds_ada_norm=num_embeds_ada_norm,
201
+ attention_bias=attention_bias,
202
+ only_cross_attention=only_cross_attention,
203
+ upcast_attention=upcast_attention,
204
+ norm_type=norm_type,
205
+ norm_elementwise_affine=norm_elementwise_affine,
206
+ num_views=num_views,
207
+ cd_attention_last=cd_attention_last,
208
+ cd_attention_mid=cd_attention_mid,
209
+ multiview_attention=multiview_attention,
210
+ mvcd_attention=mvcd_attention
211
+ )
212
+ for d in range(num_layers)
213
+ ]
214
+ )
215
+
216
+ # 4. Define output layers
217
+ self.out_channels = in_channels if out_channels is None else out_channels
218
+ if self.is_input_continuous:
219
+ # TODO: should use out_channels for continuous projections
220
+ if use_linear_projection:
221
+ self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
222
+ else:
223
+ self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
224
+ elif self.is_input_vectorized:
225
+ self.norm_out = nn.LayerNorm(inner_dim)
226
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
227
+ elif self.is_input_patches:
228
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
229
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
230
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
231
+
232
+ def forward(
233
+ self,
234
+ hidden_states: torch.Tensor,
235
+ encoder_hidden_states: Optional[torch.Tensor] = None,
236
+ timestep: Optional[torch.LongTensor] = None,
237
+ class_labels: Optional[torch.LongTensor] = None,
238
+ cross_attention_kwargs: Dict[str, Any] = None,
239
+ attention_mask: Optional[torch.Tensor] = None,
240
+ encoder_attention_mask: Optional[torch.Tensor] = None,
241
+ return_dict: bool = True,
242
+ ):
243
+ """
244
+ The [`Transformer2DModel`] forward method.
245
+
246
+ Args:
247
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
248
+ Input `hidden_states`.
249
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
250
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
251
+ self-attention.
252
+ timestep ( `torch.LongTensor`, *optional*):
253
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
254
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
255
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
256
+ `AdaLayerZeroNorm`.
257
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
258
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
259
+
260
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
261
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
262
+
263
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
264
+ above. This bias will be added to the cross-attention scores.
265
+ return_dict (`bool`, *optional*, defaults to `True`):
266
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
267
+ tuple.
268
+
269
+ Returns:
270
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
271
+ `tuple` where the first element is the sample tensor.
272
+ """
273
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
274
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
275
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
276
+ # expects mask of shape:
277
+ # [batch, key_tokens]
278
+ # adds singleton query_tokens dimension:
279
+ # [batch, 1, key_tokens]
280
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
281
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
282
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
283
+ if attention_mask is not None and attention_mask.ndim == 2:
284
+ # assume that mask is expressed as:
285
+ # (1 = keep, 0 = discard)
286
+ # convert mask into a bias that can be added to attention scores:
287
+ # (keep = +0, discard = -10000.0)
288
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
289
+ attention_mask = attention_mask.unsqueeze(1)
290
+
291
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
292
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
293
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
294
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
295
+
296
+ # 1. Input
297
+ if self.is_input_continuous:
298
+ batch, _, height, width = hidden_states.shape
299
+ residual = hidden_states
300
+
301
+ hidden_states = self.norm(hidden_states)
302
+ if not self.use_linear_projection:
303
+ hidden_states = self.proj_in(hidden_states)
304
+ inner_dim = hidden_states.shape[1]
305
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
306
+ else:
307
+ inner_dim = hidden_states.shape[1]
308
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
309
+ hidden_states = self.proj_in(hidden_states)
310
+ elif self.is_input_vectorized:
311
+ hidden_states = self.latent_image_embedding(hidden_states)
312
+ elif self.is_input_patches:
313
+ hidden_states = self.pos_embed(hidden_states)
314
+
315
+ # 2. Blocks
316
+ for block in self.transformer_blocks:
317
+ hidden_states = block(
318
+ hidden_states,
319
+ attention_mask=attention_mask,
320
+ encoder_hidden_states=encoder_hidden_states,
321
+ encoder_attention_mask=encoder_attention_mask,
322
+ timestep=timestep,
323
+ cross_attention_kwargs=cross_attention_kwargs,
324
+ class_labels=class_labels,
325
+ )
326
+
327
+ # 3. Output
328
+ if self.is_input_continuous:
329
+ if not self.use_linear_projection:
330
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
331
+ hidden_states = self.proj_out(hidden_states)
332
+ else:
333
+ hidden_states = self.proj_out(hidden_states)
334
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
335
+
336
+ output = hidden_states + residual
337
+ elif self.is_input_vectorized:
338
+ hidden_states = self.norm_out(hidden_states)
339
+ logits = self.out(hidden_states)
340
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
341
+ logits = logits.permute(0, 2, 1)
342
+
343
+ # log(p(x_0))
344
+ output = F.log_softmax(logits.double(), dim=1).float()
345
+ elif self.is_input_patches:
346
+ # TODO: cleanup!
347
+ conditioning = self.transformer_blocks[0].norm1.emb(
348
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
349
+ )
350
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
351
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
352
+ hidden_states = self.proj_out_2(hidden_states)
353
+
354
+ # unpatchify
355
+ height = width = int(hidden_states.shape[1] ** 0.5)
356
+ hidden_states = hidden_states.reshape(
357
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
358
+ )
359
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
360
+ output = hidden_states.reshape(
361
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
362
+ )
363
+
364
+ if not return_dict:
365
+ return (output,)
366
+
367
+ return TransformerMV2DModelOutput(sample=output)
368
+
369
+
370
+ @maybe_allow_in_graph
371
+ class BasicMVTransformerBlock(nn.Module):
372
+ r"""
373
+ A basic Transformer block.
374
+
375
+ Parameters:
376
+ dim (`int`): The number of channels in the input and output.
377
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
378
+ attention_head_dim (`int`): The number of channels in each head.
379
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
380
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
381
+ only_cross_attention (`bool`, *optional*):
382
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
383
+ double_self_attention (`bool`, *optional*):
384
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
385
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
386
+ num_embeds_ada_norm (:
387
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
388
+ attention_bias (:
389
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
390
+ """
391
+
392
+ def __init__(
393
+ self,
394
+ dim: int,
395
+ num_attention_heads: int,
396
+ attention_head_dim: int,
397
+ dropout=0.0,
398
+ cross_attention_dim: Optional[int] = None,
399
+ activation_fn: str = "geglu",
400
+ num_embeds_ada_norm: Optional[int] = None,
401
+ attention_bias: bool = False,
402
+ only_cross_attention: bool = False,
403
+ double_self_attention: bool = False,
404
+ upcast_attention: bool = False,
405
+ norm_elementwise_affine: bool = True,
406
+ norm_type: str = "layer_norm",
407
+ final_dropout: bool = False,
408
+ num_views: int = 1,
409
+ cd_attention_last: bool = False,
410
+ cd_attention_mid: bool = False,
411
+ multiview_attention: bool = True,
412
+ mvcd_attention: bool = False,
413
+ rowwise_attention: bool = True
414
+ ):
415
+ super().__init__()
416
+ self.only_cross_attention = only_cross_attention
417
+
418
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
419
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
420
+
421
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
422
+ raise ValueError(
423
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
424
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
425
+ )
426
+
427
+ # Define 3 blocks. Each block has its own normalization layer.
428
+ # 1. Self-Attn
429
+ if self.use_ada_layer_norm:
430
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
431
+ elif self.use_ada_layer_norm_zero:
432
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
433
+ else:
434
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
435
+
436
+ self.multiview_attention = multiview_attention
437
+ self.mvcd_attention = mvcd_attention
438
+ self.rowwise_attention = multiview_attention and rowwise_attention
439
+
440
+ # rowwise multiview attention
441
+
442
+ print('INFO: using row wise attention...')
443
+
444
+ self.attn1 = CustomAttention(
445
+ query_dim=dim,
446
+ heads=num_attention_heads,
447
+ dim_head=attention_head_dim,
448
+ dropout=dropout,
449
+ bias=attention_bias,
450
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
451
+ upcast_attention=upcast_attention,
452
+ processor=MVAttnProcessor()
453
+ )
454
+
455
+ # 2. Cross-Attn
456
+ if cross_attention_dim is not None or double_self_attention:
457
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
458
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
459
+ # the second cross attention block.
460
+ self.norm2 = (
461
+ AdaLayerNorm(dim, num_embeds_ada_norm)
462
+ if self.use_ada_layer_norm
463
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
464
+ )
465
+ self.attn2 = Attention(
466
+ query_dim=dim,
467
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
468
+ heads=num_attention_heads,
469
+ dim_head=attention_head_dim,
470
+ dropout=dropout,
471
+ bias=attention_bias,
472
+ upcast_attention=upcast_attention,
473
+ ) # is self-attn if encoder_hidden_states is none
474
+ else:
475
+ self.norm2 = None
476
+ self.attn2 = None
477
+
478
+ # 3. Feed-forward
479
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
480
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
481
+
482
+ # let chunk size default to None
483
+ self._chunk_size = None
484
+ self._chunk_dim = 0
485
+
486
+ self.num_views = num_views
487
+
488
+ self.cd_attention_last = cd_attention_last
489
+
490
+ if self.cd_attention_last:
491
+ # Joint task -Attn
492
+ self.attn_joint = CustomJointAttention(
493
+ query_dim=dim,
494
+ heads=num_attention_heads,
495
+ dim_head=attention_head_dim,
496
+ dropout=dropout,
497
+ bias=attention_bias,
498
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
499
+ upcast_attention=upcast_attention,
500
+ processor=JointAttnProcessor()
501
+ )
502
+ nn.init.zeros_(self.attn_joint.to_out[0].weight.data)
503
+ self.norm_joint = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
504
+
505
+
506
+ self.cd_attention_mid = cd_attention_mid
507
+
508
+ if self.cd_attention_mid:
509
+ print("joint twice")
510
+ # Joint task -Attn
511
+ self.attn_joint_twice = CustomJointAttention(
512
+ query_dim=dim,
513
+ heads=num_attention_heads,
514
+ dim_head=attention_head_dim,
515
+ dropout=dropout,
516
+ bias=attention_bias,
517
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
518
+ upcast_attention=upcast_attention,
519
+ processor=JointAttnProcessor()
520
+ )
521
+ nn.init.zeros_(self.attn_joint_twice.to_out[0].weight.data)
522
+ self.norm_joint_twice = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
523
+
524
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
525
+ # Sets chunk feed-forward
526
+ self._chunk_size = chunk_size
527
+ self._chunk_dim = dim
528
+
529
+ def forward(
530
+ self,
531
+ hidden_states: torch.FloatTensor,
532
+ attention_mask: Optional[torch.FloatTensor] = None,
533
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
534
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
535
+ timestep: Optional[torch.LongTensor] = None,
536
+ cross_attention_kwargs: Dict[str, Any] = None,
537
+ class_labels: Optional[torch.LongTensor] = None,
538
+ ):
539
+ assert attention_mask is None # not supported yet
540
+ # Notice that normalization is always applied before the real computation in the following blocks.
541
+ # 1. Self-Attention
542
+ if self.use_ada_layer_norm:
543
+ norm_hidden_states = self.norm1(hidden_states, timestep)
544
+ elif self.use_ada_layer_norm_zero:
545
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
546
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
547
+ )
548
+ else:
549
+ norm_hidden_states = self.norm1(hidden_states)
550
+
551
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
552
+
553
+ attn_output = self.attn1(
554
+ norm_hidden_states,
555
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
556
+ attention_mask=attention_mask,
557
+ multiview_attention=self.multiview_attention,
558
+ mvcd_attention=self.mvcd_attention,
559
+ num_views=self.num_views,
560
+ **cross_attention_kwargs,
561
+ )
562
+
563
+ if self.use_ada_layer_norm_zero:
564
+ attn_output = gate_msa.unsqueeze(1) * attn_output
565
+ hidden_states = attn_output + hidden_states
566
+
567
+ # joint attention twice
568
+ if self.cd_attention_mid:
569
+ norm_hidden_states = (
570
+ self.norm_joint_twice(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_twice(hidden_states)
571
+ )
572
+ hidden_states = self.attn_joint_twice(norm_hidden_states) + hidden_states
573
+
574
+ # 2. Cross-Attention
575
+ if self.attn2 is not None:
576
+ norm_hidden_states = (
577
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
578
+ )
579
+
580
+ attn_output = self.attn2(
581
+ norm_hidden_states,
582
+ encoder_hidden_states=encoder_hidden_states,
583
+ attention_mask=encoder_attention_mask,
584
+ **cross_attention_kwargs,
585
+ )
586
+ hidden_states = attn_output + hidden_states
587
+
588
+ # 3. Feed-forward
589
+ norm_hidden_states = self.norm3(hidden_states)
590
+
591
+ if self.use_ada_layer_norm_zero:
592
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
593
+
594
+ if self._chunk_size is not None:
595
+ # "feed_forward_chunk_size" can be used to save memory
596
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
597
+ raise ValueError(
598
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
599
+ )
600
+
601
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
602
+ ff_output = torch.cat(
603
+ [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
604
+ dim=self._chunk_dim,
605
+ )
606
+ else:
607
+ ff_output = self.ff(norm_hidden_states)
608
+
609
+ if self.use_ada_layer_norm_zero:
610
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
611
+
612
+ hidden_states = ff_output + hidden_states
613
+
614
+ if self.cd_attention_last:
615
+ norm_hidden_states = (
616
+ self.norm_joint(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint(hidden_states)
617
+ )
618
+ hidden_states = self.attn_joint(norm_hidden_states) + hidden_states
619
+
620
+ return hidden_states
621
+
622
+
623
+ class CustomAttention(Attention):
624
+ def set_use_memory_efficient_attention_xformers(
625
+ self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
626
+ ):
627
+ processor = XFormersMVAttnProcessor()
628
+ self.set_processor(processor)
629
+ # print("using xformers attention processor")
630
+
631
+
632
+ class CustomJointAttention(Attention):
633
+ def set_use_memory_efficient_attention_xformers(
634
+ self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
635
+ ):
636
+ processor = XFormersJointAttnProcessor()
637
+ self.set_processor(processor)
638
+ # print("using xformers attention processor")
639
+
640
+ class MVAttnProcessor:
641
+ r"""
642
+ Default processor for performing attention-related computations.
643
+ """
644
+
645
+ def __call__(
646
+ self,
647
+ attn: Attention,
648
+ hidden_states,
649
+ encoder_hidden_states=None,
650
+ attention_mask=None,
651
+ temb=None,
652
+ num_views=1,
653
+ multiview_attention=True
654
+ ):
655
+ residual = hidden_states
656
+
657
+ if attn.spatial_norm is not None:
658
+ hidden_states = attn.spatial_norm(hidden_states, temb)
659
+
660
+ input_ndim = hidden_states.ndim
661
+
662
+ if input_ndim == 4:
663
+ batch_size, channel, height, width = hidden_states.shape
664
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
665
+
666
+ batch_size, sequence_length, _ = (
667
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
668
+ )
669
+ height = int(math.sqrt(sequence_length))
670
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
671
+
672
+ if attn.group_norm is not None:
673
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
674
+
675
+ query = attn.to_q(hidden_states)
676
+
677
+ if encoder_hidden_states is None:
678
+ encoder_hidden_states = hidden_states
679
+ elif attn.norm_cross:
680
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
681
+
682
+ key = attn.to_k(encoder_hidden_states)
683
+ value = attn.to_v(encoder_hidden_states)
684
+
685
+ # print('query', query.shape, 'key', key.shape, 'value', value.shape)
686
+ #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320])
687
+ # pdb.set_trace()
688
+ # multi-view self-attention
689
+ key = rearrange(key, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
690
+ value = rearrange(value, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
691
+ query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) # torch.Size([192, 384, 320])
692
+
693
+ query = attn.head_to_batch_dim(query).contiguous()
694
+ key = attn.head_to_batch_dim(key).contiguous()
695
+ value = attn.head_to_batch_dim(value).contiguous()
696
+
697
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
698
+ hidden_states = torch.bmm(attention_probs, value)
699
+ hidden_states = attn.batch_to_head_dim(hidden_states)
700
+
701
+ # linear proj
702
+ hidden_states = attn.to_out[0](hidden_states)
703
+ # dropout
704
+ hidden_states = attn.to_out[1](hidden_states)
705
+ hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height)
706
+ if input_ndim == 4:
707
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
708
+
709
+ if attn.residual_connection:
710
+ hidden_states = hidden_states + residual
711
+
712
+ hidden_states = hidden_states / attn.rescale_output_factor
713
+
714
+ return hidden_states
715
+
716
+
717
+ class XFormersMVAttnProcessor:
718
+ r"""
719
+ Default processor for performing attention-related computations.
720
+ """
721
+
722
+ def __call__(
723
+ self,
724
+ attn: Attention,
725
+ hidden_states,
726
+ encoder_hidden_states=None,
727
+ attention_mask=None,
728
+ temb=None,
729
+ num_views=1,
730
+ multiview_attention=True,
731
+ mvcd_attention=False,
732
+ ):
733
+ residual = hidden_states
734
+
735
+ if attn.spatial_norm is not None:
736
+ hidden_states = attn.spatial_norm(hidden_states, temb)
737
+
738
+ input_ndim = hidden_states.ndim
739
+
740
+ if input_ndim == 4:
741
+ batch_size, channel, height, width = hidden_states.shape
742
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
743
+
744
+ batch_size, sequence_length, _ = (
745
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
746
+ )
747
+ height = int(math.sqrt(sequence_length))
748
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
749
+ # from yuancheng; here attention_mask is None
750
+ if attention_mask is not None:
751
+ # expand our mask's singleton query_tokens dimension:
752
+ # [batch*heads, 1, key_tokens] ->
753
+ # [batch*heads, query_tokens, key_tokens]
754
+ # so that it can be added as a bias onto the attention scores that xformers computes:
755
+ # [batch*heads, query_tokens, key_tokens]
756
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
757
+ _, query_tokens, _ = hidden_states.shape
758
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
759
+
760
+ if attn.group_norm is not None:
761
+ print('Warning: using group norm, pay attention to use it in row-wise attention')
762
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
763
+
764
+ query = attn.to_q(hidden_states)
765
+
766
+ if encoder_hidden_states is None:
767
+ encoder_hidden_states = hidden_states
768
+ elif attn.norm_cross:
769
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
770
+
771
+ key_raw = attn.to_k(encoder_hidden_states)
772
+ value_raw = attn.to_v(encoder_hidden_states)
773
+
774
+ # print('query', query.shape, 'key', key.shape, 'value', value.shape)
775
+ # pdb.set_trace()
776
+
777
+ key = rearrange(key_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
778
+ value = rearrange(value_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
779
+ query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) # torch.Size([192, 384, 320])
780
+ if mvcd_attention:
781
+ # memory efficient, cross domain attention
782
+ key_0, key_1 = torch.chunk(key_raw, dim=0, chunks=2) # keys shape (b t) d c
783
+ value_0, value_1 = torch.chunk(value_raw, dim=0, chunks=2)
784
+ key_cross = torch.concat([key_1, key_0], dim=0)
785
+ value_cross = torch.concat([value_1, value_0], dim=0) # shape (b t) d c
786
+ key = torch.cat([key, key_cross], dim=1)
787
+ value = torch.cat([value, value_cross], dim=1) # shape (b t) (t+1 d) c
788
+
789
+
790
+ query = attn.head_to_batch_dim(query) # torch.Size([960, 384, 64])
791
+ key = attn.head_to_batch_dim(key)
792
+ value = attn.head_to_batch_dim(value)
793
+
794
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
795
+ hidden_states = attn.batch_to_head_dim(hidden_states)
796
+
797
+ # linear proj
798
+ hidden_states = attn.to_out[0](hidden_states)
799
+ # dropout
800
+ hidden_states = attn.to_out[1](hidden_states)
801
+ # print(hidden_states.shape)
802
+ hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height)
803
+ if input_ndim == 4:
804
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
805
+
806
+ if attn.residual_connection:
807
+ hidden_states = hidden_states + residual
808
+
809
+ hidden_states = hidden_states / attn.rescale_output_factor
810
+
811
+ return hidden_states
812
+
813
+
814
+ class XFormersJointAttnProcessor:
815
+ r"""
816
+ Default processor for performing attention-related computations.
817
+ """
818
+
819
+ def __call__(
820
+ self,
821
+ attn: Attention,
822
+ hidden_states,
823
+ encoder_hidden_states=None,
824
+ attention_mask=None,
825
+ temb=None,
826
+ num_tasks=2
827
+ ):
828
+
829
+ residual = hidden_states
830
+
831
+ if attn.spatial_norm is not None:
832
+ hidden_states = attn.spatial_norm(hidden_states, temb)
833
+
834
+ input_ndim = hidden_states.ndim
835
+
836
+ if input_ndim == 4:
837
+ batch_size, channel, height, width = hidden_states.shape
838
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
839
+
840
+ batch_size, sequence_length, _ = (
841
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
842
+ )
843
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
844
+
845
+ # from yuancheng; here attention_mask is None
846
+ if attention_mask is not None:
847
+ # expand our mask's singleton query_tokens dimension:
848
+ # [batch*heads, 1, key_tokens] ->
849
+ # [batch*heads, query_tokens, key_tokens]
850
+ # so that it can be added as a bias onto the attention scores that xformers computes:
851
+ # [batch*heads, query_tokens, key_tokens]
852
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
853
+ _, query_tokens, _ = hidden_states.shape
854
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
855
+
856
+ if attn.group_norm is not None:
857
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
858
+
859
+ query = attn.to_q(hidden_states)
860
+
861
+ if encoder_hidden_states is None:
862
+ encoder_hidden_states = hidden_states
863
+ elif attn.norm_cross:
864
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
865
+
866
+ key = attn.to_k(encoder_hidden_states)
867
+ value = attn.to_v(encoder_hidden_states)
868
+
869
+ assert num_tasks == 2 # only support two tasks now
870
+
871
+ key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
872
+ value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
873
+ key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
874
+ value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
875
+ key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c
876
+ value = torch.cat([value]*2, dim=0) # (2 b t) 2d c
877
+
878
+
879
+ query = attn.head_to_batch_dim(query).contiguous()
880
+ key = attn.head_to_batch_dim(key).contiguous()
881
+ value = attn.head_to_batch_dim(value).contiguous()
882
+
883
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
884
+ hidden_states = attn.batch_to_head_dim(hidden_states)
885
+
886
+ # linear proj
887
+ hidden_states = attn.to_out[0](hidden_states)
888
+ # dropout
889
+ hidden_states = attn.to_out[1](hidden_states)
890
+
891
+ if input_ndim == 4:
892
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
893
+
894
+ if attn.residual_connection:
895
+ hidden_states = hidden_states + residual
896
+
897
+ hidden_states = hidden_states / attn.rescale_output_factor
898
+
899
+ return hidden_states
900
+
901
+
902
+ class JointAttnProcessor:
903
+ r"""
904
+ Default processor for performing attention-related computations.
905
+ """
906
+
907
+ def __call__(
908
+ self,
909
+ attn: Attention,
910
+ hidden_states,
911
+ encoder_hidden_states=None,
912
+ attention_mask=None,
913
+ temb=None,
914
+ num_tasks=2
915
+ ):
916
+
917
+ residual = hidden_states
918
+
919
+ if attn.spatial_norm is not None:
920
+ hidden_states = attn.spatial_norm(hidden_states, temb)
921
+
922
+ input_ndim = hidden_states.ndim
923
+
924
+ if input_ndim == 4:
925
+ batch_size, channel, height, width = hidden_states.shape
926
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
927
+
928
+ batch_size, sequence_length, _ = (
929
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
930
+ )
931
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
932
+
933
+
934
+ if attn.group_norm is not None:
935
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
936
+
937
+ query = attn.to_q(hidden_states)
938
+
939
+ if encoder_hidden_states is None:
940
+ encoder_hidden_states = hidden_states
941
+ elif attn.norm_cross:
942
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
943
+
944
+ key = attn.to_k(encoder_hidden_states)
945
+ value = attn.to_v(encoder_hidden_states)
946
+
947
+ assert num_tasks == 2 # only support two tasks now
948
+
949
+ key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
950
+ value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
951
+ key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
952
+ value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
953
+ key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c
954
+ value = torch.cat([value]*2, dim=0) # (2 b t) 2d c
955
+
956
+
957
+ query = attn.head_to_batch_dim(query).contiguous()
958
+ key = attn.head_to_batch_dim(key).contiguous()
959
+ value = attn.head_to_batch_dim(value).contiguous()
960
+
961
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
962
+ hidden_states = torch.bmm(attention_probs, value)
963
+ hidden_states = attn.batch_to_head_dim(hidden_states)
964
+
965
+ # linear proj
966
+ hidden_states = attn.to_out[0](hidden_states)
967
+ # dropout
968
+ hidden_states = attn.to_out[1](hidden_states)
969
+
970
+ if input_ndim == 4:
971
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
972
+
973
+ if attn.residual_connection:
974
+ hidden_states = hidden_states + residual
975
+
976
+ hidden_states = hidden_states / attn.rescale_output_factor
977
+
978
+ return hidden_states
mvdiffusion/models/transformer_mv2d_self_rowwise.py ADDED
@@ -0,0 +1,1038 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, Optional
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
23
+ from diffusers.utils import BaseOutput, deprecate
24
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
25
+ from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero, Attention
26
+ from diffusers.models.embeddings import PatchEmbed
27
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
28
+ from diffusers.models.modeling_utils import ModelMixin
29
+ from diffusers.utils.import_utils import is_xformers_available
30
+
31
+ from einops import rearrange
32
+ import pdb
33
+ import random
34
+ import math
35
+
36
+
37
+ if is_xformers_available():
38
+ import xformers
39
+ import xformers.ops
40
+ else:
41
+ xformers = None
42
+
43
+
44
+ @dataclass
45
+ class TransformerMV2DModelOutput(BaseOutput):
46
+ """
47
+ The output of [`Transformer2DModel`].
48
+
49
+ Args:
50
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
51
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
52
+ distributions for the unnoised latent pixels.
53
+ """
54
+
55
+ sample: torch.FloatTensor
56
+
57
+
58
+ class TransformerMV2DModel(ModelMixin, ConfigMixin):
59
+ """
60
+ A 2D Transformer model for image-like data.
61
+
62
+ Parameters:
63
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
64
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
65
+ in_channels (`int`, *optional*):
66
+ The number of channels in the input and output (specify if the input is **continuous**).
67
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
68
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
69
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
70
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
71
+ This is fixed during training since it is used to learn a number of position embeddings.
72
+ num_vector_embeds (`int`, *optional*):
73
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
74
+ Includes the class for the masked latent pixel.
75
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
76
+ num_embeds_ada_norm ( `int`, *optional*):
77
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
78
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
79
+ added to the hidden states.
80
+
81
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
82
+ attention_bias (`bool`, *optional*):
83
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
84
+ """
85
+
86
+ @register_to_config
87
+ def __init__(
88
+ self,
89
+ num_attention_heads: int = 16,
90
+ attention_head_dim: int = 88,
91
+ in_channels: Optional[int] = None,
92
+ out_channels: Optional[int] = None,
93
+ num_layers: int = 1,
94
+ dropout: float = 0.0,
95
+ norm_num_groups: int = 32,
96
+ cross_attention_dim: Optional[int] = None,
97
+ attention_bias: bool = False,
98
+ sample_size: Optional[int] = None,
99
+ num_vector_embeds: Optional[int] = None,
100
+ patch_size: Optional[int] = None,
101
+ activation_fn: str = "geglu",
102
+ num_embeds_ada_norm: Optional[int] = None,
103
+ use_linear_projection: bool = False,
104
+ only_cross_attention: bool = False,
105
+ upcast_attention: bool = False,
106
+ norm_type: str = "layer_norm",
107
+ norm_elementwise_affine: bool = True,
108
+ num_views: int = 1,
109
+ cd_attention_mid: bool=False,
110
+ cd_attention_last: bool=False,
111
+ multiview_attention: bool=True,
112
+ sparse_mv_attention: bool = True, # not used
113
+ mvcd_attention: bool=False,
114
+ use_dino: bool=False
115
+ ):
116
+ super().__init__()
117
+ self.use_linear_projection = use_linear_projection
118
+ self.num_attention_heads = num_attention_heads
119
+ self.attention_head_dim = attention_head_dim
120
+ inner_dim = num_attention_heads * attention_head_dim
121
+
122
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
123
+ # Define whether input is continuous or discrete depending on configuration
124
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
125
+ self.is_input_vectorized = num_vector_embeds is not None
126
+ self.is_input_patches = in_channels is not None and patch_size is not None
127
+
128
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
129
+ deprecation_message = (
130
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
131
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
132
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
133
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
134
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
135
+ )
136
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
137
+ norm_type = "ada_norm"
138
+
139
+ if self.is_input_continuous and self.is_input_vectorized:
140
+ raise ValueError(
141
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
142
+ " sure that either `in_channels` or `num_vector_embeds` is None."
143
+ )
144
+ elif self.is_input_vectorized and self.is_input_patches:
145
+ raise ValueError(
146
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
147
+ " sure that either `num_vector_embeds` or `num_patches` is None."
148
+ )
149
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
150
+ raise ValueError(
151
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
152
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
153
+ )
154
+
155
+ # 2. Define input layers
156
+ if self.is_input_continuous:
157
+ self.in_channels = in_channels
158
+
159
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
160
+ if use_linear_projection:
161
+ self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
162
+ else:
163
+ self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
164
+ elif self.is_input_vectorized:
165
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
166
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
167
+
168
+ self.height = sample_size
169
+ self.width = sample_size
170
+ self.num_vector_embeds = num_vector_embeds
171
+ self.num_latent_pixels = self.height * self.width
172
+
173
+ self.latent_image_embedding = ImagePositionalEmbeddings(
174
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
175
+ )
176
+ elif self.is_input_patches:
177
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
178
+
179
+ self.height = sample_size
180
+ self.width = sample_size
181
+
182
+ self.patch_size = patch_size
183
+ self.pos_embed = PatchEmbed(
184
+ height=sample_size,
185
+ width=sample_size,
186
+ patch_size=patch_size,
187
+ in_channels=in_channels,
188
+ embed_dim=inner_dim,
189
+ )
190
+
191
+ # 3. Define transformers blocks
192
+ self.transformer_blocks = nn.ModuleList(
193
+ [
194
+ BasicMVTransformerBlock(
195
+ inner_dim,
196
+ num_attention_heads,
197
+ attention_head_dim,
198
+ dropout=dropout,
199
+ cross_attention_dim=cross_attention_dim,
200
+ activation_fn=activation_fn,
201
+ num_embeds_ada_norm=num_embeds_ada_norm,
202
+ attention_bias=attention_bias,
203
+ only_cross_attention=only_cross_attention,
204
+ upcast_attention=upcast_attention,
205
+ norm_type=norm_type,
206
+ norm_elementwise_affine=norm_elementwise_affine,
207
+ num_views=num_views,
208
+ cd_attention_last=cd_attention_last,
209
+ cd_attention_mid=cd_attention_mid,
210
+ multiview_attention=multiview_attention,
211
+ mvcd_attention=mvcd_attention,
212
+ use_dino=use_dino
213
+ )
214
+ for d in range(num_layers)
215
+ ]
216
+ )
217
+
218
+ # 4. Define output layers
219
+ self.out_channels = in_channels if out_channels is None else out_channels
220
+ if self.is_input_continuous:
221
+ # TODO: should use out_channels for continuous projections
222
+ if use_linear_projection:
223
+ self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
224
+ else:
225
+ self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
226
+ elif self.is_input_vectorized:
227
+ self.norm_out = nn.LayerNorm(inner_dim)
228
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
229
+ elif self.is_input_patches:
230
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
231
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
232
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
233
+
234
+ def forward(
235
+ self,
236
+ hidden_states: torch.Tensor,
237
+ encoder_hidden_states: Optional[torch.Tensor] = None,
238
+ dino_feature: Optional[torch.Tensor] = None,
239
+ timestep: Optional[torch.LongTensor] = None,
240
+ class_labels: Optional[torch.LongTensor] = None,
241
+ cross_attention_kwargs: Dict[str, Any] = None,
242
+ attention_mask: Optional[torch.Tensor] = None,
243
+ encoder_attention_mask: Optional[torch.Tensor] = None,
244
+ return_dict: bool = True,
245
+ ):
246
+ """
247
+ The [`Transformer2DModel`] forward method.
248
+
249
+ Args:
250
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
251
+ Input `hidden_states`.
252
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
253
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
254
+ self-attention.
255
+ timestep ( `torch.LongTensor`, *optional*):
256
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
257
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
258
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
259
+ `AdaLayerZeroNorm`.
260
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
261
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
262
+
263
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
264
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
265
+
266
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
267
+ above. This bias will be added to the cross-attention scores.
268
+ return_dict (`bool`, *optional*, defaults to `True`):
269
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
270
+ tuple.
271
+
272
+ Returns:
273
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
274
+ `tuple` where the first element is the sample tensor.
275
+ """
276
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
277
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
278
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
279
+ # expects mask of shape:
280
+ # [batch, key_tokens]
281
+ # adds singleton query_tokens dimension:
282
+ # [batch, 1, key_tokens]
283
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
284
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
285
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
286
+ if attention_mask is not None and attention_mask.ndim == 2:
287
+ # assume that mask is expressed as:
288
+ # (1 = keep, 0 = discard)
289
+ # convert mask into a bias that can be added to attention scores:
290
+ # (keep = +0, discard = -10000.0)
291
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
292
+ attention_mask = attention_mask.unsqueeze(1)
293
+
294
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
295
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
296
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
297
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
298
+
299
+ # 1. Input
300
+ if self.is_input_continuous:
301
+ batch, _, height, width = hidden_states.shape
302
+ residual = hidden_states
303
+
304
+ hidden_states = self.norm(hidden_states)
305
+ if not self.use_linear_projection:
306
+ hidden_states = self.proj_in(hidden_states)
307
+ inner_dim = hidden_states.shape[1]
308
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
309
+ else:
310
+ inner_dim = hidden_states.shape[1]
311
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
312
+ hidden_states = self.proj_in(hidden_states)
313
+ elif self.is_input_vectorized:
314
+ hidden_states = self.latent_image_embedding(hidden_states)
315
+ elif self.is_input_patches:
316
+ hidden_states = self.pos_embed(hidden_states)
317
+
318
+ # 2. Blocks
319
+ for block in self.transformer_blocks:
320
+ hidden_states = block(
321
+ hidden_states,
322
+ attention_mask=attention_mask,
323
+ encoder_hidden_states=encoder_hidden_states,
324
+ dino_feature=dino_feature,
325
+ encoder_attention_mask=encoder_attention_mask,
326
+ timestep=timestep,
327
+ cross_attention_kwargs=cross_attention_kwargs,
328
+ class_labels=class_labels,
329
+ )
330
+
331
+ # 3. Output
332
+ if self.is_input_continuous:
333
+ if not self.use_linear_projection:
334
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
335
+ hidden_states = self.proj_out(hidden_states)
336
+ else:
337
+ hidden_states = self.proj_out(hidden_states)
338
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
339
+
340
+ output = hidden_states + residual
341
+ elif self.is_input_vectorized:
342
+ hidden_states = self.norm_out(hidden_states)
343
+ logits = self.out(hidden_states)
344
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
345
+ logits = logits.permute(0, 2, 1)
346
+
347
+ # log(p(x_0))
348
+ output = F.log_softmax(logits.double(), dim=1).float()
349
+ elif self.is_input_patches:
350
+ # TODO: cleanup!
351
+ conditioning = self.transformer_blocks[0].norm1.emb(
352
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
353
+ )
354
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
355
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
356
+ hidden_states = self.proj_out_2(hidden_states)
357
+
358
+ # unpatchify
359
+ height = width = int(hidden_states.shape[1] ** 0.5)
360
+ hidden_states = hidden_states.reshape(
361
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
362
+ )
363
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
364
+ output = hidden_states.reshape(
365
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
366
+ )
367
+
368
+ if not return_dict:
369
+ return (output,)
370
+
371
+ return TransformerMV2DModelOutput(sample=output)
372
+
373
+
374
+ @maybe_allow_in_graph
375
+ class BasicMVTransformerBlock(nn.Module):
376
+ r"""
377
+ A basic Transformer block.
378
+
379
+ Parameters:
380
+ dim (`int`): The number of channels in the input and output.
381
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
382
+ attention_head_dim (`int`): The number of channels in each head.
383
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
384
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
385
+ only_cross_attention (`bool`, *optional*):
386
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
387
+ double_self_attention (`bool`, *optional*):
388
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
389
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
390
+ num_embeds_ada_norm (:
391
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
392
+ attention_bias (:
393
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
394
+ """
395
+
396
+ def __init__(
397
+ self,
398
+ dim: int,
399
+ num_attention_heads: int,
400
+ attention_head_dim: int,
401
+ dropout=0.0,
402
+ cross_attention_dim: Optional[int] = None,
403
+ activation_fn: str = "geglu",
404
+ num_embeds_ada_norm: Optional[int] = None,
405
+ attention_bias: bool = False,
406
+ only_cross_attention: bool = False,
407
+ double_self_attention: bool = False,
408
+ upcast_attention: bool = False,
409
+ norm_elementwise_affine: bool = True,
410
+ norm_type: str = "layer_norm",
411
+ final_dropout: bool = False,
412
+ num_views: int = 1,
413
+ cd_attention_last: bool = False,
414
+ cd_attention_mid: bool = False,
415
+ multiview_attention: bool = True,
416
+ mvcd_attention: bool = False,
417
+ rowwise_attention: bool = True,
418
+ use_dino: bool = False
419
+ ):
420
+ super().__init__()
421
+ self.only_cross_attention = only_cross_attention
422
+
423
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
424
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
425
+
426
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
427
+ raise ValueError(
428
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
429
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
430
+ )
431
+
432
+ # Define 3 blocks. Each block has its own normalization layer.
433
+ # 1. Self-Attn
434
+ if self.use_ada_layer_norm:
435
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
436
+ elif self.use_ada_layer_norm_zero:
437
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
438
+ else:
439
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
440
+
441
+ self.multiview_attention = multiview_attention
442
+ self.mvcd_attention = mvcd_attention
443
+ self.cd_attention_mid = cd_attention_mid
444
+ self.rowwise_attention = multiview_attention and rowwise_attention
445
+
446
+ if mvcd_attention and (not cd_attention_mid):
447
+ # add cross domain attn to self attn
448
+ self.attn1 = CustomJointAttention(
449
+ query_dim=dim,
450
+ heads=num_attention_heads,
451
+ dim_head=attention_head_dim,
452
+ dropout=dropout,
453
+ bias=attention_bias,
454
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
455
+ upcast_attention=upcast_attention,
456
+ processor=JointAttnProcessor()
457
+ )
458
+ else:
459
+ self.attn1 = Attention(
460
+ query_dim=dim,
461
+ heads=num_attention_heads,
462
+ dim_head=attention_head_dim,
463
+ dropout=dropout,
464
+ bias=attention_bias,
465
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
466
+ upcast_attention=upcast_attention
467
+ )
468
+ # 1.1 rowwise multiview attention
469
+ if self.rowwise_attention:
470
+ # print('INFO: using self+row_wise mv attention...')
471
+ self.norm_mv = (
472
+ AdaLayerNorm(dim, num_embeds_ada_norm)
473
+ if self.use_ada_layer_norm
474
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
475
+ )
476
+ self.attn_mv = CustomAttention(
477
+ query_dim=dim,
478
+ heads=num_attention_heads,
479
+ dim_head=attention_head_dim,
480
+ dropout=dropout,
481
+ bias=attention_bias,
482
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
483
+ upcast_attention=upcast_attention,
484
+ processor=MVAttnProcessor()
485
+ )
486
+ nn.init.zeros_(self.attn_mv.to_out[0].weight.data)
487
+ else:
488
+ self.norm_mv = None
489
+ self.attn_mv = None
490
+
491
+ # # 1.2 rowwise cross-domain attn
492
+ # if mvcd_attention:
493
+ # self.attn_joint = CustomJointAttention(
494
+ # query_dim=dim,
495
+ # heads=num_attention_heads,
496
+ # dim_head=attention_head_dim,
497
+ # dropout=dropout,
498
+ # bias=attention_bias,
499
+ # cross_attention_dim=cross_attention_dim if only_cross_attention else None,
500
+ # upcast_attention=upcast_attention,
501
+ # processor=JointAttnProcessor()
502
+ # )
503
+ # nn.init.zeros_(self.attn_joint.to_out[0].weight.data)
504
+ # self.norm_joint = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
505
+ # else:
506
+ # self.attn_joint = None
507
+ # self.norm_joint = None
508
+
509
+ # 2. Cross-Attn
510
+ if cross_attention_dim is not None or double_self_attention:
511
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
512
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
513
+ # the second cross attention block.
514
+ self.norm2 = (
515
+ AdaLayerNorm(dim, num_embeds_ada_norm)
516
+ if self.use_ada_layer_norm
517
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
518
+ )
519
+ self.attn2 = Attention(
520
+ query_dim=dim,
521
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
522
+ heads=num_attention_heads,
523
+ dim_head=attention_head_dim,
524
+ dropout=dropout,
525
+ bias=attention_bias,
526
+ upcast_attention=upcast_attention,
527
+ ) # is self-attn if encoder_hidden_states is none
528
+ else:
529
+ self.norm2 = None
530
+ self.attn2 = None
531
+
532
+ # 3. Feed-forward
533
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
534
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
535
+
536
+ # let chunk size default to None
537
+ self._chunk_size = None
538
+ self._chunk_dim = 0
539
+
540
+ self.num_views = num_views
541
+
542
+
543
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
544
+ # Sets chunk feed-forward
545
+ self._chunk_size = chunk_size
546
+ self._chunk_dim = dim
547
+
548
+ def forward(
549
+ self,
550
+ hidden_states: torch.FloatTensor,
551
+ attention_mask: Optional[torch.FloatTensor] = None,
552
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
553
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
554
+ timestep: Optional[torch.LongTensor] = None,
555
+ cross_attention_kwargs: Dict[str, Any] = None,
556
+ class_labels: Optional[torch.LongTensor] = None,
557
+ dino_feature: Optional[torch.FloatTensor] = None
558
+ ):
559
+ assert attention_mask is None # not supported yet
560
+ # Notice that normalization is always applied before the real computation in the following blocks.
561
+ # 1. Self-Attention
562
+ if self.use_ada_layer_norm:
563
+ norm_hidden_states = self.norm1(hidden_states, timestep)
564
+ elif self.use_ada_layer_norm_zero:
565
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
566
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
567
+ )
568
+ else:
569
+ norm_hidden_states = self.norm1(hidden_states)
570
+
571
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
572
+
573
+ attn_output = self.attn1(
574
+ norm_hidden_states,
575
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
576
+ attention_mask=attention_mask,
577
+ # multiview_attention=self.multiview_attention,
578
+ # mvcd_attention=self.mvcd_attention,
579
+ **cross_attention_kwargs,
580
+ )
581
+
582
+
583
+ if self.use_ada_layer_norm_zero:
584
+ attn_output = gate_msa.unsqueeze(1) * attn_output
585
+ hidden_states = attn_output + hidden_states
586
+
587
+ # import pdb;pdb.set_trace()
588
+ # 1.1 row wise multiview attention
589
+ if self.rowwise_attention:
590
+ norm_hidden_states = (
591
+ self.norm_mv(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_mv(hidden_states)
592
+ )
593
+ attn_output = self.attn_mv(
594
+ norm_hidden_states,
595
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
596
+ attention_mask=attention_mask,
597
+ num_views=self.num_views,
598
+ multiview_attention=self.multiview_attention,
599
+ cd_attention_mid=self.cd_attention_mid,
600
+ **cross_attention_kwargs,
601
+ )
602
+ hidden_states = attn_output + hidden_states
603
+
604
+
605
+ # 2. Cross-Attention
606
+ if self.attn2 is not None:
607
+ norm_hidden_states = (
608
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
609
+ )
610
+
611
+ attn_output = self.attn2(
612
+ norm_hidden_states,
613
+ encoder_hidden_states=encoder_hidden_states,
614
+ attention_mask=encoder_attention_mask,
615
+ **cross_attention_kwargs,
616
+ )
617
+ hidden_states = attn_output + hidden_states
618
+
619
+ # 3. Feed-forward
620
+ norm_hidden_states = self.norm3(hidden_states)
621
+
622
+ if self.use_ada_layer_norm_zero:
623
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
624
+
625
+ if self._chunk_size is not None:
626
+ # "feed_forward_chunk_size" can be used to save memory
627
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
628
+ raise ValueError(
629
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
630
+ )
631
+
632
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
633
+ ff_output = torch.cat(
634
+ [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
635
+ dim=self._chunk_dim,
636
+ )
637
+ else:
638
+ ff_output = self.ff(norm_hidden_states)
639
+
640
+ if self.use_ada_layer_norm_zero:
641
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
642
+
643
+ hidden_states = ff_output + hidden_states
644
+
645
+ return hidden_states
646
+
647
+
648
+ class CustomAttention(Attention):
649
+ def set_use_memory_efficient_attention_xformers(
650
+ self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
651
+ ):
652
+ processor = XFormersMVAttnProcessor()
653
+ self.set_processor(processor)
654
+ # print("using xformers attention processor")
655
+
656
+
657
+ class CustomJointAttention(Attention):
658
+ def set_use_memory_efficient_attention_xformers(
659
+ self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
660
+ ):
661
+ processor = XFormersJointAttnProcessor()
662
+ self.set_processor(processor)
663
+ # print("using xformers attention processor")
664
+
665
+ class MVAttnProcessor:
666
+ r"""
667
+ Default processor for performing attention-related computations.
668
+ """
669
+
670
+ def __call__(
671
+ self,
672
+ attn: Attention,
673
+ hidden_states,
674
+ encoder_hidden_states=None,
675
+ attention_mask=None,
676
+ temb=None,
677
+ num_views=1,
678
+ cd_attention_mid=False
679
+ ):
680
+ residual = hidden_states
681
+
682
+ if attn.spatial_norm is not None:
683
+ hidden_states = attn.spatial_norm(hidden_states, temb)
684
+
685
+ input_ndim = hidden_states.ndim
686
+
687
+ if input_ndim == 4:
688
+ batch_size, channel, height, width = hidden_states.shape
689
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
690
+
691
+ batch_size, sequence_length, _ = (
692
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
693
+ )
694
+ height = int(math.sqrt(sequence_length))
695
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
696
+
697
+ if attn.group_norm is not None:
698
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
699
+
700
+ query = attn.to_q(hidden_states)
701
+
702
+ if encoder_hidden_states is None:
703
+ encoder_hidden_states = hidden_states
704
+ elif attn.norm_cross:
705
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
706
+
707
+ key = attn.to_k(encoder_hidden_states)
708
+ value = attn.to_v(encoder_hidden_states)
709
+
710
+ # print('query', query.shape, 'key', key.shape, 'value', value.shape)
711
+ #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320])
712
+ # pdb.set_trace()
713
+ # multi-view self-attention
714
+ def transpose(tensor):
715
+ tensor = rearrange(tensor, "(b v) (h w) c -> b v h w c", v=num_views, h=height)
716
+ tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # b v h w c
717
+ tensor = torch.cat([tensor_0, tensor_1], dim=3) # b v h 2w c
718
+ tensor = rearrange(tensor, "b v h w c -> (b h) (v w) c", v=num_views, h=height)
719
+ return tensor
720
+
721
+ if cd_attention_mid:
722
+ key = transpose(key)
723
+ value = transpose(value)
724
+ query = transpose(query)
725
+ else:
726
+ key = rearrange(key, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
727
+ value = rearrange(value, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
728
+ query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) # torch.Size([192, 384, 320])
729
+
730
+ query = attn.head_to_batch_dim(query).contiguous()
731
+ key = attn.head_to_batch_dim(key).contiguous()
732
+ value = attn.head_to_batch_dim(value).contiguous()
733
+
734
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
735
+ hidden_states = torch.bmm(attention_probs, value)
736
+ hidden_states = attn.batch_to_head_dim(hidden_states)
737
+
738
+ # linear proj
739
+ hidden_states = attn.to_out[0](hidden_states)
740
+ # dropout
741
+ hidden_states = attn.to_out[1](hidden_states)
742
+ if cd_attention_mid:
743
+ hidden_states = rearrange(hidden_states, "(b h) (v w) c -> b v h w c", v=num_views, h=height)
744
+ hidden_states_0, hidden_states_1 = torch.chunk(hidden_states, dim=3, chunks=2) # b v h w c
745
+ hidden_states = torch.cat([hidden_states_0, hidden_states_1], dim=0) # 2b v h w c
746
+ hidden_states = rearrange(hidden_states, "b v h w c -> (b v) (h w) c", v=num_views, h=height)
747
+ else:
748
+ hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height)
749
+ if input_ndim == 4:
750
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
751
+
752
+ if attn.residual_connection:
753
+ hidden_states = hidden_states + residual
754
+
755
+ hidden_states = hidden_states / attn.rescale_output_factor
756
+
757
+ return hidden_states
758
+
759
+
760
+ class XFormersMVAttnProcessor:
761
+ r"""
762
+ Default processor for performing attention-related computations.
763
+ """
764
+
765
+ def __call__(
766
+ self,
767
+ attn: Attention,
768
+ hidden_states,
769
+ encoder_hidden_states=None,
770
+ attention_mask=None,
771
+ temb=None,
772
+ num_views=1,
773
+ multiview_attention=True,
774
+ cd_attention_mid=False
775
+ ):
776
+ # print(num_views)
777
+ residual = hidden_states
778
+
779
+ if attn.spatial_norm is not None:
780
+ hidden_states = attn.spatial_norm(hidden_states, temb)
781
+
782
+ input_ndim = hidden_states.ndim
783
+
784
+ if input_ndim == 4:
785
+ batch_size, channel, height, width = hidden_states.shape
786
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
787
+
788
+ batch_size, sequence_length, _ = (
789
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
790
+ )
791
+ height = int(math.sqrt(sequence_length))
792
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
793
+ # from yuancheng; here attention_mask is None
794
+ if attention_mask is not None:
795
+ # expand our mask's singleton query_tokens dimension:
796
+ # [batch*heads, 1, key_tokens] ->
797
+ # [batch*heads, query_tokens, key_tokens]
798
+ # so that it can be added as a bias onto the attention scores that xformers computes:
799
+ # [batch*heads, query_tokens, key_tokens]
800
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
801
+ _, query_tokens, _ = hidden_states.shape
802
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
803
+
804
+ if attn.group_norm is not None:
805
+ print('Warning: using group norm, pay attention to use it in row-wise attention')
806
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
807
+
808
+ query = attn.to_q(hidden_states)
809
+
810
+ if encoder_hidden_states is None:
811
+ encoder_hidden_states = hidden_states
812
+ elif attn.norm_cross:
813
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
814
+
815
+ key_raw = attn.to_k(encoder_hidden_states)
816
+ value_raw = attn.to_v(encoder_hidden_states)
817
+
818
+ # print('query', query.shape, 'key', key.shape, 'value', value.shape)
819
+ # pdb.set_trace()
820
+ def transpose(tensor):
821
+ tensor = rearrange(tensor, "(b v) (h w) c -> b v h w c", v=num_views, h=height)
822
+ tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # b v h w c
823
+ tensor = torch.cat([tensor_0, tensor_1], dim=3) # b v h 2w c
824
+ tensor = rearrange(tensor, "b v h w c -> (b h) (v w) c", v=num_views, h=height)
825
+ return tensor
826
+ # print(mvcd_attention)
827
+ # import pdb;pdb.set_trace()
828
+ if cd_attention_mid:
829
+ key = transpose(key_raw)
830
+ value = transpose(value_raw)
831
+ query = transpose(query)
832
+ else:
833
+ key = rearrange(key_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
834
+ value = rearrange(value_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
835
+ query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) # torch.Size([192, 384, 320])
836
+
837
+
838
+ query = attn.head_to_batch_dim(query) # torch.Size([960, 384, 64])
839
+ key = attn.head_to_batch_dim(key)
840
+ value = attn.head_to_batch_dim(value)
841
+
842
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
843
+ hidden_states = attn.batch_to_head_dim(hidden_states)
844
+
845
+ # linear proj
846
+ hidden_states = attn.to_out[0](hidden_states)
847
+ # dropout
848
+ hidden_states = attn.to_out[1](hidden_states)
849
+
850
+ if cd_attention_mid:
851
+ hidden_states = rearrange(hidden_states, "(b h) (v w) c -> b v h w c", v=num_views, h=height)
852
+ hidden_states_0, hidden_states_1 = torch.chunk(hidden_states, dim=3, chunks=2) # b v h w c
853
+ hidden_states = torch.cat([hidden_states_0, hidden_states_1], dim=0) # 2b v h w c
854
+ hidden_states = rearrange(hidden_states, "b v h w c -> (b v) (h w) c", v=num_views, h=height)
855
+ else:
856
+ hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height)
857
+ if input_ndim == 4:
858
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
859
+
860
+ if attn.residual_connection:
861
+ hidden_states = hidden_states + residual
862
+
863
+ hidden_states = hidden_states / attn.rescale_output_factor
864
+
865
+ return hidden_states
866
+
867
+
868
+ class XFormersJointAttnProcessor:
869
+ r"""
870
+ Default processor for performing attention-related computations.
871
+ """
872
+
873
+ def __call__(
874
+ self,
875
+ attn: Attention,
876
+ hidden_states,
877
+ encoder_hidden_states=None,
878
+ attention_mask=None,
879
+ temb=None,
880
+ num_tasks=2
881
+ ):
882
+ residual = hidden_states
883
+
884
+ if attn.spatial_norm is not None:
885
+ hidden_states = attn.spatial_norm(hidden_states, temb)
886
+
887
+ input_ndim = hidden_states.ndim
888
+
889
+ if input_ndim == 4:
890
+ batch_size, channel, height, width = hidden_states.shape
891
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
892
+
893
+ batch_size, sequence_length, _ = (
894
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
895
+ )
896
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
897
+
898
+ # from yuancheng; here attention_mask is None
899
+ if attention_mask is not None:
900
+ # expand our mask's singleton query_tokens dimension:
901
+ # [batch*heads, 1, key_tokens] ->
902
+ # [batch*heads, query_tokens, key_tokens]
903
+ # so that it can be added as a bias onto the attention scores that xformers computes:
904
+ # [batch*heads, query_tokens, key_tokens]
905
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
906
+ _, query_tokens, _ = hidden_states.shape
907
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
908
+
909
+ if attn.group_norm is not None:
910
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
911
+
912
+ query = attn.to_q(hidden_states)
913
+
914
+ if encoder_hidden_states is None:
915
+ encoder_hidden_states = hidden_states
916
+ elif attn.norm_cross:
917
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
918
+
919
+ key = attn.to_k(encoder_hidden_states)
920
+ value = attn.to_v(encoder_hidden_states)
921
+
922
+ assert num_tasks == 2 # only support two tasks now
923
+
924
+ def transpose(tensor):
925
+ tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # bv hw c
926
+ tensor = torch.cat([tensor_0, tensor_1], dim=1) # bv 2hw c
927
+ return tensor
928
+ key = transpose(key)
929
+ value = transpose(value)
930
+ query = transpose(query)
931
+ # from icecream import ic
932
+ # ic(key.shape, value.shape, query.shape)
933
+ # import pdb;pdb.set_trace()
934
+ query = attn.head_to_batch_dim(query).contiguous()
935
+ key = attn.head_to_batch_dim(key).contiguous()
936
+ value = attn.head_to_batch_dim(value).contiguous()
937
+
938
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
939
+ hidden_states = attn.batch_to_head_dim(hidden_states)
940
+
941
+ # linear proj
942
+ hidden_states = attn.to_out[0](hidden_states)
943
+ # dropout
944
+ hidden_states = attn.to_out[1](hidden_states)
945
+ hidden_states_normal, hidden_states_color = torch.chunk(hidden_states, dim=1, chunks=2)
946
+ hidden_states = torch.cat([hidden_states_normal, hidden_states_color], dim=0) # 2bv hw c
947
+
948
+ if input_ndim == 4:
949
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
950
+
951
+ if attn.residual_connection:
952
+ hidden_states = hidden_states + residual
953
+
954
+ hidden_states = hidden_states / attn.rescale_output_factor
955
+
956
+ return hidden_states
957
+
958
+
959
+ class JointAttnProcessor:
960
+ r"""
961
+ Default processor for performing attention-related computations.
962
+ """
963
+
964
+ def __call__(
965
+ self,
966
+ attn: Attention,
967
+ hidden_states,
968
+ encoder_hidden_states=None,
969
+ attention_mask=None,
970
+ temb=None,
971
+ num_tasks=2
972
+ ):
973
+
974
+ residual = hidden_states
975
+
976
+ if attn.spatial_norm is not None:
977
+ hidden_states = attn.spatial_norm(hidden_states, temb)
978
+
979
+ input_ndim = hidden_states.ndim
980
+
981
+ if input_ndim == 4:
982
+ batch_size, channel, height, width = hidden_states.shape
983
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
984
+
985
+ batch_size, sequence_length, _ = (
986
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
987
+ )
988
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
989
+
990
+
991
+ if attn.group_norm is not None:
992
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
993
+
994
+ query = attn.to_q(hidden_states)
995
+
996
+ if encoder_hidden_states is None:
997
+ encoder_hidden_states = hidden_states
998
+ elif attn.norm_cross:
999
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1000
+
1001
+ key = attn.to_k(encoder_hidden_states)
1002
+ value = attn.to_v(encoder_hidden_states)
1003
+
1004
+ assert num_tasks == 2 # only support two tasks now
1005
+
1006
+ def transpose(tensor):
1007
+ tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # bv hw c
1008
+ tensor = torch.cat([tensor_0, tensor_1], dim=1) # bv 2hw c
1009
+ return tensor
1010
+ key = transpose(key)
1011
+ value = transpose(value)
1012
+ query = transpose(query)
1013
+
1014
+
1015
+ query = attn.head_to_batch_dim(query).contiguous()
1016
+ key = attn.head_to_batch_dim(key).contiguous()
1017
+ value = attn.head_to_batch_dim(value).contiguous()
1018
+
1019
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
1020
+ hidden_states = torch.bmm(attention_probs, value)
1021
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1022
+
1023
+
1024
+ # linear proj
1025
+ hidden_states = attn.to_out[0](hidden_states)
1026
+ # dropout
1027
+ hidden_states = attn.to_out[1](hidden_states)
1028
+
1029
+ hidden_states = torch.cat([hidden_states[:, 0], hidden_states[:, 1]], dim=0) # 2bv hw c
1030
+ if input_ndim == 4:
1031
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1032
+
1033
+ if attn.residual_connection:
1034
+ hidden_states = hidden_states + residual
1035
+
1036
+ hidden_states = hidden_states / attn.rescale_output_factor
1037
+
1038
+ return hidden_states
mvdiffusion/models/unet_mv2d_blocks.py ADDED
@@ -0,0 +1,971 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional, Tuple
15
+
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.utils import is_torch_version, logging
22
+ from diffusers.models.normalization import AdaGroupNorm
23
+ from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
24
+ from diffusers.models.dual_transformer_2d import DualTransformer2DModel
25
+ from diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D
26
+
27
+ from diffusers.models.unets.unet_2d_blocks import DownBlock2D, ResnetDownsampleBlock2D, AttnDownBlock2D, CrossAttnDownBlock2D, SimpleCrossAttnDownBlock2D, SkipDownBlock2D, AttnSkipDownBlock2D, DownEncoderBlock2D, AttnDownEncoderBlock2D, KDownBlock2D, KCrossAttnDownBlock2D
28
+ from diffusers.models.unets.unet_2d_blocks import UpBlock2D, ResnetUpsampleBlock2D, CrossAttnUpBlock2D, SimpleCrossAttnUpBlock2D, AttnUpBlock2D, SkipUpBlock2D, AttnSkipUpBlock2D, UpDecoderBlock2D, AttnUpDecoderBlock2D, KUpBlock2D, KCrossAttnUpBlock2D
29
+
30
+
31
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
+
33
+
34
+ def get_down_block(
35
+ down_block_type,
36
+ num_layers,
37
+ in_channels,
38
+ out_channels,
39
+ temb_channels,
40
+ add_downsample,
41
+ resnet_eps,
42
+ resnet_act_fn,
43
+ transformer_layers_per_block=1,
44
+ num_attention_heads=None,
45
+ resnet_groups=None,
46
+ cross_attention_dim=None,
47
+ downsample_padding=None,
48
+ dual_cross_attention=False,
49
+ use_linear_projection=False,
50
+ only_cross_attention=False,
51
+ upcast_attention=False,
52
+ resnet_time_scale_shift="default",
53
+ resnet_skip_time_act=False,
54
+ resnet_out_scale_factor=1.0,
55
+ cross_attention_norm=None,
56
+ attention_head_dim=None,
57
+ downsample_type=None,
58
+ num_views=1,
59
+ cd_attention_last: bool = False,
60
+ cd_attention_mid: bool = False,
61
+ multiview_attention: bool = True,
62
+ sparse_mv_attention: bool = False,
63
+ selfattn_block: str = "custom",
64
+ mvcd_attention: bool=False,
65
+ use_dino: bool = False
66
+ ):
67
+ # If attn head dim is not defined, we default it to the number of heads
68
+ if attention_head_dim is None:
69
+ logger.warn(
70
+ f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
71
+ )
72
+ attention_head_dim = num_attention_heads
73
+
74
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
75
+ if down_block_type == "DownBlock2D":
76
+ return DownBlock2D(
77
+ num_layers=num_layers,
78
+ in_channels=in_channels,
79
+ out_channels=out_channels,
80
+ temb_channels=temb_channels,
81
+ add_downsample=add_downsample,
82
+ resnet_eps=resnet_eps,
83
+ resnet_act_fn=resnet_act_fn,
84
+ resnet_groups=resnet_groups,
85
+ downsample_padding=downsample_padding,
86
+ resnet_time_scale_shift=resnet_time_scale_shift,
87
+ )
88
+ elif down_block_type == "ResnetDownsampleBlock2D":
89
+ return ResnetDownsampleBlock2D(
90
+ num_layers=num_layers,
91
+ in_channels=in_channels,
92
+ out_channels=out_channels,
93
+ temb_channels=temb_channels,
94
+ add_downsample=add_downsample,
95
+ resnet_eps=resnet_eps,
96
+ resnet_act_fn=resnet_act_fn,
97
+ resnet_groups=resnet_groups,
98
+ resnet_time_scale_shift=resnet_time_scale_shift,
99
+ skip_time_act=resnet_skip_time_act,
100
+ output_scale_factor=resnet_out_scale_factor,
101
+ )
102
+ elif down_block_type == "AttnDownBlock2D":
103
+ if add_downsample is False:
104
+ downsample_type = None
105
+ else:
106
+ downsample_type = downsample_type or "conv" # default to 'conv'
107
+ return AttnDownBlock2D(
108
+ num_layers=num_layers,
109
+ in_channels=in_channels,
110
+ out_channels=out_channels,
111
+ temb_channels=temb_channels,
112
+ resnet_eps=resnet_eps,
113
+ resnet_act_fn=resnet_act_fn,
114
+ resnet_groups=resnet_groups,
115
+ downsample_padding=downsample_padding,
116
+ attention_head_dim=attention_head_dim,
117
+ resnet_time_scale_shift=resnet_time_scale_shift,
118
+ downsample_type=downsample_type,
119
+ )
120
+ elif down_block_type == "CrossAttnDownBlock2D":
121
+ if cross_attention_dim is None:
122
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
123
+ return CrossAttnDownBlock2D(
124
+ num_layers=num_layers,
125
+ transformer_layers_per_block=transformer_layers_per_block,
126
+ in_channels=in_channels,
127
+ out_channels=out_channels,
128
+ temb_channels=temb_channels,
129
+ add_downsample=add_downsample,
130
+ resnet_eps=resnet_eps,
131
+ resnet_act_fn=resnet_act_fn,
132
+ resnet_groups=resnet_groups,
133
+ downsample_padding=downsample_padding,
134
+ cross_attention_dim=cross_attention_dim,
135
+ num_attention_heads=num_attention_heads,
136
+ dual_cross_attention=dual_cross_attention,
137
+ use_linear_projection=use_linear_projection,
138
+ only_cross_attention=only_cross_attention,
139
+ upcast_attention=upcast_attention,
140
+ resnet_time_scale_shift=resnet_time_scale_shift,
141
+ )
142
+ # custom MV2D attention block
143
+ elif down_block_type == "CrossAttnDownBlockMV2D":
144
+ if cross_attention_dim is None:
145
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMV2D")
146
+ return CrossAttnDownBlockMV2D(
147
+ num_layers=num_layers,
148
+ transformer_layers_per_block=transformer_layers_per_block,
149
+ in_channels=in_channels,
150
+ out_channels=out_channels,
151
+ temb_channels=temb_channels,
152
+ add_downsample=add_downsample,
153
+ resnet_eps=resnet_eps,
154
+ resnet_act_fn=resnet_act_fn,
155
+ resnet_groups=resnet_groups,
156
+ downsample_padding=downsample_padding,
157
+ cross_attention_dim=cross_attention_dim,
158
+ num_attention_heads=num_attention_heads,
159
+ dual_cross_attention=dual_cross_attention,
160
+ use_linear_projection=use_linear_projection,
161
+ only_cross_attention=only_cross_attention,
162
+ upcast_attention=upcast_attention,
163
+ resnet_time_scale_shift=resnet_time_scale_shift,
164
+ num_views=num_views,
165
+ cd_attention_last=cd_attention_last,
166
+ cd_attention_mid=cd_attention_mid,
167
+ multiview_attention=multiview_attention,
168
+ sparse_mv_attention=sparse_mv_attention,
169
+ selfattn_block=selfattn_block,
170
+ mvcd_attention=mvcd_attention,
171
+ use_dino=use_dino
172
+ )
173
+ elif down_block_type == "SimpleCrossAttnDownBlock2D":
174
+ if cross_attention_dim is None:
175
+ raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D")
176
+ return SimpleCrossAttnDownBlock2D(
177
+ num_layers=num_layers,
178
+ in_channels=in_channels,
179
+ out_channels=out_channels,
180
+ temb_channels=temb_channels,
181
+ add_downsample=add_downsample,
182
+ resnet_eps=resnet_eps,
183
+ resnet_act_fn=resnet_act_fn,
184
+ resnet_groups=resnet_groups,
185
+ cross_attention_dim=cross_attention_dim,
186
+ attention_head_dim=attention_head_dim,
187
+ resnet_time_scale_shift=resnet_time_scale_shift,
188
+ skip_time_act=resnet_skip_time_act,
189
+ output_scale_factor=resnet_out_scale_factor,
190
+ only_cross_attention=only_cross_attention,
191
+ cross_attention_norm=cross_attention_norm,
192
+ )
193
+ elif down_block_type == "SkipDownBlock2D":
194
+ return SkipDownBlock2D(
195
+ num_layers=num_layers,
196
+ in_channels=in_channels,
197
+ out_channels=out_channels,
198
+ temb_channels=temb_channels,
199
+ add_downsample=add_downsample,
200
+ resnet_eps=resnet_eps,
201
+ resnet_act_fn=resnet_act_fn,
202
+ downsample_padding=downsample_padding,
203
+ resnet_time_scale_shift=resnet_time_scale_shift,
204
+ )
205
+ elif down_block_type == "AttnSkipDownBlock2D":
206
+ return AttnSkipDownBlock2D(
207
+ num_layers=num_layers,
208
+ in_channels=in_channels,
209
+ out_channels=out_channels,
210
+ temb_channels=temb_channels,
211
+ add_downsample=add_downsample,
212
+ resnet_eps=resnet_eps,
213
+ resnet_act_fn=resnet_act_fn,
214
+ attention_head_dim=attention_head_dim,
215
+ resnet_time_scale_shift=resnet_time_scale_shift,
216
+ )
217
+ elif down_block_type == "DownEncoderBlock2D":
218
+ return DownEncoderBlock2D(
219
+ num_layers=num_layers,
220
+ in_channels=in_channels,
221
+ out_channels=out_channels,
222
+ add_downsample=add_downsample,
223
+ resnet_eps=resnet_eps,
224
+ resnet_act_fn=resnet_act_fn,
225
+ resnet_groups=resnet_groups,
226
+ downsample_padding=downsample_padding,
227
+ resnet_time_scale_shift=resnet_time_scale_shift,
228
+ )
229
+ elif down_block_type == "AttnDownEncoderBlock2D":
230
+ return AttnDownEncoderBlock2D(
231
+ num_layers=num_layers,
232
+ in_channels=in_channels,
233
+ out_channels=out_channels,
234
+ add_downsample=add_downsample,
235
+ resnet_eps=resnet_eps,
236
+ resnet_act_fn=resnet_act_fn,
237
+ resnet_groups=resnet_groups,
238
+ downsample_padding=downsample_padding,
239
+ attention_head_dim=attention_head_dim,
240
+ resnet_time_scale_shift=resnet_time_scale_shift,
241
+ )
242
+ elif down_block_type == "KDownBlock2D":
243
+ return KDownBlock2D(
244
+ num_layers=num_layers,
245
+ in_channels=in_channels,
246
+ out_channels=out_channels,
247
+ temb_channels=temb_channels,
248
+ add_downsample=add_downsample,
249
+ resnet_eps=resnet_eps,
250
+ resnet_act_fn=resnet_act_fn,
251
+ )
252
+ elif down_block_type == "KCrossAttnDownBlock2D":
253
+ return KCrossAttnDownBlock2D(
254
+ num_layers=num_layers,
255
+ in_channels=in_channels,
256
+ out_channels=out_channels,
257
+ temb_channels=temb_channels,
258
+ add_downsample=add_downsample,
259
+ resnet_eps=resnet_eps,
260
+ resnet_act_fn=resnet_act_fn,
261
+ cross_attention_dim=cross_attention_dim,
262
+ attention_head_dim=attention_head_dim,
263
+ add_self_attention=True if not add_downsample else False,
264
+ )
265
+ raise ValueError(f"{down_block_type} does not exist.")
266
+
267
+
268
+ def get_up_block(
269
+ up_block_type,
270
+ num_layers,
271
+ in_channels,
272
+ out_channels,
273
+ prev_output_channel,
274
+ temb_channels,
275
+ add_upsample,
276
+ resnet_eps,
277
+ resnet_act_fn,
278
+ transformer_layers_per_block=1,
279
+ num_attention_heads=None,
280
+ resnet_groups=None,
281
+ cross_attention_dim=None,
282
+ dual_cross_attention=False,
283
+ use_linear_projection=False,
284
+ only_cross_attention=False,
285
+ upcast_attention=False,
286
+ resnet_time_scale_shift="default",
287
+ resnet_skip_time_act=False,
288
+ resnet_out_scale_factor=1.0,
289
+ cross_attention_norm=None,
290
+ attention_head_dim=None,
291
+ upsample_type=None,
292
+ num_views=1,
293
+ cd_attention_last: bool = False,
294
+ cd_attention_mid: bool = False,
295
+ multiview_attention: bool = True,
296
+ sparse_mv_attention: bool = False,
297
+ selfattn_block: str = "custom",
298
+ mvcd_attention: bool=False,
299
+ use_dino: bool = False
300
+ ):
301
+ # If attn head dim is not defined, we default it to the number of heads
302
+ if attention_head_dim is None:
303
+ logger.warn(
304
+ f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
305
+ )
306
+ attention_head_dim = num_attention_heads
307
+
308
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
309
+ if up_block_type == "UpBlock2D":
310
+ return UpBlock2D(
311
+ num_layers=num_layers,
312
+ in_channels=in_channels,
313
+ out_channels=out_channels,
314
+ prev_output_channel=prev_output_channel,
315
+ temb_channels=temb_channels,
316
+ add_upsample=add_upsample,
317
+ resnet_eps=resnet_eps,
318
+ resnet_act_fn=resnet_act_fn,
319
+ resnet_groups=resnet_groups,
320
+ resnet_time_scale_shift=resnet_time_scale_shift,
321
+ )
322
+ elif up_block_type == "ResnetUpsampleBlock2D":
323
+ return ResnetUpsampleBlock2D(
324
+ num_layers=num_layers,
325
+ in_channels=in_channels,
326
+ out_channels=out_channels,
327
+ prev_output_channel=prev_output_channel,
328
+ temb_channels=temb_channels,
329
+ add_upsample=add_upsample,
330
+ resnet_eps=resnet_eps,
331
+ resnet_act_fn=resnet_act_fn,
332
+ resnet_groups=resnet_groups,
333
+ resnet_time_scale_shift=resnet_time_scale_shift,
334
+ skip_time_act=resnet_skip_time_act,
335
+ output_scale_factor=resnet_out_scale_factor,
336
+ )
337
+ elif up_block_type == "CrossAttnUpBlock2D":
338
+ if cross_attention_dim is None:
339
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
340
+ return CrossAttnUpBlock2D(
341
+ num_layers=num_layers,
342
+ transformer_layers_per_block=transformer_layers_per_block,
343
+ in_channels=in_channels,
344
+ out_channels=out_channels,
345
+ prev_output_channel=prev_output_channel,
346
+ temb_channels=temb_channels,
347
+ add_upsample=add_upsample,
348
+ resnet_eps=resnet_eps,
349
+ resnet_act_fn=resnet_act_fn,
350
+ resnet_groups=resnet_groups,
351
+ cross_attention_dim=cross_attention_dim,
352
+ num_attention_heads=num_attention_heads,
353
+ dual_cross_attention=dual_cross_attention,
354
+ use_linear_projection=use_linear_projection,
355
+ only_cross_attention=only_cross_attention,
356
+ upcast_attention=upcast_attention,
357
+ resnet_time_scale_shift=resnet_time_scale_shift,
358
+ )
359
+ # custom MV2D attention block
360
+ elif up_block_type == "CrossAttnUpBlockMV2D":
361
+ if cross_attention_dim is None:
362
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMV2D")
363
+ return CrossAttnUpBlockMV2D(
364
+ num_layers=num_layers,
365
+ transformer_layers_per_block=transformer_layers_per_block,
366
+ in_channels=in_channels,
367
+ out_channels=out_channels,
368
+ prev_output_channel=prev_output_channel,
369
+ temb_channels=temb_channels,
370
+ add_upsample=add_upsample,
371
+ resnet_eps=resnet_eps,
372
+ resnet_act_fn=resnet_act_fn,
373
+ resnet_groups=resnet_groups,
374
+ cross_attention_dim=cross_attention_dim,
375
+ num_attention_heads=num_attention_heads,
376
+ dual_cross_attention=dual_cross_attention,
377
+ use_linear_projection=use_linear_projection,
378
+ only_cross_attention=only_cross_attention,
379
+ upcast_attention=upcast_attention,
380
+ resnet_time_scale_shift=resnet_time_scale_shift,
381
+ num_views=num_views,
382
+ cd_attention_last=cd_attention_last,
383
+ cd_attention_mid=cd_attention_mid,
384
+ multiview_attention=multiview_attention,
385
+ sparse_mv_attention=sparse_mv_attention,
386
+ selfattn_block=selfattn_block,
387
+ mvcd_attention=mvcd_attention,
388
+ use_dino=use_dino
389
+ )
390
+ elif up_block_type == "SimpleCrossAttnUpBlock2D":
391
+ if cross_attention_dim is None:
392
+ raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D")
393
+ return SimpleCrossAttnUpBlock2D(
394
+ num_layers=num_layers,
395
+ in_channels=in_channels,
396
+ out_channels=out_channels,
397
+ prev_output_channel=prev_output_channel,
398
+ temb_channels=temb_channels,
399
+ add_upsample=add_upsample,
400
+ resnet_eps=resnet_eps,
401
+ resnet_act_fn=resnet_act_fn,
402
+ resnet_groups=resnet_groups,
403
+ cross_attention_dim=cross_attention_dim,
404
+ attention_head_dim=attention_head_dim,
405
+ resnet_time_scale_shift=resnet_time_scale_shift,
406
+ skip_time_act=resnet_skip_time_act,
407
+ output_scale_factor=resnet_out_scale_factor,
408
+ only_cross_attention=only_cross_attention,
409
+ cross_attention_norm=cross_attention_norm,
410
+ )
411
+ elif up_block_type == "AttnUpBlock2D":
412
+ if add_upsample is False:
413
+ upsample_type = None
414
+ else:
415
+ upsample_type = upsample_type or "conv" # default to 'conv'
416
+
417
+ return AttnUpBlock2D(
418
+ num_layers=num_layers,
419
+ in_channels=in_channels,
420
+ out_channels=out_channels,
421
+ prev_output_channel=prev_output_channel,
422
+ temb_channels=temb_channels,
423
+ resnet_eps=resnet_eps,
424
+ resnet_act_fn=resnet_act_fn,
425
+ resnet_groups=resnet_groups,
426
+ attention_head_dim=attention_head_dim,
427
+ resnet_time_scale_shift=resnet_time_scale_shift,
428
+ upsample_type=upsample_type,
429
+ )
430
+ elif up_block_type == "SkipUpBlock2D":
431
+ return SkipUpBlock2D(
432
+ num_layers=num_layers,
433
+ in_channels=in_channels,
434
+ out_channels=out_channels,
435
+ prev_output_channel=prev_output_channel,
436
+ temb_channels=temb_channels,
437
+ add_upsample=add_upsample,
438
+ resnet_eps=resnet_eps,
439
+ resnet_act_fn=resnet_act_fn,
440
+ resnet_time_scale_shift=resnet_time_scale_shift,
441
+ )
442
+ elif up_block_type == "AttnSkipUpBlock2D":
443
+ return AttnSkipUpBlock2D(
444
+ num_layers=num_layers,
445
+ in_channels=in_channels,
446
+ out_channels=out_channels,
447
+ prev_output_channel=prev_output_channel,
448
+ temb_channels=temb_channels,
449
+ add_upsample=add_upsample,
450
+ resnet_eps=resnet_eps,
451
+ resnet_act_fn=resnet_act_fn,
452
+ attention_head_dim=attention_head_dim,
453
+ resnet_time_scale_shift=resnet_time_scale_shift,
454
+ )
455
+ elif up_block_type == "UpDecoderBlock2D":
456
+ return UpDecoderBlock2D(
457
+ num_layers=num_layers,
458
+ in_channels=in_channels,
459
+ out_channels=out_channels,
460
+ add_upsample=add_upsample,
461
+ resnet_eps=resnet_eps,
462
+ resnet_act_fn=resnet_act_fn,
463
+ resnet_groups=resnet_groups,
464
+ resnet_time_scale_shift=resnet_time_scale_shift,
465
+ temb_channels=temb_channels,
466
+ )
467
+ elif up_block_type == "AttnUpDecoderBlock2D":
468
+ return AttnUpDecoderBlock2D(
469
+ num_layers=num_layers,
470
+ in_channels=in_channels,
471
+ out_channels=out_channels,
472
+ add_upsample=add_upsample,
473
+ resnet_eps=resnet_eps,
474
+ resnet_act_fn=resnet_act_fn,
475
+ resnet_groups=resnet_groups,
476
+ attention_head_dim=attention_head_dim,
477
+ resnet_time_scale_shift=resnet_time_scale_shift,
478
+ temb_channels=temb_channels,
479
+ )
480
+ elif up_block_type == "KUpBlock2D":
481
+ return KUpBlock2D(
482
+ num_layers=num_layers,
483
+ in_channels=in_channels,
484
+ out_channels=out_channels,
485
+ temb_channels=temb_channels,
486
+ add_upsample=add_upsample,
487
+ resnet_eps=resnet_eps,
488
+ resnet_act_fn=resnet_act_fn,
489
+ )
490
+ elif up_block_type == "KCrossAttnUpBlock2D":
491
+ return KCrossAttnUpBlock2D(
492
+ num_layers=num_layers,
493
+ in_channels=in_channels,
494
+ out_channels=out_channels,
495
+ temb_channels=temb_channels,
496
+ add_upsample=add_upsample,
497
+ resnet_eps=resnet_eps,
498
+ resnet_act_fn=resnet_act_fn,
499
+ cross_attention_dim=cross_attention_dim,
500
+ attention_head_dim=attention_head_dim,
501
+ )
502
+
503
+ raise ValueError(f"{up_block_type} does not exist.")
504
+
505
+
506
+ class UNetMidBlockMV2DCrossAttn(nn.Module):
507
+ def __init__(
508
+ self,
509
+ in_channels: int,
510
+ temb_channels: int,
511
+ dropout: float = 0.0,
512
+ num_layers: int = 1,
513
+ transformer_layers_per_block: int = 1,
514
+ resnet_eps: float = 1e-6,
515
+ resnet_time_scale_shift: str = "default",
516
+ resnet_act_fn: str = "swish",
517
+ resnet_groups: int = 32,
518
+ resnet_pre_norm: bool = True,
519
+ num_attention_heads=1,
520
+ output_scale_factor=1.0,
521
+ cross_attention_dim=1280,
522
+ dual_cross_attention=False,
523
+ use_linear_projection=False,
524
+ upcast_attention=False,
525
+ num_views: int = 1,
526
+ cd_attention_last: bool = False,
527
+ cd_attention_mid: bool = False,
528
+ multiview_attention: bool = True,
529
+ sparse_mv_attention: bool = False,
530
+ selfattn_block: str = "custom",
531
+ mvcd_attention: bool=False,
532
+ use_dino: bool = False
533
+ ):
534
+ super().__init__()
535
+
536
+ self.has_cross_attention = True
537
+ self.num_attention_heads = num_attention_heads
538
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
539
+ if selfattn_block == "custom":
540
+ from .transformer_mv2d import TransformerMV2DModel
541
+ elif selfattn_block == "rowwise":
542
+ from .transformer_mv2d_rowwise import TransformerMV2DModel
543
+ elif selfattn_block == "self_rowwise":
544
+ from .transformer_mv2d_self_rowwise import TransformerMV2DModel
545
+ else:
546
+ raise NotImplementedError
547
+
548
+ # there is always at least one resnet
549
+ resnets = [
550
+ ResnetBlock2D(
551
+ in_channels=in_channels,
552
+ out_channels=in_channels,
553
+ temb_channels=temb_channels,
554
+ eps=resnet_eps,
555
+ groups=resnet_groups,
556
+ dropout=dropout,
557
+ time_embedding_norm=resnet_time_scale_shift,
558
+ non_linearity=resnet_act_fn,
559
+ output_scale_factor=output_scale_factor,
560
+ pre_norm=resnet_pre_norm,
561
+ )
562
+ ]
563
+ attentions = []
564
+
565
+ for _ in range(num_layers):
566
+ if not dual_cross_attention:
567
+ attentions.append(
568
+ TransformerMV2DModel(
569
+ num_attention_heads,
570
+ in_channels // num_attention_heads,
571
+ in_channels=in_channels,
572
+ num_layers=transformer_layers_per_block,
573
+ cross_attention_dim=cross_attention_dim,
574
+ norm_num_groups=resnet_groups,
575
+ use_linear_projection=use_linear_projection,
576
+ upcast_attention=upcast_attention,
577
+ num_views=num_views,
578
+ cd_attention_last=cd_attention_last,
579
+ cd_attention_mid=cd_attention_mid,
580
+ multiview_attention=multiview_attention,
581
+ sparse_mv_attention=sparse_mv_attention,
582
+ mvcd_attention=mvcd_attention,
583
+ use_dino=use_dino
584
+ )
585
+ )
586
+ else:
587
+ raise NotImplementedError
588
+ resnets.append(
589
+ ResnetBlock2D(
590
+ in_channels=in_channels,
591
+ out_channels=in_channels,
592
+ temb_channels=temb_channels,
593
+ eps=resnet_eps,
594
+ groups=resnet_groups,
595
+ dropout=dropout,
596
+ time_embedding_norm=resnet_time_scale_shift,
597
+ non_linearity=resnet_act_fn,
598
+ output_scale_factor=output_scale_factor,
599
+ pre_norm=resnet_pre_norm,
600
+ )
601
+ )
602
+
603
+ self.attentions = nn.ModuleList(attentions)
604
+ self.resnets = nn.ModuleList(resnets)
605
+
606
+ def forward(
607
+ self,
608
+ hidden_states: torch.FloatTensor,
609
+ temb: Optional[torch.FloatTensor] = None,
610
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
611
+ attention_mask: Optional[torch.FloatTensor] = None,
612
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
613
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
614
+ dino_feature: Optional[torch.FloatTensor] = None
615
+ ) -> torch.FloatTensor:
616
+ hidden_states = self.resnets[0](hidden_states, temb)
617
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
618
+ hidden_states = attn(
619
+ hidden_states,
620
+ encoder_hidden_states=encoder_hidden_states,
621
+ cross_attention_kwargs=cross_attention_kwargs,
622
+ attention_mask=attention_mask,
623
+ encoder_attention_mask=encoder_attention_mask,
624
+ dino_feature=dino_feature,
625
+ return_dict=False,
626
+ )[0]
627
+ hidden_states = resnet(hidden_states, temb)
628
+
629
+ return hidden_states
630
+
631
+
632
+ class CrossAttnUpBlockMV2D(nn.Module):
633
+ def __init__(
634
+ self,
635
+ in_channels: int,
636
+ out_channels: int,
637
+ prev_output_channel: int,
638
+ temb_channels: int,
639
+ dropout: float = 0.0,
640
+ num_layers: int = 1,
641
+ transformer_layers_per_block: int = 1,
642
+ resnet_eps: float = 1e-6,
643
+ resnet_time_scale_shift: str = "default",
644
+ resnet_act_fn: str = "swish",
645
+ resnet_groups: int = 32,
646
+ resnet_pre_norm: bool = True,
647
+ num_attention_heads=1,
648
+ cross_attention_dim=1280,
649
+ output_scale_factor=1.0,
650
+ add_upsample=True,
651
+ dual_cross_attention=False,
652
+ use_linear_projection=False,
653
+ only_cross_attention=False,
654
+ upcast_attention=False,
655
+ num_views: int = 1,
656
+ cd_attention_last: bool = False,
657
+ cd_attention_mid: bool = False,
658
+ multiview_attention: bool = True,
659
+ sparse_mv_attention: bool = False,
660
+ selfattn_block: str = "custom",
661
+ mvcd_attention: bool=False,
662
+ use_dino: bool = False
663
+ ):
664
+ super().__init__()
665
+ resnets = []
666
+ attentions = []
667
+
668
+ self.has_cross_attention = True
669
+ self.num_attention_heads = num_attention_heads
670
+
671
+ if selfattn_block == "custom":
672
+ from .transformer_mv2d import TransformerMV2DModel
673
+ elif selfattn_block == "rowwise":
674
+ from .transformer_mv2d_rowwise import TransformerMV2DModel
675
+ elif selfattn_block == "self_rowwise":
676
+ from .transformer_mv2d_self_rowwise import TransformerMV2DModel
677
+ else:
678
+ raise NotImplementedError
679
+
680
+ for i in range(num_layers):
681
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
682
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
683
+
684
+ resnets.append(
685
+ ResnetBlock2D(
686
+ in_channels=resnet_in_channels + res_skip_channels,
687
+ out_channels=out_channels,
688
+ temb_channels=temb_channels,
689
+ eps=resnet_eps,
690
+ groups=resnet_groups,
691
+ dropout=dropout,
692
+ time_embedding_norm=resnet_time_scale_shift,
693
+ non_linearity=resnet_act_fn,
694
+ output_scale_factor=output_scale_factor,
695
+ pre_norm=resnet_pre_norm,
696
+ )
697
+ )
698
+ if not dual_cross_attention:
699
+ attentions.append(
700
+ TransformerMV2DModel(
701
+ num_attention_heads,
702
+ out_channels // num_attention_heads,
703
+ in_channels=out_channels,
704
+ num_layers=transformer_layers_per_block,
705
+ cross_attention_dim=cross_attention_dim,
706
+ norm_num_groups=resnet_groups,
707
+ use_linear_projection=use_linear_projection,
708
+ only_cross_attention=only_cross_attention,
709
+ upcast_attention=upcast_attention,
710
+ num_views=num_views,
711
+ cd_attention_last=cd_attention_last,
712
+ cd_attention_mid=cd_attention_mid,
713
+ multiview_attention=multiview_attention,
714
+ sparse_mv_attention=sparse_mv_attention,
715
+ mvcd_attention=mvcd_attention,
716
+ use_dino=use_dino
717
+ )
718
+ )
719
+ else:
720
+ raise NotImplementedError
721
+ self.attentions = nn.ModuleList(attentions)
722
+ self.resnets = nn.ModuleList(resnets)
723
+
724
+ if add_upsample:
725
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
726
+ else:
727
+ self.upsamplers = None
728
+
729
+ self.gradient_checkpointing = False
730
+
731
+ def forward(
732
+ self,
733
+ hidden_states: torch.FloatTensor,
734
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
735
+ temb: Optional[torch.FloatTensor] = None,
736
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
737
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
738
+ upsample_size: Optional[int] = None,
739
+ attention_mask: Optional[torch.FloatTensor] = None,
740
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
741
+ dino_feature: Optional[torch.FloatTensor] = None
742
+ ):
743
+ for resnet, attn in zip(self.resnets, self.attentions):
744
+ # pop res hidden states
745
+ res_hidden_states = res_hidden_states_tuple[-1]
746
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
747
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
748
+
749
+ if self.training and self.gradient_checkpointing:
750
+
751
+ def create_custom_forward(module, return_dict=None):
752
+ def custom_forward(*inputs):
753
+ if return_dict is not None:
754
+ return module(*inputs, return_dict=return_dict)
755
+ else:
756
+ return module(*inputs)
757
+
758
+ return custom_forward
759
+
760
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
761
+ hidden_states = torch.utils.checkpoint.checkpoint(
762
+ create_custom_forward(resnet),
763
+ hidden_states,
764
+ temb,
765
+ **ckpt_kwargs,
766
+ )
767
+ hidden_states = torch.utils.checkpoint.checkpoint(
768
+ create_custom_forward(attn, return_dict=False),
769
+ hidden_states,
770
+ encoder_hidden_states,
771
+ dino_feature,
772
+ None, # timestep
773
+ None, # class_labels
774
+ cross_attention_kwargs,
775
+ attention_mask,
776
+ encoder_attention_mask,
777
+ **ckpt_kwargs,
778
+ )[0]
779
+ else:
780
+ hidden_states = resnet(hidden_states, temb)
781
+ hidden_states = attn(
782
+ hidden_states,
783
+ encoder_hidden_states=encoder_hidden_states,
784
+ cross_attention_kwargs=cross_attention_kwargs,
785
+ attention_mask=attention_mask,
786
+ encoder_attention_mask=encoder_attention_mask,
787
+ dino_feature=dino_feature,
788
+ return_dict=False,
789
+ )[0]
790
+
791
+ if self.upsamplers is not None:
792
+ for upsampler in self.upsamplers:
793
+ hidden_states = upsampler(hidden_states, upsample_size)
794
+
795
+ return hidden_states
796
+
797
+
798
+ class CrossAttnDownBlockMV2D(nn.Module):
799
+ def __init__(
800
+ self,
801
+ in_channels: int,
802
+ out_channels: int,
803
+ temb_channels: int,
804
+ dropout: float = 0.0,
805
+ num_layers: int = 1,
806
+ transformer_layers_per_block: int = 1,
807
+ resnet_eps: float = 1e-6,
808
+ resnet_time_scale_shift: str = "default",
809
+ resnet_act_fn: str = "swish",
810
+ resnet_groups: int = 32,
811
+ resnet_pre_norm: bool = True,
812
+ num_attention_heads=1,
813
+ cross_attention_dim=1280,
814
+ output_scale_factor=1.0,
815
+ downsample_padding=1,
816
+ add_downsample=True,
817
+ dual_cross_attention=False,
818
+ use_linear_projection=False,
819
+ only_cross_attention=False,
820
+ upcast_attention=False,
821
+ num_views: int = 1,
822
+ cd_attention_last: bool = False,
823
+ cd_attention_mid: bool = False,
824
+ multiview_attention: bool = True,
825
+ sparse_mv_attention: bool = False,
826
+ selfattn_block: str = "custom",
827
+ mvcd_attention: bool=False,
828
+ use_dino: bool = False
829
+ ):
830
+ super().__init__()
831
+ resnets = []
832
+ attentions = []
833
+
834
+ self.has_cross_attention = True
835
+ self.num_attention_heads = num_attention_heads
836
+ if selfattn_block == "custom":
837
+ from .transformer_mv2d import TransformerMV2DModel
838
+ elif selfattn_block == "rowwise":
839
+ from .transformer_mv2d_rowwise import TransformerMV2DModel
840
+ elif selfattn_block == "self_rowwise":
841
+ from .transformer_mv2d_self_rowwise import TransformerMV2DModel
842
+ else:
843
+ raise NotImplementedError
844
+
845
+ for i in range(num_layers):
846
+ in_channels = in_channels if i == 0 else out_channels
847
+ resnets.append(
848
+ ResnetBlock2D(
849
+ in_channels=in_channels,
850
+ out_channels=out_channels,
851
+ temb_channels=temb_channels,
852
+ eps=resnet_eps,
853
+ groups=resnet_groups,
854
+ dropout=dropout,
855
+ time_embedding_norm=resnet_time_scale_shift,
856
+ non_linearity=resnet_act_fn,
857
+ output_scale_factor=output_scale_factor,
858
+ pre_norm=resnet_pre_norm,
859
+ )
860
+ )
861
+ if not dual_cross_attention:
862
+ attentions.append(
863
+ TransformerMV2DModel(
864
+ num_attention_heads,
865
+ out_channels // num_attention_heads,
866
+ in_channels=out_channels,
867
+ num_layers=transformer_layers_per_block,
868
+ cross_attention_dim=cross_attention_dim,
869
+ norm_num_groups=resnet_groups,
870
+ use_linear_projection=use_linear_projection,
871
+ only_cross_attention=only_cross_attention,
872
+ upcast_attention=upcast_attention,
873
+ num_views=num_views,
874
+ cd_attention_last=cd_attention_last,
875
+ cd_attention_mid=cd_attention_mid,
876
+ multiview_attention=multiview_attention,
877
+ sparse_mv_attention=sparse_mv_attention,
878
+ mvcd_attention=mvcd_attention,
879
+ use_dino=use_dino
880
+ )
881
+ )
882
+ else:
883
+ raise NotImplementedError
884
+ self.attentions = nn.ModuleList(attentions)
885
+ self.resnets = nn.ModuleList(resnets)
886
+
887
+ if add_downsample:
888
+ self.downsamplers = nn.ModuleList(
889
+ [
890
+ Downsample2D(
891
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
892
+ )
893
+ ]
894
+ )
895
+ else:
896
+ self.downsamplers = None
897
+
898
+ self.gradient_checkpointing = False
899
+
900
+ def forward(
901
+ self,
902
+ hidden_states: torch.FloatTensor,
903
+ temb: Optional[torch.FloatTensor] = None,
904
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
905
+ dino_feature: Optional[torch.FloatTensor] = None,
906
+ attention_mask: Optional[torch.FloatTensor] = None,
907
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
908
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
909
+ additional_residuals=None,
910
+ ):
911
+ output_states = ()
912
+
913
+ blocks = list(zip(self.resnets, self.attentions))
914
+
915
+ for i, (resnet, attn) in enumerate(blocks):
916
+ if self.training and self.gradient_checkpointing:
917
+
918
+ def create_custom_forward(module, return_dict=None):
919
+ def custom_forward(*inputs):
920
+ if return_dict is not None:
921
+ return module(*inputs, return_dict=return_dict)
922
+ else:
923
+ return module(*inputs)
924
+
925
+ return custom_forward
926
+
927
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
928
+ hidden_states = torch.utils.checkpoint.checkpoint(
929
+ create_custom_forward(resnet),
930
+ hidden_states,
931
+ temb,
932
+ **ckpt_kwargs,
933
+ )
934
+ hidden_states = torch.utils.checkpoint.checkpoint(
935
+ create_custom_forward(attn, return_dict=False),
936
+ hidden_states,
937
+ encoder_hidden_states,
938
+ dino_feature,
939
+ None, # timestep
940
+ None, # class_labels
941
+ cross_attention_kwargs,
942
+ attention_mask,
943
+ encoder_attention_mask,
944
+ **ckpt_kwargs,
945
+ )[0]
946
+ else:
947
+ hidden_states = resnet(hidden_states, temb)
948
+ hidden_states = attn(
949
+ hidden_states,
950
+ encoder_hidden_states=encoder_hidden_states,
951
+ dino_feature=dino_feature,
952
+ cross_attention_kwargs=cross_attention_kwargs,
953
+ attention_mask=attention_mask,
954
+ encoder_attention_mask=encoder_attention_mask,
955
+ return_dict=False,
956
+ )[0]
957
+
958
+ # apply additional residuals to the output of the last pair of resnet and attention blocks
959
+ if i == len(blocks) - 1 and additional_residuals is not None:
960
+ hidden_states = hidden_states + additional_residuals
961
+
962
+ output_states = output_states + (hidden_states,)
963
+
964
+ if self.downsamplers is not None:
965
+ for downsampler in self.downsamplers:
966
+ hidden_states = downsampler(hidden_states)
967
+
968
+ output_states = output_states + (hidden_states,)
969
+
970
+ return hidden_states, output_states
971
+
mvdiffusion/models/unet_mv2d_condition.py ADDED
@@ -0,0 +1,1686 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+ import os
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.utils.checkpoint
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.loaders import UNet2DConditionLoadersMixin
24
+ from diffusers.utils import BaseOutput, logging
25
+ from diffusers.models.activations import get_activation
26
+ from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
27
+ from diffusers.models.embeddings import (
28
+ GaussianFourierProjection,
29
+ ImageHintTimeEmbedding,
30
+ ImageProjection,
31
+ ImageTimeEmbedding,
32
+ TextImageProjection,
33
+ TextImageTimeEmbedding,
34
+ TextTimeEmbedding,
35
+ TimestepEmbedding,
36
+ Timesteps,
37
+ )
38
+ from diffusers.models.modeling_utils import ModelMixin, load_state_dict, _load_state_dict_into_model
39
+ from diffusers.models.unet_2d_blocks import (
40
+ CrossAttnDownBlock2D,
41
+ CrossAttnUpBlock2D,
42
+ DownBlock2D,
43
+ UNetMidBlock2DCrossAttn,
44
+ UNetMidBlock2DSimpleCrossAttn,
45
+ UpBlock2D,
46
+ )
47
+ from diffusers.utils import (
48
+ CONFIG_NAME,
49
+ FLAX_WEIGHTS_NAME,
50
+ SAFETENSORS_WEIGHTS_NAME,
51
+ WEIGHTS_NAME,
52
+ _add_variant,
53
+ _get_model_file,
54
+ deprecate,
55
+ is_torch_version,
56
+ logging,
57
+ )
58
+ from diffusers.utils.import_utils import is_accelerate_available
59
+ from diffusers.utils.hub_utils import HF_HUB_OFFLINE
60
+ from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
61
+ DIFFUSERS_CACHE = HUGGINGFACE_HUB_CACHE
62
+
63
+ from diffusers import __version__
64
+ from .unet_mv2d_blocks import (
65
+ CrossAttnDownBlockMV2D,
66
+ CrossAttnUpBlockMV2D,
67
+ UNetMidBlockMV2DCrossAttn,
68
+ get_down_block,
69
+ get_up_block,
70
+ )
71
+ from einops import rearrange, repeat
72
+
73
+ from diffusers import __version__
74
+ from mvdiffusion.models.unet_mv2d_blocks import (
75
+ CrossAttnDownBlockMV2D,
76
+ CrossAttnUpBlockMV2D,
77
+ UNetMidBlockMV2DCrossAttn,
78
+ get_down_block,
79
+ get_up_block,
80
+ )
81
+
82
+
83
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
84
+
85
+
86
+ @dataclass
87
+ class UNetMV2DConditionOutput(BaseOutput):
88
+ """
89
+ The output of [`UNet2DConditionModel`].
90
+
91
+ Args:
92
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
93
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
94
+ """
95
+
96
+ sample: torch.FloatTensor = None
97
+
98
+
99
+ class ResidualBlock(nn.Module):
100
+ def __init__(self, dim):
101
+ super(ResidualBlock, self).__init__()
102
+ self.linear1 = nn.Linear(dim, dim)
103
+ self.activation = nn.SiLU()
104
+ self.linear2 = nn.Linear(dim, dim)
105
+
106
+ def forward(self, x):
107
+ identity = x
108
+ out = self.linear1(x)
109
+ out = self.activation(out)
110
+ out = self.linear2(out)
111
+ out += identity
112
+ out = self.activation(out)
113
+ return out
114
+
115
+ class ResidualLiner(nn.Module):
116
+ def __init__(self, in_features, out_features, dim, act=None, num_block=1):
117
+ super(ResidualLiner, self).__init__()
118
+ self.linear_in = nn.Sequential(nn.Linear(in_features, dim), nn.SiLU())
119
+
120
+ blocks = nn.ModuleList()
121
+ for _ in range(num_block):
122
+ blocks.append(ResidualBlock(dim))
123
+ self.blocks = blocks
124
+
125
+ self.linear_out = nn.Linear(dim, out_features)
126
+ self.act = act
127
+
128
+ def forward(self, x):
129
+ out = self.linear_in(x)
130
+ for block in self.blocks:
131
+ out = block(out)
132
+ out = self.linear_out(out)
133
+ if self.act is not None:
134
+ out = self.act(out)
135
+ return out
136
+
137
+ class BasicConvBlock(nn.Module):
138
+ def __init__(self, in_channels, out_channels, stride=1):
139
+ super(BasicConvBlock, self).__init__()
140
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
141
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=in_channels, affine=True)
142
+ self.act = nn.SiLU()
143
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
144
+ self.norm2 = nn.GroupNorm(num_groups=8, num_channels=in_channels, affine=True)
145
+ self.downsample = nn.Sequential()
146
+ if stride != 1 or in_channels != out_channels:
147
+ self.downsample = nn.Sequential(
148
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
149
+ nn.GroupNorm(num_groups=8, num_channels=in_channels, affine=True)
150
+ )
151
+
152
+ def forward(self, x):
153
+ identity = x
154
+ out = self.conv1(x)
155
+ out = self.norm1(out)
156
+ out = self.act(out)
157
+ out = self.conv2(out)
158
+ out = self.norm2(out)
159
+ out += self.downsample(identity)
160
+ out = self.act(out)
161
+ return out
162
+
163
+ class UNetMV2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
164
+ r"""
165
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
166
+ shaped output.
167
+
168
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
169
+ for all models (such as downloading or saving).
170
+
171
+ Parameters:
172
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
173
+ Height and width of input/output sample.
174
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
175
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
176
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
177
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
178
+ Whether to flip the sin to cos in the time embedding.
179
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
180
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
181
+ The tuple of downsample blocks to use.
182
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
183
+ Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
184
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
185
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
186
+ The tuple of upsample blocks to use.
187
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
188
+ Whether to include self-attention in the basic transformer blocks, see
189
+ [`~models.attention.BasicTransformerBlock`].
190
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
191
+ The tuple of output channels for each block.
192
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
193
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
194
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
195
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
196
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
197
+ If `None`, normalization and activation layers is skipped in post-processing.
198
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
199
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
200
+ The dimension of the cross attention features.
201
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
202
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
203
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
204
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
205
+ encoder_hid_dim (`int`, *optional*, defaults to None):
206
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
207
+ dimension to `cross_attention_dim`.
208
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
209
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
210
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
211
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
212
+ num_attention_heads (`int`, *optional*):
213
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
214
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
215
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
216
+ class_embed_type (`str`, *optional*, defaults to `None`):
217
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
218
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
219
+ addition_embed_type (`str`, *optional*, defaults to `None`):
220
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
221
+ "text". "text" will use the `TextTimeEmbedding` layer.
222
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
223
+ Dimension for the timestep embeddings.
224
+ num_class_embeds (`int`, *optional*, defaults to `None`):
225
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
226
+ class conditioning with `class_embed_type` equal to `None`.
227
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
228
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
229
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
230
+ An optional override for the dimension of the projected time embedding.
231
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
232
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
233
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
234
+ timestep_post_act (`str`, *optional*, defaults to `None`):
235
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
236
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
237
+ The dimension of `cond_proj` layer in the timestep embedding.
238
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
239
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
240
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
241
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
242
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
243
+ embeddings with the class embeddings.
244
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
245
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
246
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
247
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
248
+ otherwise.
249
+ """
250
+
251
+ _supports_gradient_checkpointing = True
252
+
253
+ @register_to_config
254
+ def __init__(
255
+ self,
256
+ sample_size: Optional[int] = None,
257
+ in_channels: int = 4,
258
+ out_channels: int = 4,
259
+ center_input_sample: bool = False,
260
+ flip_sin_to_cos: bool = True,
261
+ freq_shift: int = 0,
262
+ down_block_types: Tuple[str] = (
263
+ "CrossAttnDownBlockMV2D",
264
+ "CrossAttnDownBlockMV2D",
265
+ "CrossAttnDownBlockMV2D",
266
+ "DownBlock2D",
267
+ ),
268
+ mid_block_type: Optional[str] = "UNetMidBlockMV2DCrossAttn",
269
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D"),
270
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
271
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
272
+ layers_per_block: Union[int, Tuple[int]] = 2,
273
+ downsample_padding: int = 1,
274
+ mid_block_scale_factor: float = 1,
275
+ act_fn: str = "silu",
276
+ norm_num_groups: Optional[int] = 32,
277
+ norm_eps: float = 1e-5,
278
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
279
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
280
+ encoder_hid_dim: Optional[int] = None,
281
+ encoder_hid_dim_type: Optional[str] = None,
282
+ attention_head_dim: Union[int, Tuple[int]] = 8,
283
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
284
+ dual_cross_attention: bool = False,
285
+ use_linear_projection: bool = False,
286
+ class_embed_type: Optional[str] = None,
287
+ addition_embed_type: Optional[str] = None,
288
+ addition_time_embed_dim: Optional[int] = None,
289
+ num_class_embeds: Optional[int] = None,
290
+ upcast_attention: bool = False,
291
+ resnet_time_scale_shift: str = "default",
292
+ resnet_skip_time_act: bool = False,
293
+ resnet_out_scale_factor: int = 1.0,
294
+ time_embedding_type: str = "positional",
295
+ time_embedding_dim: Optional[int] = None,
296
+ time_embedding_act_fn: Optional[str] = None,
297
+ timestep_post_act: Optional[str] = None,
298
+ time_cond_proj_dim: Optional[int] = None,
299
+ conv_in_kernel: int = 3,
300
+ conv_out_kernel: int = 3,
301
+ projection_class_embeddings_input_dim: Optional[int] = None,
302
+ projection_camera_embeddings_input_dim: Optional[int] = None,
303
+ class_embeddings_concat: bool = False,
304
+ mid_block_only_cross_attention: Optional[bool] = None,
305
+ cross_attention_norm: Optional[str] = None,
306
+ addition_embed_type_num_heads=64,
307
+ num_views: int = 1,
308
+ cd_attention_last: bool = False,
309
+ cd_attention_mid: bool = False,
310
+ multiview_attention: bool = True,
311
+ sparse_mv_attention: bool = False,
312
+ selfattn_block: str = "custom",
313
+ mvcd_attention: bool = False,
314
+ regress_elevation: bool = False,
315
+ regress_focal_length: bool = False,
316
+ num_regress_blocks: int = 4,
317
+ use_dino: bool = False,
318
+ addition_downsample: bool = False,
319
+ addition_channels: Optional[Tuple[int]] = (1280, 1280, 1280),
320
+ ):
321
+ super().__init__()
322
+
323
+ self.sample_size = sample_size
324
+ self.num_views = num_views
325
+ self.mvcd_attention = mvcd_attention
326
+ if num_attention_heads is not None:
327
+ raise ValueError(
328
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
329
+ )
330
+
331
+ # If `num_attention_heads` is not defined (which is the case for most models)
332
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
333
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
334
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
335
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
336
+ # which is why we correct for the naming here.
337
+ num_attention_heads = num_attention_heads or attention_head_dim
338
+
339
+ # Check inputs
340
+ if len(down_block_types) != len(up_block_types):
341
+ raise ValueError(
342
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
343
+ )
344
+
345
+ if len(block_out_channels) != len(down_block_types):
346
+ raise ValueError(
347
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
348
+ )
349
+
350
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
351
+ raise ValueError(
352
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
353
+ )
354
+
355
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
356
+ raise ValueError(
357
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
358
+ )
359
+
360
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
361
+ raise ValueError(
362
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
363
+ )
364
+
365
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
366
+ raise ValueError(
367
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
368
+ )
369
+
370
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
371
+ raise ValueError(
372
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
373
+ )
374
+
375
+ # input
376
+ conv_in_padding = (conv_in_kernel - 1) // 2
377
+ self.conv_in = nn.Conv2d(
378
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
379
+ )
380
+
381
+ # time
382
+ if time_embedding_type == "fourier":
383
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
384
+ if time_embed_dim % 2 != 0:
385
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
386
+ self.time_proj = GaussianFourierProjection(
387
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
388
+ )
389
+ timestep_input_dim = time_embed_dim
390
+ elif time_embedding_type == "positional":
391
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
392
+
393
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
394
+ timestep_input_dim = block_out_channels[0]
395
+ else:
396
+ raise ValueError(
397
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
398
+ )
399
+
400
+ self.time_embedding = TimestepEmbedding(
401
+ timestep_input_dim,
402
+ time_embed_dim,
403
+ act_fn=act_fn,
404
+ post_act_fn=timestep_post_act,
405
+ cond_proj_dim=time_cond_proj_dim,
406
+ )
407
+
408
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
409
+ encoder_hid_dim_type = "text_proj"
410
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
411
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
412
+
413
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
414
+ raise ValueError(
415
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
416
+ )
417
+
418
+ if encoder_hid_dim_type == "text_proj":
419
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
420
+ elif encoder_hid_dim_type == "text_image_proj":
421
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
422
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
423
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
424
+ self.encoder_hid_proj = TextImageProjection(
425
+ text_embed_dim=encoder_hid_dim,
426
+ image_embed_dim=cross_attention_dim,
427
+ cross_attention_dim=cross_attention_dim,
428
+ )
429
+ elif encoder_hid_dim_type == "image_proj":
430
+ # Kandinsky 2.2
431
+ self.encoder_hid_proj = ImageProjection(
432
+ image_embed_dim=encoder_hid_dim,
433
+ cross_attention_dim=cross_attention_dim,
434
+ )
435
+ elif encoder_hid_dim_type is not None:
436
+ raise ValueError(
437
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
438
+ )
439
+ else:
440
+ self.encoder_hid_proj = None
441
+
442
+ # class embedding
443
+ if class_embed_type is None and num_class_embeds is not None:
444
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
445
+ elif class_embed_type == "timestep":
446
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
447
+ elif class_embed_type == "identity":
448
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
449
+ elif class_embed_type == "projection":
450
+ if projection_class_embeddings_input_dim is None:
451
+ raise ValueError(
452
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
453
+ )
454
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
455
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
456
+ # 2. it projects from an arbitrary input dimension.
457
+ #
458
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
459
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
460
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
461
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
462
+ elif class_embed_type == "simple_projection":
463
+ if projection_class_embeddings_input_dim is None:
464
+ raise ValueError(
465
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
466
+ )
467
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
468
+ else:
469
+ self.class_embedding = None
470
+
471
+ if addition_embed_type == "text":
472
+ if encoder_hid_dim is not None:
473
+ text_time_embedding_from_dim = encoder_hid_dim
474
+ else:
475
+ text_time_embedding_from_dim = cross_attention_dim
476
+
477
+ self.add_embedding = TextTimeEmbedding(
478
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
479
+ )
480
+ elif addition_embed_type == "text_image":
481
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
482
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
483
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
484
+ self.add_embedding = TextImageTimeEmbedding(
485
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
486
+ )
487
+ elif addition_embed_type == "text_time":
488
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
489
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
490
+ elif addition_embed_type == "image":
491
+ # Kandinsky 2.2
492
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
493
+ elif addition_embed_type == "image_hint":
494
+ # Kandinsky 2.2 ControlNet
495
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
496
+ elif addition_embed_type is not None:
497
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
498
+
499
+ if time_embedding_act_fn is None:
500
+ self.time_embed_act = None
501
+ else:
502
+ self.time_embed_act = get_activation(time_embedding_act_fn)
503
+
504
+ self.down_blocks = nn.ModuleList([])
505
+ self.up_blocks = nn.ModuleList([])
506
+
507
+ if isinstance(only_cross_attention, bool):
508
+ if mid_block_only_cross_attention is None:
509
+ mid_block_only_cross_attention = only_cross_attention
510
+
511
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
512
+
513
+ if mid_block_only_cross_attention is None:
514
+ mid_block_only_cross_attention = False
515
+
516
+ if isinstance(num_attention_heads, int):
517
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
518
+
519
+ if isinstance(attention_head_dim, int):
520
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
521
+
522
+ if isinstance(cross_attention_dim, int):
523
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
524
+
525
+ if isinstance(layers_per_block, int):
526
+ layers_per_block = [layers_per_block] * len(down_block_types)
527
+
528
+ if isinstance(transformer_layers_per_block, int):
529
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
530
+
531
+ if class_embeddings_concat:
532
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
533
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
534
+ # regular time embeddings
535
+ blocks_time_embed_dim = time_embed_dim * 2
536
+ else:
537
+ blocks_time_embed_dim = time_embed_dim
538
+
539
+ # down
540
+ output_channel = block_out_channels[0]
541
+ for i, down_block_type in enumerate(down_block_types):
542
+ input_channel = output_channel
543
+ output_channel = block_out_channels[i]
544
+ is_final_block = i == len(block_out_channels) - 1
545
+
546
+ down_block = get_down_block(
547
+ down_block_type,
548
+ num_layers=layers_per_block[i],
549
+ transformer_layers_per_block=transformer_layers_per_block[i],
550
+ in_channels=input_channel,
551
+ out_channels=output_channel,
552
+ temb_channels=blocks_time_embed_dim,
553
+ add_downsample=not is_final_block,
554
+ resnet_eps=norm_eps,
555
+ resnet_act_fn=act_fn,
556
+ resnet_groups=norm_num_groups,
557
+ cross_attention_dim=cross_attention_dim[i],
558
+ num_attention_heads=num_attention_heads[i],
559
+ downsample_padding=downsample_padding,
560
+ dual_cross_attention=dual_cross_attention,
561
+ use_linear_projection=use_linear_projection,
562
+ only_cross_attention=only_cross_attention[i],
563
+ upcast_attention=upcast_attention,
564
+ resnet_time_scale_shift=resnet_time_scale_shift,
565
+ resnet_skip_time_act=resnet_skip_time_act,
566
+ resnet_out_scale_factor=resnet_out_scale_factor,
567
+ cross_attention_norm=cross_attention_norm,
568
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
569
+ num_views=num_views,
570
+ cd_attention_last=cd_attention_last,
571
+ cd_attention_mid=cd_attention_mid,
572
+ multiview_attention=multiview_attention,
573
+ sparse_mv_attention=sparse_mv_attention,
574
+ selfattn_block=selfattn_block,
575
+ mvcd_attention=mvcd_attention,
576
+ use_dino=use_dino
577
+ )
578
+ self.down_blocks.append(down_block)
579
+
580
+ # mid
581
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
582
+ self.mid_block = UNetMidBlock2DCrossAttn(
583
+ transformer_layers_per_block=transformer_layers_per_block[-1],
584
+ in_channels=block_out_channels[-1],
585
+ temb_channels=blocks_time_embed_dim,
586
+ resnet_eps=norm_eps,
587
+ resnet_act_fn=act_fn,
588
+ output_scale_factor=mid_block_scale_factor,
589
+ resnet_time_scale_shift=resnet_time_scale_shift,
590
+ cross_attention_dim=cross_attention_dim[-1],
591
+ num_attention_heads=num_attention_heads[-1],
592
+ resnet_groups=norm_num_groups,
593
+ dual_cross_attention=dual_cross_attention,
594
+ use_linear_projection=use_linear_projection,
595
+ upcast_attention=upcast_attention,
596
+ )
597
+ # custom MV2D attention block
598
+ elif mid_block_type == "UNetMidBlockMV2DCrossAttn":
599
+ self.mid_block = UNetMidBlockMV2DCrossAttn(
600
+ transformer_layers_per_block=transformer_layers_per_block[-1],
601
+ in_channels=block_out_channels[-1],
602
+ temb_channels=blocks_time_embed_dim,
603
+ resnet_eps=norm_eps,
604
+ resnet_act_fn=act_fn,
605
+ output_scale_factor=mid_block_scale_factor,
606
+ resnet_time_scale_shift=resnet_time_scale_shift,
607
+ cross_attention_dim=cross_attention_dim[-1],
608
+ num_attention_heads=num_attention_heads[-1],
609
+ resnet_groups=norm_num_groups,
610
+ dual_cross_attention=dual_cross_attention,
611
+ use_linear_projection=use_linear_projection,
612
+ upcast_attention=upcast_attention,
613
+ num_views=num_views,
614
+ cd_attention_last=cd_attention_last,
615
+ cd_attention_mid=cd_attention_mid,
616
+ multiview_attention=multiview_attention,
617
+ sparse_mv_attention=sparse_mv_attention,
618
+ selfattn_block=selfattn_block,
619
+ mvcd_attention=mvcd_attention,
620
+ use_dino=use_dino
621
+ )
622
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
623
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
624
+ in_channels=block_out_channels[-1],
625
+ temb_channels=blocks_time_embed_dim,
626
+ resnet_eps=norm_eps,
627
+ resnet_act_fn=act_fn,
628
+ output_scale_factor=mid_block_scale_factor,
629
+ cross_attention_dim=cross_attention_dim[-1],
630
+ attention_head_dim=attention_head_dim[-1],
631
+ resnet_groups=norm_num_groups,
632
+ resnet_time_scale_shift=resnet_time_scale_shift,
633
+ skip_time_act=resnet_skip_time_act,
634
+ only_cross_attention=mid_block_only_cross_attention,
635
+ cross_attention_norm=cross_attention_norm,
636
+ )
637
+ elif mid_block_type is None:
638
+ self.mid_block = None
639
+ else:
640
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
641
+
642
+ self.addition_downsample = addition_downsample
643
+ if self.addition_downsample:
644
+ inc = block_out_channels[-1]
645
+ self.downsample = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
646
+ self.conv_block = nn.ModuleList()
647
+ self.conv_block.append(BasicConvBlock(inc, addition_channels[0], stride=1))
648
+ for dim_ in addition_channels[1:-1]:
649
+ self.conv_block.append(BasicConvBlock(dim_, dim_, stride=1))
650
+ self.conv_block.append(BasicConvBlock(dim_, inc))
651
+ self.addition_conv_out = nn.Conv2d(inc, inc, kernel_size=1, bias=False)
652
+ nn.init.zeros_(self.addition_conv_out.weight.data)
653
+ self.addition_act_out = nn.SiLU()
654
+ self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
655
+
656
+ self.regress_elevation = regress_elevation
657
+ self.regress_focal_length = regress_focal_length
658
+ if regress_elevation or regress_focal_length:
659
+ self.pool = nn.AdaptiveAvgPool2d((1, 1))
660
+ self.camera_embedding = TimestepEmbedding(projection_camera_embeddings_input_dim, time_embed_dim=time_embed_dim)
661
+
662
+ regress_in_dim = block_out_channels[-1]*2 if mvcd_attention else block_out_channels
663
+
664
+ if regress_elevation:
665
+ self.elevation_regressor = ResidualLiner(regress_in_dim, 1, 1280, act=None, num_block=num_regress_blocks)
666
+ if regress_focal_length:
667
+ self.focal_regressor = ResidualLiner(regress_in_dim, 1, 1280, act=None, num_block=num_regress_blocks)
668
+ '''
669
+ self.regress_elevation = regress_elevation
670
+ self.regress_focal_length = regress_focal_length
671
+ if regress_elevation and (not regress_focal_length):
672
+ print("Regressing elevation")
673
+ cam_dim = 1
674
+ elif regress_focal_length and (not regress_elevation):
675
+ print("Regressing focal length")
676
+ cam_dim = 6
677
+ elif regress_elevation and regress_focal_length:
678
+ print("Regressing both elevation and focal length")
679
+ cam_dim = 7
680
+ else:
681
+ cam_dim = 0
682
+ assert projection_camera_embeddings_input_dim == 2*cam_dim, "projection_camera_embeddings_input_dim should be 2*cam_dim"
683
+ if regress_elevation or regress_focal_length:
684
+ self.elevation_regressor = nn.ModuleList([
685
+ nn.Linear(block_out_channels[-1], 1280),
686
+ nn.SiLU(),
687
+ nn.Linear(1280, 1280),
688
+ nn.SiLU(),
689
+ nn.Linear(1280, cam_dim)
690
+ ])
691
+ self.pool = nn.AdaptiveAvgPool2d((1, 1))
692
+ self.focal_act = nn.Softmax(dim=-1)
693
+ self.camera_embedding = TimestepEmbedding(projection_camera_embeddings_input_dim, time_embed_dim=time_embed_dim)
694
+ '''
695
+
696
+ # count how many layers upsample the images
697
+ self.num_upsamplers = 0
698
+
699
+ # up
700
+ reversed_block_out_channels = list(reversed(block_out_channels))
701
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
702
+ reversed_layers_per_block = list(reversed(layers_per_block))
703
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
704
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
705
+ only_cross_attention = list(reversed(only_cross_attention))
706
+
707
+ output_channel = reversed_block_out_channels[0]
708
+ for i, up_block_type in enumerate(up_block_types):
709
+ is_final_block = i == len(block_out_channels) - 1
710
+
711
+ prev_output_channel = output_channel
712
+ output_channel = reversed_block_out_channels[i]
713
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
714
+
715
+ # add upsample block for all BUT final layer
716
+ if not is_final_block:
717
+ add_upsample = True
718
+ self.num_upsamplers += 1
719
+ else:
720
+ add_upsample = False
721
+
722
+ up_block = get_up_block(
723
+ up_block_type,
724
+ num_layers=reversed_layers_per_block[i] + 1,
725
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
726
+ in_channels=input_channel,
727
+ out_channels=output_channel,
728
+ prev_output_channel=prev_output_channel,
729
+ temb_channels=blocks_time_embed_dim,
730
+ add_upsample=add_upsample,
731
+ resnet_eps=norm_eps,
732
+ resnet_act_fn=act_fn,
733
+ resnet_groups=norm_num_groups,
734
+ cross_attention_dim=reversed_cross_attention_dim[i],
735
+ num_attention_heads=reversed_num_attention_heads[i],
736
+ dual_cross_attention=dual_cross_attention,
737
+ use_linear_projection=use_linear_projection,
738
+ only_cross_attention=only_cross_attention[i],
739
+ upcast_attention=upcast_attention,
740
+ resnet_time_scale_shift=resnet_time_scale_shift,
741
+ resnet_skip_time_act=resnet_skip_time_act,
742
+ resnet_out_scale_factor=resnet_out_scale_factor,
743
+ cross_attention_norm=cross_attention_norm,
744
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
745
+ num_views=num_views,
746
+ cd_attention_last=cd_attention_last,
747
+ cd_attention_mid=cd_attention_mid,
748
+ multiview_attention=multiview_attention,
749
+ sparse_mv_attention=sparse_mv_attention,
750
+ selfattn_block=selfattn_block,
751
+ mvcd_attention=mvcd_attention,
752
+ use_dino=use_dino
753
+ )
754
+ self.up_blocks.append(up_block)
755
+ prev_output_channel = output_channel
756
+
757
+ # out
758
+ if norm_num_groups is not None:
759
+ self.conv_norm_out = nn.GroupNorm(
760
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
761
+ )
762
+
763
+ self.conv_act = get_activation(act_fn)
764
+
765
+ else:
766
+ self.conv_norm_out = None
767
+ self.conv_act = None
768
+
769
+ conv_out_padding = (conv_out_kernel - 1) // 2
770
+ self.conv_out = nn.Conv2d(
771
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
772
+ )
773
+
774
+ @property
775
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
776
+ r"""
777
+ Returns:
778
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
779
+ indexed by its weight name.
780
+ """
781
+ # set recursively
782
+ processors = {}
783
+
784
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
785
+ if hasattr(module, "set_processor"):
786
+ processors[f"{name}.processor"] = module.processor
787
+
788
+ for sub_name, child in module.named_children():
789
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
790
+
791
+ return processors
792
+
793
+ for name, module in self.named_children():
794
+ fn_recursive_add_processors(name, module, processors)
795
+
796
+ return processors
797
+
798
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
799
+ r"""
800
+ Sets the attention processor to use to compute attention.
801
+
802
+ Parameters:
803
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
804
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
805
+ for **all** `Attention` layers.
806
+
807
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
808
+ processor. This is strongly recommended when setting trainable attention processors.
809
+
810
+ """
811
+ count = len(self.attn_processors.keys())
812
+
813
+ if isinstance(processor, dict) and len(processor) != count:
814
+ raise ValueError(
815
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
816
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
817
+ )
818
+
819
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
820
+ if hasattr(module, "set_processor"):
821
+ if not isinstance(processor, dict):
822
+ module.set_processor(processor)
823
+ else:
824
+ module.set_processor(processor.pop(f"{name}.processor"))
825
+
826
+ for sub_name, child in module.named_children():
827
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
828
+
829
+ for name, module in self.named_children():
830
+ fn_recursive_attn_processor(name, module, processor)
831
+
832
+ def set_default_attn_processor(self):
833
+ """
834
+ Disables custom attention processors and sets the default attention implementation.
835
+ """
836
+ self.set_attn_processor(AttnProcessor())
837
+
838
+ def set_attention_slice(self, slice_size):
839
+ r"""
840
+ Enable sliced attention computation.
841
+
842
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
843
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
844
+
845
+ Args:
846
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
847
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
848
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
849
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
850
+ must be a multiple of `slice_size`.
851
+ """
852
+ sliceable_head_dims = []
853
+
854
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
855
+ if hasattr(module, "set_attention_slice"):
856
+ sliceable_head_dims.append(module.sliceable_head_dim)
857
+
858
+ for child in module.children():
859
+ fn_recursive_retrieve_sliceable_dims(child)
860
+
861
+ # retrieve number of attention layers
862
+ for module in self.children():
863
+ fn_recursive_retrieve_sliceable_dims(module)
864
+
865
+ num_sliceable_layers = len(sliceable_head_dims)
866
+
867
+ if slice_size == "auto":
868
+ # half the attention head size is usually a good trade-off between
869
+ # speed and memory
870
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
871
+ elif slice_size == "max":
872
+ # make smallest slice possible
873
+ slice_size = num_sliceable_layers * [1]
874
+
875
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
876
+
877
+ if len(slice_size) != len(sliceable_head_dims):
878
+ raise ValueError(
879
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
880
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
881
+ )
882
+
883
+ for i in range(len(slice_size)):
884
+ size = slice_size[i]
885
+ dim = sliceable_head_dims[i]
886
+ if size is not None and size > dim:
887
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
888
+
889
+ # Recursively walk through all the children.
890
+ # Any children which exposes the set_attention_slice method
891
+ # gets the message
892
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
893
+ if hasattr(module, "set_attention_slice"):
894
+ module.set_attention_slice(slice_size.pop())
895
+
896
+ for child in module.children():
897
+ fn_recursive_set_attention_slice(child, slice_size)
898
+
899
+ reversed_slice_size = list(reversed(slice_size))
900
+ for module in self.children():
901
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
902
+
903
+ def _set_gradient_checkpointing(self, module, value=False):
904
+ if isinstance(module, (CrossAttnDownBlock2D, CrossAttnDownBlockMV2D, DownBlock2D, CrossAttnUpBlock2D, CrossAttnUpBlockMV2D, UpBlock2D)):
905
+ module.gradient_checkpointing = value
906
+
907
+ def forward(
908
+ self,
909
+ sample: torch.FloatTensor,
910
+ timestep: Union[torch.Tensor, float, int],
911
+ encoder_hidden_states: torch.Tensor,
912
+ class_labels: Optional[torch.Tensor] = None,
913
+ timestep_cond: Optional[torch.Tensor] = None,
914
+ attention_mask: Optional[torch.Tensor] = None,
915
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
916
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
917
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
918
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
919
+ encoder_attention_mask: Optional[torch.Tensor] = None,
920
+ dino_feature: Optional[torch.Tensor] = None,
921
+ return_dict: bool = True,
922
+ vis_max_min: bool = False,
923
+ ) -> Union[UNetMV2DConditionOutput, Tuple]:
924
+ r"""
925
+ The [`UNet2DConditionModel`] forward method.
926
+
927
+ Args:
928
+ sample (`torch.FloatTensor`):
929
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
930
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
931
+ encoder_hidden_states (`torch.FloatTensor`):
932
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
933
+ encoder_attention_mask (`torch.Tensor`):
934
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
935
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
936
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
937
+ return_dict (`bool`, *optional*, defaults to `True`):
938
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
939
+ tuple.
940
+ cross_attention_kwargs (`dict`, *optional*):
941
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
942
+ added_cond_kwargs: (`dict`, *optional*):
943
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
944
+ are passed along to the UNet blocks.
945
+
946
+ Returns:
947
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
948
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
949
+ a `tuple` is returned where the first element is the sample tensor.
950
+ """
951
+ record_max_min = {}
952
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
953
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
954
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
955
+ # on the fly if necessary.
956
+ default_overall_up_factor = 2**self.num_upsamplers
957
+
958
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
959
+ forward_upsample_size = False
960
+ upsample_size = None
961
+
962
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
963
+ logger.info("Forward upsample size to force interpolation output size.")
964
+ forward_upsample_size = True
965
+
966
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
967
+ # expects mask of shape:
968
+ # [batch, key_tokens]
969
+ # adds singleton query_tokens dimension:
970
+ # [batch, 1, key_tokens]
971
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
972
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
973
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
974
+ if attention_mask is not None:
975
+ # assume that mask is expressed as:
976
+ # (1 = keep, 0 = discard)
977
+ # convert mask into a bias that can be added to attention scores:
978
+ # (keep = +0, discard = -10000.0)
979
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
980
+ attention_mask = attention_mask.unsqueeze(1)
981
+
982
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
983
+ if encoder_attention_mask is not None:
984
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
985
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
986
+
987
+ # 0. center input if necessary
988
+ if self.config.center_input_sample:
989
+ sample = 2 * sample - 1.0
990
+ # 1. time
991
+ timesteps = timestep
992
+ if not torch.is_tensor(timesteps):
993
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
994
+ # This would be a good case for the `match` statement (Python 3.10+)
995
+ is_mps = sample.device.type == "mps"
996
+ if isinstance(timestep, float):
997
+ dtype = torch.float32 if is_mps else torch.float64
998
+ else:
999
+ dtype = torch.int32 if is_mps else torch.int64
1000
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
1001
+ elif len(timesteps.shape) == 0:
1002
+ timesteps = timesteps[None].to(sample.device)
1003
+
1004
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1005
+ timesteps = timesteps.expand(sample.shape[0])
1006
+
1007
+ t_emb = self.time_proj(timesteps)
1008
+
1009
+ # `Timesteps` does not contain any weights and will always return f32 tensors
1010
+ # but time_embedding might actually be running in fp16. so we need to cast here.
1011
+ # there might be better ways to encapsulate this.
1012
+ t_emb = t_emb.to(dtype=sample.dtype)
1013
+
1014
+ emb = self.time_embedding(t_emb, timestep_cond)
1015
+ aug_emb = None
1016
+ if self.class_embedding is not None:
1017
+ if class_labels is None:
1018
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
1019
+
1020
+ if self.config.class_embed_type == "timestep":
1021
+ class_labels = self.time_proj(class_labels)
1022
+
1023
+ # `Timesteps` does not contain any weights and will always return f32 tensors
1024
+ # there might be better ways to encapsulate this.
1025
+ class_labels = class_labels.to(dtype=sample.dtype)
1026
+
1027
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
1028
+ if self.config.class_embeddings_concat:
1029
+ emb = torch.cat([emb, class_emb], dim=-1)
1030
+ else:
1031
+ emb = emb + class_emb
1032
+
1033
+ if self.config.addition_embed_type == "text":
1034
+ aug_emb = self.add_embedding(encoder_hidden_states)
1035
+ elif self.config.addition_embed_type == "text_image":
1036
+ # Kandinsky 2.1 - style
1037
+ if "image_embeds" not in added_cond_kwargs:
1038
+ raise ValueError(
1039
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1040
+ )
1041
+
1042
+ image_embs = added_cond_kwargs.get("image_embeds")
1043
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
1044
+ aug_emb = self.add_embedding(text_embs, image_embs)
1045
+ elif self.config.addition_embed_type == "text_time":
1046
+ # SDXL - style
1047
+ if "text_embeds" not in added_cond_kwargs:
1048
+ raise ValueError(
1049
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
1050
+ )
1051
+ text_embeds = added_cond_kwargs.get("text_embeds")
1052
+ if "time_ids" not in added_cond_kwargs:
1053
+ raise ValueError(
1054
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
1055
+ )
1056
+ time_ids = added_cond_kwargs.get("time_ids")
1057
+ time_embeds = self.add_time_proj(time_ids.flatten())
1058
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
1059
+
1060
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
1061
+ add_embeds = add_embeds.to(emb.dtype)
1062
+ aug_emb = self.add_embedding(add_embeds)
1063
+ elif self.config.addition_embed_type == "image":
1064
+ # Kandinsky 2.2 - style
1065
+ if "image_embeds" not in added_cond_kwargs:
1066
+ raise ValueError(
1067
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1068
+ )
1069
+ image_embs = added_cond_kwargs.get("image_embeds")
1070
+ aug_emb = self.add_embedding(image_embs)
1071
+ elif self.config.addition_embed_type == "image_hint":
1072
+ # Kandinsky 2.2 - style
1073
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
1074
+ raise ValueError(
1075
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1076
+ )
1077
+ image_embs = added_cond_kwargs.get("image_embeds")
1078
+ hint = added_cond_kwargs.get("hint")
1079
+ aug_emb, hint = self.add_embedding(image_embs, hint)
1080
+ sample = torch.cat([sample, hint], dim=1)
1081
+
1082
+ emb = emb + aug_emb if aug_emb is not None else emb
1083
+ emb_pre_act = emb
1084
+ if self.time_embed_act is not None:
1085
+ emb = self.time_embed_act(emb)
1086
+
1087
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
1088
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1089
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
1090
+ # Kadinsky 2.1 - style
1091
+ if "image_embeds" not in added_cond_kwargs:
1092
+ raise ValueError(
1093
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1094
+ )
1095
+
1096
+ image_embeds = added_cond_kwargs.get("image_embeds")
1097
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
1098
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
1099
+ # Kandinsky 2.2 - style
1100
+ if "image_embeds" not in added_cond_kwargs:
1101
+ raise ValueError(
1102
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1103
+ )
1104
+ image_embeds = added_cond_kwargs.get("image_embeds")
1105
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1106
+ # 2. pre-process
1107
+ sample = self.conv_in(sample)
1108
+ # 3. down
1109
+
1110
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
1111
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
1112
+
1113
+ down_block_res_samples = (sample,)
1114
+ for i, downsample_block in enumerate(self.down_blocks):
1115
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
1116
+ # For t2i-adapter CrossAttnDownBlock2D
1117
+ additional_residuals = {}
1118
+ if is_adapter and len(down_block_additional_residuals) > 0:
1119
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
1120
+
1121
+ sample, res_samples = downsample_block(
1122
+ hidden_states=sample,
1123
+ temb=emb,
1124
+ encoder_hidden_states=encoder_hidden_states,
1125
+ dino_feature=dino_feature,
1126
+ attention_mask=attention_mask,
1127
+ cross_attention_kwargs=cross_attention_kwargs,
1128
+ encoder_attention_mask=encoder_attention_mask,
1129
+ **additional_residuals,
1130
+ )
1131
+ else:
1132
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
1133
+
1134
+ if is_adapter and len(down_block_additional_residuals) > 0:
1135
+ sample += down_block_additional_residuals.pop(0)
1136
+
1137
+ down_block_res_samples += res_samples
1138
+
1139
+ if is_controlnet:
1140
+ new_down_block_res_samples = ()
1141
+
1142
+ for down_block_res_sample, down_block_additional_residual in zip(
1143
+ down_block_res_samples, down_block_additional_residuals
1144
+ ):
1145
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
1146
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1147
+
1148
+ down_block_res_samples = new_down_block_res_samples
1149
+
1150
+ if self.addition_downsample:
1151
+ global_sample = sample
1152
+ global_sample = self.downsample(global_sample)
1153
+ for layer in self.conv_block:
1154
+ global_sample = layer(global_sample)
1155
+ global_sample = self.addition_act_out(self.addition_conv_out(global_sample))
1156
+ global_sample = self.upsample(global_sample)
1157
+ # 4. mid
1158
+ if self.mid_block is not None:
1159
+ sample = self.mid_block(
1160
+ sample,
1161
+ emb,
1162
+ encoder_hidden_states=encoder_hidden_states,
1163
+ dino_feature=dino_feature,
1164
+ attention_mask=attention_mask,
1165
+ cross_attention_kwargs=cross_attention_kwargs,
1166
+ encoder_attention_mask=encoder_attention_mask,
1167
+ )
1168
+ # 4.1 regress elevation and focal length
1169
+ # # predict elevation -> embed -> projection -> add to time emb
1170
+ if self.regress_elevation or self.regress_focal_length:
1171
+ pool_embeds = self.pool(sample.detach()).squeeze(-1).squeeze(-1) # (2B, C)
1172
+ if self.mvcd_attention:
1173
+ pool_embeds_normal, pool_embeds_color = torch.chunk(pool_embeds, 2, dim=0)
1174
+ pool_embeds = torch.cat([pool_embeds_normal, pool_embeds_color], dim=-1) # (B, 2C)
1175
+ pose_pred = []
1176
+ if self.regress_elevation:
1177
+ ele_pred = self.elevation_regressor(pool_embeds)
1178
+ ele_pred = rearrange(ele_pred, '(b v) c -> b v c', v=self.num_views)
1179
+ ele_pred = torch.mean(ele_pred, dim=1)
1180
+ pose_pred.append(ele_pred) # b, c
1181
+
1182
+ if self.regress_focal_length:
1183
+ focal_pred = self.focal_regressor(pool_embeds)
1184
+ focal_pred = rearrange(focal_pred, '(b v) c -> b v c', v=self.num_views)
1185
+ focal_pred = torch.mean(focal_pred, dim=1)
1186
+ pose_pred.append(focal_pred)
1187
+ pose_pred = torch.cat(pose_pred, dim=-1)
1188
+ # 'e_de_da_sincos', (B, 2)
1189
+ pose_embeds = torch.cat([
1190
+ torch.sin(pose_pred),
1191
+ torch.cos(pose_pred)
1192
+ ], dim=-1)
1193
+ pose_embeds = self.camera_embedding(pose_embeds)
1194
+ pose_embeds = torch.repeat_interleave(pose_embeds, self.num_views, 0)
1195
+ if self.mvcd_attention:
1196
+ pose_embeds = torch.cat([pose_embeds,] * 2, dim=0)
1197
+
1198
+ emb = pose_embeds + emb_pre_act
1199
+ if self.time_embed_act is not None:
1200
+ emb = self.time_embed_act(emb)
1201
+
1202
+ if is_controlnet:
1203
+ sample = sample + mid_block_additional_residual
1204
+
1205
+ if self.addition_downsample:
1206
+ sample = sample + global_sample
1207
+
1208
+ # 5. up
1209
+ for i, upsample_block in enumerate(self.up_blocks):
1210
+ is_final_block = i == len(self.up_blocks) - 1
1211
+
1212
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1213
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1214
+
1215
+ # if we have not reached the final block and need to forward the
1216
+ # upsample size, we do it here
1217
+ if not is_final_block and forward_upsample_size:
1218
+ upsample_size = down_block_res_samples[-1].shape[2:]
1219
+
1220
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1221
+ sample = upsample_block(
1222
+ hidden_states=sample,
1223
+ temb=emb,
1224
+ res_hidden_states_tuple=res_samples,
1225
+ encoder_hidden_states=encoder_hidden_states,
1226
+ dino_feature=dino_feature,
1227
+ cross_attention_kwargs=cross_attention_kwargs,
1228
+ upsample_size=upsample_size,
1229
+ attention_mask=attention_mask,
1230
+ encoder_attention_mask=encoder_attention_mask,
1231
+ )
1232
+ else:
1233
+ sample = upsample_block(
1234
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
1235
+ )
1236
+ if torch.isnan(sample).any() or torch.isinf(sample).any():
1237
+ print("NAN in sample, stop training.")
1238
+ exit()
1239
+ # 6. post-process
1240
+ if self.conv_norm_out:
1241
+ sample = self.conv_norm_out(sample)
1242
+ sample = self.conv_act(sample)
1243
+ sample = self.conv_out(sample)
1244
+ if not return_dict:
1245
+ return (sample, pose_pred)
1246
+ if self.regress_elevation or self.regress_focal_length:
1247
+ return UNetMV2DConditionOutput(sample=sample), pose_pred
1248
+ else:
1249
+ return UNetMV2DConditionOutput(sample=sample)
1250
+
1251
+
1252
+ @classmethod
1253
+ def from_pretrained_2d(
1254
+ cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
1255
+ camera_embedding_type: str, num_views: int, sample_size: int,
1256
+ zero_init_conv_in: bool = True, zero_init_camera_projection: bool = False,
1257
+ projection_camera_embeddings_input_dim: int=2,
1258
+ cd_attention_last: bool = False, num_regress_blocks: int = 4,
1259
+ cd_attention_mid: bool = False, multiview_attention: bool = True,
1260
+ sparse_mv_attention: bool = False, selfattn_block: str = 'custom', mvcd_attention: bool = False,
1261
+ in_channels: int = 8, out_channels: int = 4, unclip: bool = False, regress_elevation: bool = False, regress_focal_length: bool = False,
1262
+ init_mvattn_with_selfattn: bool= False, use_dino: bool = False, addition_downsample: bool = False,
1263
+ **kwargs
1264
+ ):
1265
+ r"""
1266
+ Instantiate a pretrained PyTorch model from a pretrained model configuration.
1267
+
1268
+ The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
1269
+ train the model, set it back in training mode with `model.train()`.
1270
+
1271
+ Parameters:
1272
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
1273
+ Can be either:
1274
+
1275
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
1276
+ the Hub.
1277
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
1278
+ with [`~ModelMixin.save_pretrained`].
1279
+
1280
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
1281
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
1282
+ is not used.
1283
+ torch_dtype (`str` or `torch.dtype`, *optional*):
1284
+ Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
1285
+ dtype is automatically derived from the model's weights.
1286
+ force_download (`bool`, *optional*, defaults to `False`):
1287
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
1288
+ cached versions if they exist.
1289
+ resume_download (`bool`, *optional*, defaults to `False`):
1290
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
1291
+ incompletely downloaded files are deleted.
1292
+ proxies (`Dict[str, str]`, *optional*):
1293
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
1294
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1295
+ output_loading_info (`bool`, *optional*, defaults to `False`):
1296
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
1297
+ local_files_only(`bool`, *optional*, defaults to `False`):
1298
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
1299
+ won't be downloaded from the Hub.
1300
+ use_auth_token (`str` or *bool*, *optional*):
1301
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
1302
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
1303
+ revision (`str`, *optional*, defaults to `"main"`):
1304
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
1305
+ allowed by Git.
1306
+ from_flax (`bool`, *optional*, defaults to `False`):
1307
+ Load the model weights from a Flax checkpoint save file.
1308
+ subfolder (`str`, *optional*, defaults to `""`):
1309
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
1310
+ mirror (`str`, *optional*):
1311
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
1312
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
1313
+ information.
1314
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
1315
+ A map that specifies where each submodule should go. It doesn't need to be defined for each
1316
+ parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
1317
+ same device.
1318
+
1319
+ Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
1320
+ more information about each option see [designing a device
1321
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
1322
+ max_memory (`Dict`, *optional*):
1323
+ A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
1324
+ each GPU and the available CPU RAM if unset.
1325
+ offload_folder (`str` or `os.PathLike`, *optional*):
1326
+ The path to offload weights if `device_map` contains the value `"disk"`.
1327
+ offload_state_dict (`bool`, *optional*):
1328
+ If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
1329
+ the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
1330
+ when there is some disk offload.
1331
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
1332
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
1333
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
1334
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
1335
+ argument to `True` will raise an error.
1336
+ variant (`str`, *optional*):
1337
+ Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
1338
+ loading `from_flax`.
1339
+ use_safetensors (`bool`, *optional*, defaults to `None`):
1340
+ If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
1341
+ `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
1342
+ weights. If set to `False`, `safetensors` weights are not loaded.
1343
+
1344
+ <Tip>
1345
+
1346
+ To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
1347
+ `huggingface-cli login`. You can also activate the special
1348
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
1349
+ firewalled environment.
1350
+
1351
+ </Tip>
1352
+
1353
+ Example:
1354
+
1355
+ ```py
1356
+ from diffusers import UNet2DConditionModel
1357
+
1358
+ unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
1359
+ ```
1360
+
1361
+ If you get the error message below, you need to finetune the weights for your downstream task:
1362
+
1363
+ ```bash
1364
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
1365
+ - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
1366
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
1367
+ ```
1368
+ """
1369
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
1370
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
1371
+ force_download = kwargs.pop("force_download", False)
1372
+ from_flax = kwargs.pop("from_flax", False)
1373
+ resume_download = kwargs.pop("resume_download", False)
1374
+ proxies = kwargs.pop("proxies", None)
1375
+ output_loading_info = kwargs.pop("output_loading_info", False)
1376
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
1377
+ use_auth_token = kwargs.pop("use_auth_token", None)
1378
+ revision = kwargs.pop("revision", None)
1379
+ torch_dtype = kwargs.pop("torch_dtype", None)
1380
+ subfolder = kwargs.pop("subfolder", None)
1381
+ device_map = kwargs.pop("device_map", None)
1382
+ max_memory = kwargs.pop("max_memory", None)
1383
+ offload_folder = kwargs.pop("offload_folder", None)
1384
+ offload_state_dict = kwargs.pop("offload_state_dict", False)
1385
+ variant = kwargs.pop("variant", None)
1386
+ use_safetensors = kwargs.pop("use_safetensors", None)
1387
+
1388
+ if use_safetensors:
1389
+ raise ValueError(
1390
+ "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
1391
+ )
1392
+
1393
+ allow_pickle = False
1394
+ if use_safetensors is None:
1395
+ use_safetensors = True
1396
+ allow_pickle = True
1397
+
1398
+ if device_map is not None and not is_accelerate_available():
1399
+ raise NotImplementedError(
1400
+ "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
1401
+ " `device_map=None`. You can install accelerate with `pip install accelerate`."
1402
+ )
1403
+
1404
+ # Check if we can handle device_map and dispatching the weights
1405
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
1406
+ raise NotImplementedError(
1407
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
1408
+ " `device_map=None`."
1409
+ )
1410
+
1411
+ # Load config if we don't provide a configuration
1412
+ config_path = pretrained_model_name_or_path
1413
+
1414
+ user_agent = {
1415
+ "diffusers": __version__,
1416
+ "file_type": "model",
1417
+ "framework": "pytorch",
1418
+ }
1419
+
1420
+ # load config
1421
+ config, unused_kwargs, commit_hash = cls.load_config(
1422
+ config_path,
1423
+ cache_dir=cache_dir,
1424
+ return_unused_kwargs=True,
1425
+ return_commit_hash=True,
1426
+ force_download=force_download,
1427
+ resume_download=resume_download,
1428
+ proxies=proxies,
1429
+ local_files_only=local_files_only,
1430
+ use_auth_token=use_auth_token,
1431
+ revision=revision,
1432
+ subfolder=subfolder,
1433
+ device_map=device_map,
1434
+ max_memory=max_memory,
1435
+ offload_folder=offload_folder,
1436
+ offload_state_dict=offload_state_dict,
1437
+ user_agent=user_agent,
1438
+ **kwargs,
1439
+ )
1440
+
1441
+ # modify config
1442
+ config["_class_name"] = cls.__name__
1443
+ config['in_channels'] = in_channels
1444
+ config['out_channels'] = out_channels
1445
+ config['sample_size'] = sample_size # training resolution
1446
+ config['num_views'] = num_views
1447
+ config['cd_attention_last'] = cd_attention_last
1448
+ config['cd_attention_mid'] = cd_attention_mid
1449
+ config['multiview_attention'] = multiview_attention
1450
+ config['sparse_mv_attention'] = sparse_mv_attention
1451
+ config['selfattn_block'] = selfattn_block
1452
+ config['mvcd_attention'] = mvcd_attention
1453
+ config["down_block_types"] = [
1454
+ "CrossAttnDownBlockMV2D",
1455
+ "CrossAttnDownBlockMV2D",
1456
+ "CrossAttnDownBlockMV2D",
1457
+ "DownBlock2D"
1458
+ ]
1459
+ config['mid_block_type'] = "UNetMidBlockMV2DCrossAttn"
1460
+ config["up_block_types"] = [
1461
+ "UpBlock2D",
1462
+ "CrossAttnUpBlockMV2D",
1463
+ "CrossAttnUpBlockMV2D",
1464
+ "CrossAttnUpBlockMV2D"
1465
+ ]
1466
+
1467
+
1468
+ config['regress_elevation'] = regress_elevation # true
1469
+ config['regress_focal_length'] = regress_focal_length # true
1470
+ config['projection_camera_embeddings_input_dim'] = projection_camera_embeddings_input_dim # 2 for elevation and 10 for focal_length
1471
+ config['use_dino'] = use_dino
1472
+ config['num_regress_blocks'] = num_regress_blocks
1473
+ config['addition_downsample'] = addition_downsample
1474
+ # load model
1475
+ model_file = None
1476
+ if from_flax:
1477
+ raise NotImplementedError
1478
+ else:
1479
+ if use_safetensors:
1480
+ try:
1481
+ model_file = _get_model_file(
1482
+ pretrained_model_name_or_path,
1483
+ weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
1484
+ cache_dir=cache_dir,
1485
+ force_download=force_download,
1486
+ resume_download=resume_download,
1487
+ proxies=proxies,
1488
+ local_files_only=local_files_only,
1489
+ use_auth_token=use_auth_token,
1490
+ revision=revision,
1491
+ subfolder=subfolder,
1492
+ user_agent=user_agent,
1493
+ commit_hash=commit_hash,
1494
+ )
1495
+ except IOError as e:
1496
+ if not allow_pickle:
1497
+ raise e
1498
+ pass
1499
+ if model_file is None:
1500
+ model_file = _get_model_file(
1501
+ pretrained_model_name_or_path,
1502
+ weights_name=_add_variant(WEIGHTS_NAME, variant),
1503
+ cache_dir=cache_dir,
1504
+ force_download=force_download,
1505
+ resume_download=resume_download,
1506
+ proxies=proxies,
1507
+ local_files_only=local_files_only,
1508
+ use_auth_token=use_auth_token,
1509
+ revision=revision,
1510
+ subfolder=subfolder,
1511
+ user_agent=user_agent,
1512
+ commit_hash=commit_hash,
1513
+ )
1514
+
1515
+ model = cls.from_config(config, **unused_kwargs)
1516
+ import copy
1517
+ state_dict_pretrain = load_state_dict(model_file, variant=variant)
1518
+ state_dict = copy.deepcopy(state_dict_pretrain)
1519
+
1520
+ if init_mvattn_with_selfattn:
1521
+ for key in state_dict_pretrain:
1522
+ if 'attn1' in key:
1523
+ key_mv = key.replace('attn1', 'attn_mv')
1524
+ state_dict[key_mv] = state_dict_pretrain[key]
1525
+ if 'to_out.0.weight' in key:
1526
+ nn.init.zeros_(state_dict[key_mv].data)
1527
+ if 'transformer_blocks' in key and 'norm1' in key: # in case that initialize the norm layer in resnet block
1528
+ key_mv = key.replace('norm1', 'norm_mv')
1529
+ state_dict[key_mv] = state_dict_pretrain[key]
1530
+ # del state_dict_pretrain
1531
+
1532
+ model._convert_deprecated_attention_blocks(state_dict)
1533
+
1534
+ conv_in_weight = state_dict['conv_in.weight']
1535
+ conv_out_weight = state_dict['conv_out.weight']
1536
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model_2d(
1537
+ model,
1538
+ state_dict,
1539
+ model_file,
1540
+ pretrained_model_name_or_path,
1541
+ ignore_mismatched_sizes=True,
1542
+ )
1543
+ if any([key == 'conv_in.weight' for key, _, _ in mismatched_keys]):
1544
+ # initialize from the original SD structure
1545
+ model.conv_in.weight.data[:,:4] = conv_in_weight
1546
+
1547
+ # whether to place all zero to new layers?
1548
+ if zero_init_conv_in:
1549
+ model.conv_in.weight.data[:,4:] = 0.
1550
+
1551
+ if any([key == 'conv_out.weight' for key, _, _ in mismatched_keys]):
1552
+ # initialize from the original SD structure
1553
+ model.conv_out.weight.data[:,:4] = conv_out_weight
1554
+ if out_channels == 8: # copy for the last 4 channels
1555
+ model.conv_out.weight.data[:, 4:] = conv_out_weight
1556
+
1557
+ if zero_init_camera_projection: # true
1558
+ params = [p for p in model.camera_embedding.parameters()]
1559
+ torch.nn.init.zeros_(params[-1].data)
1560
+
1561
+ loading_info = {
1562
+ "missing_keys": missing_keys,
1563
+ "unexpected_keys": unexpected_keys,
1564
+ "mismatched_keys": mismatched_keys,
1565
+ "error_msgs": error_msgs,
1566
+ }
1567
+
1568
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
1569
+ raise ValueError(
1570
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
1571
+ )
1572
+ elif torch_dtype is not None:
1573
+ model = model.to(torch_dtype)
1574
+
1575
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
1576
+
1577
+ # Set model in evaluation mode to deactivate DropOut modules by default
1578
+ model.eval()
1579
+ if output_loading_info:
1580
+ return model, loading_info
1581
+ return model
1582
+
1583
+ @classmethod
1584
+ def _load_pretrained_model_2d(
1585
+ cls,
1586
+ model,
1587
+ state_dict,
1588
+ resolved_archive_file,
1589
+ pretrained_model_name_or_path,
1590
+ ignore_mismatched_sizes=False,
1591
+ ):
1592
+ # Retrieve missing & unexpected_keys
1593
+ model_state_dict = model.state_dict()
1594
+ loaded_keys = list(state_dict.keys())
1595
+
1596
+ expected_keys = list(model_state_dict.keys())
1597
+
1598
+ original_loaded_keys = loaded_keys
1599
+
1600
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
1601
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
1602
+
1603
+ # Make sure we are able to load base models as well as derived models (with heads)
1604
+ model_to_load = model
1605
+
1606
+ def _find_mismatched_keys(
1607
+ state_dict,
1608
+ model_state_dict,
1609
+ loaded_keys,
1610
+ ignore_mismatched_sizes,
1611
+ ):
1612
+ mismatched_keys = []
1613
+ if ignore_mismatched_sizes:
1614
+ for checkpoint_key in loaded_keys:
1615
+ model_key = checkpoint_key
1616
+
1617
+ if (
1618
+ model_key in model_state_dict
1619
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
1620
+ ):
1621
+ mismatched_keys.append(
1622
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
1623
+ )
1624
+ del state_dict[checkpoint_key]
1625
+ return mismatched_keys
1626
+
1627
+ if state_dict is not None:
1628
+ # Whole checkpoint
1629
+ mismatched_keys = _find_mismatched_keys(
1630
+ state_dict,
1631
+ model_state_dict,
1632
+ original_loaded_keys,
1633
+ ignore_mismatched_sizes,
1634
+ )
1635
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
1636
+
1637
+ if len(error_msgs) > 0:
1638
+ error_msg = "\n\t".join(error_msgs)
1639
+ if "size mismatch" in error_msg:
1640
+ error_msg += (
1641
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
1642
+ )
1643
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
1644
+
1645
+ if len(unexpected_keys) > 0:
1646
+ logger.warning(
1647
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
1648
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
1649
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
1650
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
1651
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
1652
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
1653
+ " identical (initializing a BertForSequenceClassification model from a"
1654
+ " BertForSequenceClassification model)."
1655
+ )
1656
+ else:
1657
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
1658
+ if len(missing_keys) > 0:
1659
+ logger.warning(
1660
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
1661
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
1662
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
1663
+ )
1664
+ elif len(mismatched_keys) == 0:
1665
+ logger.info(
1666
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
1667
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
1668
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
1669
+ " without further training."
1670
+ )
1671
+ if len(mismatched_keys) > 0:
1672
+ mismatched_warning = "\n".join(
1673
+ [
1674
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
1675
+ for key, shape1, shape2 in mismatched_keys
1676
+ ]
1677
+ )
1678
+ logger.warning(
1679
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
1680
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
1681
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
1682
+ " able to use it for predictions and inference."
1683
+ )
1684
+
1685
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
1686
+
mvdiffusion/pipelines/pipeline_mvdiffusion_unclip.py ADDED
@@ -0,0 +1,633 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import warnings
3
+ from typing import Callable, List, Optional, Union, Dict, Any
4
+ import PIL
5
+ import torch
6
+ from packaging import version
7
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, CLIPFeatureExtractor, CLIPTokenizer, CLIPTextModel
8
+ from diffusers.utils.import_utils import is_accelerate_available
9
+ from diffusers.configuration_utils import FrozenDict
10
+ from diffusers.image_processor import VaeImageProcessor
11
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
12
+ from diffusers.models.embeddings import get_timestep_embedding
13
+ from diffusers.schedulers import KarrasDiffusionSchedulers
14
+ from diffusers.utils import deprecate, logging
15
+ from diffusers.utils.torch_utils import randn_tensor
16
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
17
+ from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
18
+ import os
19
+ import torchvision.transforms.functional as TF
20
+ from einops import rearrange
21
+ logger = logging.get_logger(__name__)
22
+
23
+ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline):
24
+ """
25
+ Pipeline for text-guided image to image generation using stable unCLIP.
26
+
27
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
28
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
29
+
30
+ Args:
31
+ feature_extractor ([`CLIPFeatureExtractor`]):
32
+ Feature extractor for image pre-processing before being encoded.
33
+ image_encoder ([`CLIPVisionModelWithProjection`]):
34
+ CLIP vision model for encoding images.
35
+ image_normalizer ([`StableUnCLIPImageNormalizer`]):
36
+ Used to normalize the predicted image embeddings before the noise is applied and un-normalize the image
37
+ embeddings after the noise has been applied.
38
+ image_noising_scheduler ([`KarrasDiffusionSchedulers`]):
39
+ Noise schedule for adding noise to the predicted image embeddings. The amount of noise to add is determined
40
+ by `noise_level` in `StableUnCLIPPipeline.__call__`.
41
+ tokenizer (`CLIPTokenizer`):
42
+ Tokenizer of class
43
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
44
+ text_encoder ([`CLIPTextModel`]):
45
+ Frozen text-encoder.
46
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
47
+ scheduler ([`KarrasDiffusionSchedulers`]):
48
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
49
+ vae ([`AutoencoderKL`]):
50
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
51
+ """
52
+ # image encoding components
53
+ feature_extractor: CLIPFeatureExtractor
54
+ image_encoder: CLIPVisionModelWithProjection
55
+ # image noising components
56
+ image_normalizer: StableUnCLIPImageNormalizer
57
+ image_noising_scheduler: KarrasDiffusionSchedulers
58
+ # regular denoising components
59
+ tokenizer: CLIPTokenizer
60
+ text_encoder: CLIPTextModel
61
+ unet: UNet2DConditionModel
62
+ scheduler: KarrasDiffusionSchedulers
63
+ vae: AutoencoderKL
64
+
65
+ def __init__(
66
+ self,
67
+ # image encoding components
68
+ feature_extractor: CLIPFeatureExtractor,
69
+ image_encoder: CLIPVisionModelWithProjection,
70
+ # image noising components
71
+ image_normalizer: StableUnCLIPImageNormalizer,
72
+ image_noising_scheduler: KarrasDiffusionSchedulers,
73
+ # regular denoising components
74
+ tokenizer: CLIPTokenizer,
75
+ text_encoder: CLIPTextModel,
76
+ unet: UNet2DConditionModel,
77
+ scheduler: KarrasDiffusionSchedulers,
78
+ # vae
79
+ vae: AutoencoderKL,
80
+ num_views: int = 4,
81
+ ):
82
+ super().__init__()
83
+
84
+ self.register_modules(
85
+ feature_extractor=feature_extractor,
86
+ image_encoder=image_encoder,
87
+ image_normalizer=image_normalizer,
88
+ image_noising_scheduler=image_noising_scheduler,
89
+ tokenizer=tokenizer,
90
+ text_encoder=text_encoder,
91
+ unet=unet,
92
+ scheduler=scheduler,
93
+ vae=vae,
94
+ )
95
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
96
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
97
+ self.num_views: int = num_views
98
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
99
+ def enable_vae_slicing(self):
100
+ r"""
101
+ Enable sliced VAE decoding.
102
+
103
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
104
+ steps. This is useful to save some memory and allow larger batch sizes.
105
+ """
106
+ self.vae.enable_slicing()
107
+
108
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
109
+ def disable_vae_slicing(self):
110
+ r"""
111
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
112
+ computing decoding in one step.
113
+ """
114
+ self.vae.disable_slicing()
115
+
116
+ def enable_sequential_cpu_offload(self, gpu_id=0):
117
+ r"""
118
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
119
+ models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
120
+ when their specific submodule has its `forward` method called.
121
+ """
122
+ if is_accelerate_available():
123
+ from accelerate import cpu_offload
124
+ else:
125
+ raise ImportError("Please install accelerate via `pip install accelerate`")
126
+
127
+ device = torch.device(f"cuda:{gpu_id}")
128
+
129
+ # TODO: self.image_normalizer.{scale,unscale} are not covered by the offload hooks, so they fails if added to the list
130
+ models = [
131
+ self.image_encoder,
132
+ self.text_encoder,
133
+ self.unet,
134
+ self.vae,
135
+ ]
136
+ for cpu_offloaded_model in models:
137
+ if cpu_offloaded_model is not None:
138
+ cpu_offload(cpu_offloaded_model, device)
139
+
140
+ @property
141
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
142
+ def _execution_device(self):
143
+ r"""
144
+ Returns the device on which the pipeline's models will be executed. After calling
145
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
146
+ hooks.
147
+ """
148
+ if not hasattr(self.unet, "_hf_hook"):
149
+ return self.device
150
+ for module in self.unet.modules():
151
+ if (
152
+ hasattr(module, "_hf_hook")
153
+ and hasattr(module._hf_hook, "execution_device")
154
+ and module._hf_hook.execution_device is not None
155
+ ):
156
+ return torch.device(module._hf_hook.execution_device)
157
+ return self.device
158
+
159
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
160
+ def _encode_prompt(
161
+ self,
162
+ prompt,
163
+ device,
164
+ num_images_per_prompt,
165
+ do_classifier_free_guidance,
166
+ negative_prompt=None,
167
+ prompt_embeds: Optional[torch.FloatTensor] = None,
168
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
169
+ lora_scale: Optional[float] = None,
170
+ ):
171
+ r"""
172
+ Encodes the prompt into text encoder hidden states.
173
+
174
+ Args:
175
+ prompt (`str` or `List[str]`, *optional*):
176
+ prompt to be encoded
177
+ device: (`torch.device`):
178
+ torch device
179
+ num_images_per_prompt (`int`):
180
+ number of images that should be generated per prompt
181
+ do_classifier_free_guidance (`bool`):
182
+ whether to use classifier free guidance or not
183
+ negative_prompt (`str` or `List[str]`, *optional*):
184
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
185
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
186
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
187
+ prompt_embeds (`torch.FloatTensor`, *optional*):
188
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
189
+ provided, text embeddings will be generated from `prompt` input argument.
190
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
191
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
192
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
193
+ argument.
194
+ """
195
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
196
+
197
+ if do_classifier_free_guidance:
198
+ # For classifier free guidance, we need to do two forward passes.
199
+ # Here we concatenate the unconditional and text embeddings into a single batch
200
+ # to avoid doing two forward passes
201
+ normal_prompt_embeds, color_prompt_embeds = torch.chunk(prompt_embeds, 2, dim=0)
202
+
203
+ prompt_embeds = torch.cat([normal_prompt_embeds, normal_prompt_embeds, color_prompt_embeds, color_prompt_embeds], 0)
204
+
205
+ return prompt_embeds
206
+
207
+ def _encode_image(
208
+ self,
209
+ image_pil,
210
+ device,
211
+ num_images_per_prompt,
212
+ do_classifier_free_guidance,
213
+ noise_level: int=0,
214
+ generator: Optional[torch.Generator] = None
215
+ ):
216
+ dtype = next(self.image_encoder.parameters()).dtype
217
+ # ______________________________clip image embedding______________________________
218
+ image = self.feature_extractor(images=image_pil, return_tensors="pt").pixel_values
219
+ image = image.to(device=device, dtype=dtype)
220
+ image_embeds = self.image_encoder(image).image_embeds
221
+
222
+ image_embeds = self.noise_image_embeddings(
223
+ image_embeds=image_embeds,
224
+ noise_level=noise_level,
225
+ generator=generator,
226
+ )
227
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
228
+ # image_embeds = image_embeds.unsqueeze(1)
229
+ # note: the condition input is same
230
+ image_embeds = image_embeds.repeat(num_images_per_prompt, 1)
231
+
232
+ if do_classifier_free_guidance:
233
+ normal_image_embeds, color_image_embeds = torch.chunk(image_embeds, 2, dim=0)
234
+ negative_prompt_embeds = torch.zeros_like(normal_image_embeds)
235
+
236
+ # For classifier free guidance, we need to do two forward passes.
237
+ # Here we concatenate the unconditional and text embeddings into a single batch
238
+ # to avoid doing two forward passes
239
+ image_embeds = torch.cat([negative_prompt_embeds, normal_image_embeds, negative_prompt_embeds, color_image_embeds], 0)
240
+
241
+ # _____________________________vae input latents__________________________________________________
242
+ image_pt = torch.stack([TF.to_tensor(img) for img in image_pil], dim=0).to(dtype=self.vae.dtype, device=device)
243
+ image_pt = image_pt * 2.0 - 1.0
244
+ image_latents = self.vae.encode(image_pt).latent_dist.mode() * self.vae.config.scaling_factor
245
+ # Note: repeat differently from official pipelines
246
+ image_latents = image_latents.repeat(num_images_per_prompt, 1, 1, 1)
247
+
248
+ if do_classifier_free_guidance:
249
+ normal_image_latents, color_image_latents = torch.chunk(image_latents, 2, dim=0)
250
+ image_latents = torch.cat([torch.zeros_like(normal_image_latents), normal_image_latents,
251
+ torch.zeros_like(color_image_latents), color_image_latents], 0)
252
+
253
+ return image_embeds, image_latents
254
+
255
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
256
+ def decode_latents(self, latents):
257
+ latents = 1 / self.vae.config.scaling_factor * latents
258
+ image = self.vae.decode(latents).sample
259
+ image = (image / 2 + 0.5).clamp(0, 1)
260
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
261
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
262
+ return image
263
+
264
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
265
+ def prepare_extra_step_kwargs(self, generator, eta):
266
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
267
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
268
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
269
+ # and should be between [0, 1]
270
+
271
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
272
+ extra_step_kwargs = {}
273
+ if accepts_eta:
274
+ extra_step_kwargs["eta"] = eta
275
+
276
+ # check if the scheduler accepts generator
277
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
278
+ if accepts_generator:
279
+ extra_step_kwargs["generator"] = generator
280
+ return extra_step_kwargs
281
+
282
+ def check_inputs(
283
+ self,
284
+ prompt,
285
+ image,
286
+ height,
287
+ width,
288
+ callback_steps,
289
+ noise_level,
290
+ ):
291
+ if height % 8 != 0 or width % 8 != 0:
292
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
293
+
294
+ if (callback_steps is None) or (
295
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
296
+ ):
297
+ raise ValueError(
298
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
299
+ f" {type(callback_steps)}."
300
+ )
301
+
302
+ if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
303
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
304
+
305
+
306
+ if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps:
307
+ raise ValueError(
308
+ f"`noise_level` must be between 0 and {self.image_noising_scheduler.config.num_train_timesteps - 1}, inclusive."
309
+ )
310
+
311
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
312
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
313
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
314
+ if isinstance(generator, list) and len(generator) != batch_size:
315
+ raise ValueError(
316
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
317
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
318
+ )
319
+
320
+ if latents is None:
321
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
322
+ else:
323
+ latents = latents.to(device)
324
+
325
+ # scale the initial noise by the standard deviation required by the scheduler
326
+ latents = latents * self.scheduler.init_noise_sigma
327
+ return latents
328
+
329
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_unclip.StableUnCLIPPipeline.noise_image_embeddings
330
+ def noise_image_embeddings(
331
+ self,
332
+ image_embeds: torch.Tensor,
333
+ noise_level: int,
334
+ noise: Optional[torch.FloatTensor] = None,
335
+ generator: Optional[torch.Generator] = None,
336
+ ):
337
+ """
338
+ Add noise to the image embeddings. The amount of noise is controlled by a `noise_level` input. A higher
339
+ `noise_level` increases the variance in the final un-noised images.
340
+
341
+ The noise is applied in two ways
342
+ 1. A noise schedule is applied directly to the embeddings
343
+ 2. A vector of sinusoidal time embeddings are appended to the output.
344
+
345
+ In both cases, the amount of noise is controlled by the same `noise_level`.
346
+
347
+ The embeddings are normalized before the noise is applied and un-normalized after the noise is applied.
348
+ """
349
+ if noise is None:
350
+ noise = randn_tensor(
351
+ image_embeds.shape, generator=generator, device=image_embeds.device, dtype=image_embeds.dtype
352
+ )
353
+
354
+ noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device)
355
+
356
+ image_embeds = self.image_normalizer.scale(image_embeds)
357
+
358
+ image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise)
359
+
360
+ image_embeds = self.image_normalizer.unscale(image_embeds)
361
+
362
+ noise_level = get_timestep_embedding(
363
+ timesteps=noise_level, embedding_dim=image_embeds.shape[-1], flip_sin_to_cos=True, downscale_freq_shift=0
364
+ )
365
+
366
+ # `get_timestep_embeddings` does not contain any weights and will always return f32 tensors,
367
+ # but we might actually be running in fp16. so we need to cast here.
368
+ # there might be better ways to encapsulate this.
369
+ noise_level = noise_level.to(image_embeds.dtype)
370
+
371
+ image_embeds = torch.cat((image_embeds, noise_level), 1)
372
+
373
+ return image_embeds
374
+
375
+ @torch.no_grad()
376
+ # @replace_example_docstring(EXAMPLE_DOC_STRING)
377
+ def __call__(
378
+ self,
379
+ image: Union[torch.FloatTensor, PIL.Image.Image],
380
+ prompt: Union[str, List[str]],
381
+ prompt_embeds: torch.FloatTensor = None,
382
+ dino_feature: torch.FloatTensor = None,
383
+ height: Optional[int] = None,
384
+ width: Optional[int] = None,
385
+ num_inference_steps: int = 20,
386
+ guidance_scale: float = 10,
387
+ negative_prompt: Optional[Union[str, List[str]]] = None,
388
+ num_images_per_prompt: Optional[int] = 1,
389
+ eta: float = 0.0,
390
+ generator: Optional[torch.Generator] = None,
391
+ latents: Optional[torch.FloatTensor] = None,
392
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
393
+ output_type: Optional[str] = "pil",
394
+ return_dict: bool = True,
395
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
396
+ callback_steps: int = 1,
397
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
398
+ noise_level: int = 0,
399
+ image_embeds: Optional[torch.FloatTensor] = None,
400
+ return_elevation_focal: Optional[bool] = False,
401
+ gt_img_in: Optional[torch.FloatTensor] = None,
402
+ ):
403
+ r"""
404
+ Function invoked when calling the pipeline for generation.
405
+
406
+ Args:
407
+ prompt (`str` or `List[str]`, *optional*):
408
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
409
+ instead.
410
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
411
+ `Image`, or tensor representing an image batch. The image will be encoded to its CLIP embedding which
412
+ the unet will be conditioned on. Note that the image is _not_ encoded by the vae and then used as the
413
+ latents in the denoising process such as in the standard stable diffusion text guided image variation
414
+ process.
415
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
416
+ The height in pixels of the generated image.
417
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
418
+ The width in pixels of the generated image.
419
+ num_inference_steps (`int`, *optional*, defaults to 20):
420
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
421
+ expense of slower inference.
422
+ guidance_scale (`float`, *optional*, defaults to 10.0):
423
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
424
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
425
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
426
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
427
+ usually at the expense of lower image quality.
428
+ negative_prompt (`str` or `List[str]`, *optional*):
429
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
430
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
431
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
432
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
433
+ The number of images to generate per prompt.
434
+ eta (`float`, *optional*, defaults to 0.0):
435
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
436
+ [`schedulers.DDIMScheduler`], will be ignored for others.
437
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
438
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
439
+ to make generation deterministic.
440
+ latents (`torch.FloatTensor`, *optional*):
441
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
442
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
443
+ tensor will ge generated by sampling using the supplied random `generator`.
444
+ prompt_embeds (`torch.FloatTensor`, *optional*):
445
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
446
+ provided, text embeddings will be generated from `prompt` input argument.
447
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
448
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
449
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
450
+ argument.
451
+ output_type (`str`, *optional*, defaults to `"pil"`):
452
+ The output format of the generate image. Choose between
453
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
454
+ return_dict (`bool`, *optional*, defaults to `True`):
455
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
456
+ plain tuple.
457
+ callback (`Callable`, *optional*):
458
+ A function that will be called every `callback_steps` steps during inference. The function will be
459
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
460
+ callback_steps (`int`, *optional*, defaults to 1):
461
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
462
+ called at every step.
463
+ cross_attention_kwargs (`dict`, *optional*):
464
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
465
+ `self.processor` in
466
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
467
+ noise_level (`int`, *optional*, defaults to `0`):
468
+ The amount of noise to add to the image embeddings. A higher `noise_level` increases the variance in
469
+ the final un-noised images. See `StableUnCLIPPipeline.noise_image_embeddings` for details.
470
+ image_embeds (`torch.FloatTensor`, *optional*):
471
+ Pre-generated CLIP embeddings to condition the unet on. Note that these are not latents to be used in
472
+ the denoising process. If you want to provide pre-generated latents, pass them to `__call__` as
473
+ `latents`.
474
+
475
+ Examples:
476
+
477
+ Returns:
478
+ [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput`] if `return_dict` is
479
+ True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images.
480
+ """
481
+ # 0. Default height and width to unet
482
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
483
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
484
+
485
+ # 1. Check inputs. Raise error if not correct
486
+ self.check_inputs(
487
+ prompt=prompt,
488
+ image=image,
489
+ height=height,
490
+ width=width,
491
+ callback_steps=callback_steps,
492
+ noise_level=noise_level
493
+ )
494
+
495
+ # 2. Define call parameters
496
+ if isinstance(image, list):
497
+ batch_size = len(image)
498
+ elif isinstance(image, torch.Tensor):
499
+ batch_size = image.shape[0]
500
+ assert batch_size >= self.num_views and batch_size % self.num_views == 0
501
+ elif isinstance(image, PIL.Image.Image):
502
+ image = [image]*self.num_views*2
503
+ batch_size = self.num_views*2
504
+
505
+ if isinstance(prompt, str):
506
+ prompt = [prompt] * self.num_views * 2
507
+
508
+ device = self._execution_device
509
+
510
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
511
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
512
+ # corresponds to doing no classifier free guidance.
513
+ do_classifier_free_guidance = guidance_scale != 1.0
514
+
515
+ # 3. Encode input prompt
516
+ text_encoder_lora_scale = (
517
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
518
+ )
519
+ prompt_embeds = self._encode_prompt(
520
+ prompt=prompt,
521
+ device=device,
522
+ num_images_per_prompt=num_images_per_prompt,
523
+ do_classifier_free_guidance=do_classifier_free_guidance,
524
+ negative_prompt=negative_prompt,
525
+ prompt_embeds=prompt_embeds,
526
+ negative_prompt_embeds=negative_prompt_embeds,
527
+ lora_scale=text_encoder_lora_scale,
528
+ )
529
+
530
+
531
+ # 4. Encoder input image
532
+ if isinstance(image, list):
533
+ image_pil = image
534
+ elif isinstance(image, torch.Tensor):
535
+ image_pil = [TF.to_pil_image(image[i]) for i in range(image.shape[0])]
536
+ noise_level = torch.tensor([noise_level], device=device)
537
+ image_embeds, image_latents = self._encode_image(
538
+ image_pil=image_pil,
539
+ device=device,
540
+ num_images_per_prompt=num_images_per_prompt,
541
+ do_classifier_free_guidance=do_classifier_free_guidance,
542
+ noise_level=noise_level,
543
+ generator=generator,
544
+ )
545
+
546
+ # 5. Prepare timesteps
547
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
548
+ timesteps = self.scheduler.timesteps
549
+
550
+ # 6. Prepare latent variables
551
+ num_channels_latents = self.unet.config.out_channels
552
+ if gt_img_in is not None:
553
+ latents = gt_img_in * self.scheduler.init_noise_sigma
554
+ else:
555
+ latents = self.prepare_latents(
556
+ batch_size=batch_size,
557
+ num_channels_latents=num_channels_latents,
558
+ height=height,
559
+ width=width,
560
+ dtype=prompt_embeds.dtype,
561
+ device=device,
562
+ generator=generator,
563
+ latents=latents,
564
+ )
565
+
566
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
567
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
568
+
569
+ eles, focals = [], []
570
+ # 8. Denoising loop
571
+ for i, t in enumerate(self.progress_bar(timesteps)):
572
+ if do_classifier_free_guidance:
573
+ normal_latents, color_latents = torch.chunk(latents, 2, dim=0)
574
+ latent_model_input = torch.cat([normal_latents, normal_latents, color_latents, color_latents], 0)
575
+ else:
576
+ latent_model_input = latents
577
+ latent_model_input = torch.cat([
578
+ latent_model_input, image_latents
579
+ ], dim=1)
580
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
581
+
582
+ # predict the noise residual
583
+ unet_out = self.unet(
584
+ latent_model_input,
585
+ t,
586
+ encoder_hidden_states=prompt_embeds,
587
+ dino_feature=dino_feature,
588
+ class_labels=image_embeds,
589
+ cross_attention_kwargs=cross_attention_kwargs,
590
+ return_dict=False)
591
+
592
+ noise_pred = unet_out[0]
593
+ if return_elevation_focal:
594
+ uncond_pose, pose = torch.chunk(unet_out[1], 2, 0)
595
+ pose = uncond_pose + guidance_scale * (pose - uncond_pose)
596
+ ele = pose[:, 0].detach().cpu().numpy() # b
597
+ eles.append(ele)
598
+ focal = pose[:, 1].detach().cpu().numpy()
599
+ focals.append(focal)
600
+
601
+ # perform guidance
602
+ if do_classifier_free_guidance:
603
+ normal_noise_pred_uncond, normal_noise_pred_text, color_noise_pred_uncond, color_noise_pred_text = torch.chunk(noise_pred, 4, dim=0)
604
+
605
+ noise_pred_uncond, noise_pred_text = torch.cat([normal_noise_pred_uncond, color_noise_pred_uncond], 0), torch.cat([normal_noise_pred_text, color_noise_pred_text], 0)
606
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
607
+
608
+ # compute the previous noisy sample x_t -> x_t-1
609
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
610
+
611
+ if callback is not None and i % callback_steps == 0:
612
+ callback(i, t, latents)
613
+
614
+ # 9. Post-processing
615
+ if not output_type == "latent":
616
+ if num_channels_latents == 8:
617
+ latents = torch.cat([latents[:, :4], latents[:, 4:]], dim=0)
618
+ with torch.no_grad():
619
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
620
+ else:
621
+ image = latents
622
+
623
+ image = self.image_processor.postprocess(image, output_type=output_type)
624
+
625
+ # Offload last model to CPU
626
+ # if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
627
+ # self.final_offload_hook.offload()
628
+ if not return_dict:
629
+ return (image, )
630
+ if return_elevation_focal:
631
+ return ImagePipelineOutput(images=image), eles, focals
632
+ else:
633
+ return ImagePipelineOutput(images=image)
requirements.txt CHANGED
@@ -1,15 +1,35 @@
1
- torch==2.0.1
2
- transformers
3
- flax
4
- -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
5
- jaxlib
6
- opencv-python
7
- git+https://github.com/huggingface/diffusers@main
8
- transformers
9
- openai
10
- einops
11
- torch
12
- torchvision
13
- accelerate
14
- gradio
15
- numpy<2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu118
2
+ diffusers[torch]==0.26.0
3
+ transformers==4.37.2
4
+ torch==2.1.2
5
+ torchvision==0.16.2
6
+ torchaudio==2.1.2
7
+ xformers==0.0.23.post1
8
+ decord==0.6.0
9
+ click==8.1.7
10
+ pytorch-lightning==1.9.0
11
+ omegaconf==2.2.3
12
+ nerfacc==0.3.3
13
+ trimesh==3.9.8
14
+ pyhocon==0.3.57
15
+ icecream==2.1.0
16
+ PyMCubes==0.1.2
17
+ accelerate==0.21.0
18
+ modelcards
19
+ einops
20
+ ftfy
21
+ piq
22
+ matplotlib
23
+ opencv-python
24
+ imageio
25
+ imageio-ffmpeg
26
+ scipy
27
+ pyransac3d
28
+ torch_efficient_distloss
29
+ tensorboard
30
+ rembg
31
+ segment_anything
32
+ gradio==4.31.5
33
+ moviepy
34
+ kornia
35
+ fire
sam_pt/sam_vit_h_4b8939.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
3
+ size 2564550879
utils/misc.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from omegaconf import OmegaConf
3
+ from packaging import version
4
+
5
+
6
+ # ============ Register OmegaConf Recolvers ============= #
7
+ OmegaConf.register_new_resolver('calc_exp_lr_decay_rate', lambda factor, n: factor**(1./n))
8
+ OmegaConf.register_new_resolver('add', lambda a, b: a + b)
9
+ OmegaConf.register_new_resolver('sub', lambda a, b: a - b)
10
+ OmegaConf.register_new_resolver('mul', lambda a, b: a * b)
11
+ OmegaConf.register_new_resolver('div', lambda a, b: a / b)
12
+ OmegaConf.register_new_resolver('idiv', lambda a, b: a // b)
13
+ OmegaConf.register_new_resolver('basename', lambda p: os.path.basename(p))
14
+ # ======================================================= #
15
+
16
+
17
+ def prompt(question):
18
+ inp = input(f"{question} (y/n)").lower().strip()
19
+ if inp and inp == 'y':
20
+ return True
21
+ if inp and inp == 'n':
22
+ return False
23
+ return prompt(question)
24
+
25
+
26
+ def load_config(*yaml_files, cli_args=[]):
27
+ yaml_confs = [OmegaConf.load(f) for f in yaml_files]
28
+ cli_conf = OmegaConf.from_cli(cli_args)
29
+ conf = OmegaConf.merge(*yaml_confs, cli_conf)
30
+ OmegaConf.resolve(conf)
31
+ return conf
32
+
33
+
34
+ def config_to_primitive(config, resolve=True):
35
+ return OmegaConf.to_container(config, resolve=resolve)
36
+
37
+
38
+ def dump_config(path, config):
39
+ with open(path, 'w') as fp:
40
+ OmegaConf.save(config=config, f=fp)
41
+
42
+ def get_rank():
43
+ # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing,
44
+ # therefore LOCAL_RANK needs to be checked first
45
+ rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK")
46
+ for key in rank_keys:
47
+ rank = os.environ.get(key)
48
+ if rank is not None:
49
+ return int(rank)
50
+ return 0
51
+
52
+
53
+ def parse_version(ver):
54
+ return version.parse(ver)
utils/utils.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision.utils import make_grid
2
+ from PIL import Image, ImageDraw, ImageFont
3
+ import numpy as np
4
+ import torch
5
+ def make_grid_(imgs, save_file, nrow=10, pad_value=1):
6
+ if isinstance(imgs, list):
7
+ if isinstance(imgs[0], Image.Image):
8
+ imgs = [torch.from_numpy(np.array(img)/255.) for img in imgs]
9
+ elif isinstance(imgs[0], np.ndarray):
10
+ imgs = [torch.from_numpy(img/255.) for img in imgs]
11
+ imgs = torch.stack(imgs, 0).permute(0, 3, 1, 2)
12
+ if isinstance(imgs, np.ndarray):
13
+ imgs = torch.from_numpy(imgs)
14
+
15
+ img_grid = make_grid(imgs, nrow=nrow, padding=2, pad_value=pad_value)
16
+ img_grid = img_grid.permute(1, 2, 0).numpy()
17
+ img_grid = (img_grid * 255).astype(np.uint8)
18
+ img_grid = Image.fromarray(img_grid)
19
+ img_grid.save(save_file)
20
+
21
+ def draw_caption(img, text, pos, size=100, color=(128, 128, 128)):
22
+ draw = ImageDraw.Draw(img)
23
+ # font = ImageFont.truetype(size= size)
24
+ font = ImageFont.load_default()
25
+ font = font.font_variant(size=size)
26
+ draw.text(pos, text, color, font=font)
27
+ return img