haodongli commited on
Commit
44189a1
1 Parent(s): d34366d
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🚀
4
  colorFrom: blue
5
  colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
4
  colorFrom: blue
5
  colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 4.21.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
app.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from gradio_imageslider import ImageSlider
3
+ import functools
4
+ import os
5
+ import tempfile
6
+ import diffusers
7
+ import gradio as gr
8
+ import imageio as imageio
9
+ import numpy as np
10
+ import spaces
11
+ import torch as torch
12
+ from PIL import Image
13
+ from tqdm import tqdm
14
+ from pathlib import Path
15
+ import gradio
16
+ from gradio.utils import get_cache_folder
17
+ from infer import lotus
18
+
19
+ # def process_image_check(path_input):
20
+ # if path_input is None:
21
+ # raise gr.Error(
22
+ # "Missing image in the first pane: upload a file or use one from the gallery below."
23
+ # )
24
+
25
+ # def infer(path_input, seed=0):
26
+ # print(f"==> Processing image {path_input}")
27
+ # return path_input
28
+ # return [path_input, path_input]
29
+ # # name_base, name_ext = os.path.splitext(os.path.basename(path_input))
30
+ # # print(f"==> Processing image {name_base}{name_ext}")
31
+ # # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ # # print(f"==> Device: {device}")
33
+ # # output_g, output_d = lotus(path_input, 'depth', seed, device)
34
+ # # if not os.path.exists("files/output"):
35
+ # # os.makedirs("files/output")
36
+ # # g_save_path = os.path.join("files/output", f"{name_base}_g{name_ext}")
37
+ # # d_save_path = os.path.join("files/output", f"{name_base}_d{name_ext}")
38
+ # # output_g.save(g_save_path)
39
+ # # output_d.save(d_save_path)
40
+ # # yield [path_input, g_save_path], [path_input, d_save_path]
41
+
42
+ # def run_demo_server():
43
+ # gradio_theme = gr.themes.Default()
44
+
45
+ # with gr.Blocks(
46
+ # theme=gradio_theme,
47
+ # title="LOTUS (Depth)",
48
+ # css="""
49
+ # #download {
50
+ # height: 118px;
51
+ # }
52
+ # .slider .inner {
53
+ # width: 5px;
54
+ # background: #FFF;
55
+ # }
56
+ # .viewport {
57
+ # aspect-ratio: 4/3;
58
+ # }
59
+ # .tabs button.selected {
60
+ # font-size: 20px !important;
61
+ # color: crimson !important;
62
+ # }
63
+ # h1 {
64
+ # text-align: center;
65
+ # display: block;
66
+ # }
67
+ # h2 {
68
+ # text-align: center;
69
+ # display: block;
70
+ # }
71
+ # h3 {
72
+ # text-align: center;
73
+ # display: block;
74
+ # }
75
+ # .md_feedback li {
76
+ # margin-bottom: 0px !important;
77
+ # }
78
+ # """,
79
+ # head="""
80
+ # <script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script>
81
+ # <script>
82
+ # window.dataLayer = window.dataLayer || [];
83
+ # function gtag() {dataLayer.push(arguments);}
84
+ # gtag('js', new Date());
85
+ # gtag('config', 'G-1FWSVCGZTG');
86
+ # </script>
87
+ # """,
88
+ # ) as demo:
89
+ # gr.Markdown(
90
+ # """
91
+ # # LOTUS: Diffusion-based Visual Foundation Model for High-quality Dense Prediction
92
+ # <p align="center">
93
+ # <a title="Page" href="https://lotus3d.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
94
+ # <img src="https://img.shields.io/badge/Project-Website-pink?logo=googlechrome&logoColor=white">
95
+ # </a>
96
+ # <a title="arXiv" href="https://arxiv.org/abs/2409.18124" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
97
+ # <img src="https://img.shields.io/badge/arXiv-Paper-b31b1b?logo=arxiv&logoColor=white">
98
+ # </a>
99
+ # <a title="Github" href="https://github.com/EnVision-Research/Lotus" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
100
+ # <img src="https://img.shields.io/github/stars/EnVision-Research/Lotus?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
101
+ # </a>
102
+ # <a title="Social" href="https://x.com/haodongli00/status/1839524569058582884" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
103
+ # <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
104
+ # </a>
105
+ # """
106
+ # )
107
+ # with gr.Tabs(elem_classes=["tabs"]):
108
+ # with gr.Tab("IMAGE"):
109
+ # with gr.Row():
110
+ # with gr.Column():
111
+ # image_input = gr.Image(
112
+ # label="Input Image",
113
+ # type="filepath",
114
+ # )
115
+ # seed = gr.Number(
116
+ # label="Seed",
117
+ # minimum=0,
118
+ # maximum=999999,
119
+ # )
120
+ # with gr.Row():
121
+ # image_submit_btn = gr.Button(
122
+ # value="Predict Depth!", variant="primary"
123
+ # )
124
+ # # image_reset_btn = gr.Button(value="Reset")
125
+ # with gr.Column():
126
+ # image_output_g = gr.Image(
127
+ # label="Output (Generative)",
128
+ # type="filepath",
129
+ # )
130
+ # # image_output_g = ImageSlider(
131
+ # # label="Output (Generative)",
132
+ # # type="filepath",
133
+ # # show_download_button=True,
134
+ # # show_share_button=True,
135
+ # # interactive=False,
136
+ # # elem_classes="slider",
137
+ # # position=0.25,
138
+ # # )
139
+ # # with gr.Row():
140
+ # # image_output_d = gr.Image(
141
+ # # label="Output (Generative)",
142
+ # # type="filepath",
143
+ # # )
144
+ # # image_output_d = ImageSlider(
145
+ # # label="Output (Discriminative)",
146
+ # # type="filepath",
147
+ # # show_download_button=True,
148
+ # # show_share_button=True,
149
+ # # interactive=False,
150
+ # # elem_classes="slider",
151
+ # # position=0.25,
152
+ # # )
153
+
154
+ # # gr.Examples(
155
+ # # fn=infer,
156
+ # # examples=sorted([
157
+ # # os.path.join("files", "images", name)
158
+ # # for name in os.listdir(os.path.join("files", "images"))
159
+ # # ]),
160
+ # # inputs=[image_input],
161
+ # # outputs=[image_output_g],
162
+ # # cache_examples=True,
163
+ # # )
164
+
165
+ # with gr.Tab("VIDEO"):
166
+ # with gr.Column():
167
+ # gr.Markdown("Coming soon")
168
+
169
+ # ### Image
170
+ # image_submit_btn.click(
171
+ # fn=infer,
172
+ # inputs=[
173
+ # image_input
174
+ # ],
175
+ # outputs=image_output_g,
176
+ # concurrency_limit=1,
177
+ # )
178
+ # # image_reset_btn.click(
179
+ # # fn=lambda: (
180
+ # # None,
181
+ # # None,
182
+ # # None,
183
+ # # ),
184
+ # # inputs=[],
185
+ # # outputs=image_output_g,
186
+ # # queue=False,
187
+ # # )
188
+
189
+ # ### Video
190
+
191
+ # ### Server launch
192
+ # demo.queue(
193
+ # api_open=False,
194
+ # ).launch(
195
+ # server_name="0.0.0.0",
196
+ # server_port=7860,
197
+ # )
198
+
199
+ # def main():
200
+ # os.system("pip freeze")
201
+ # run_demo_server()
202
+
203
+ # if __name__ == "__main__":
204
+ # main()
205
+
206
+ def flip_text(x):
207
+ return x[::-1]
208
+
209
+ def flip_image(x):
210
+ return np.fliplr(x)
211
+
212
+ with gr.Blocks() as demo:
213
+ gr.Markdown("Flip text or image files using this demo.")
214
+ with gr.Tab("Flip Text"):
215
+ text_input = gr.Textbox()
216
+ text_output = gr.Textbox()
217
+ text_button = gr.Button("Flip")
218
+ with gr.Tab("Flip Image"):
219
+ with gr.Row():
220
+ image_input = gr.Image()
221
+ image_output = gr.Image()
222
+ image_button = gr.Button("Flip")
223
+
224
+ with gr.Accordion("Open for More!", open=False):
225
+ gr.Markdown("Look at me...")
226
+ temp_slider = gr.Slider(
227
+ 0, 1,
228
+ value=0.1,
229
+ step=0.1,
230
+ interactive=True,
231
+ label="Slide me",
232
+ )
233
+
234
+ text_button.click(flip_text, inputs=text_input, outputs=text_output)
235
+ image_button.click(flip_image, inputs=image_input, outputs=image_output)
236
+
237
+ demo.launch(share=True)
files/images/00.png ADDED
files/output/00_d.png ADDED
files/output/00_g.png ADDED
files/output/01_d.jpeg ADDED
files/output/01_g.jpeg ADDED
files/videos/obama.mp4 ADDED
Binary file (320 kB). View file
 
infer.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from utils.args import parse_args
2
+ import logging
3
+ import os
4
+ import argparse
5
+ from pathlib import Path
6
+ from PIL import Image
7
+
8
+ import numpy as np
9
+ import torch
10
+ from tqdm.auto import tqdm
11
+ from diffusers.utils import check_min_version
12
+
13
+ from pipeline import LotusGPipeline, LotusDPipeline
14
+ from utils.image_utils import colorize_depth_map
15
+ from utils.seed_all import seed_all
16
+
17
+ check_min_version('0.28.0.dev0')
18
+
19
+ def infer_pipe(pipe, image_input, task_name, seed, device):
20
+ if seed is None:
21
+ generator = None
22
+ else:
23
+ generator = torch.Generator(device=device).manual_seed(seed)
24
+
25
+ test_image = Image.open(image_input).convert('RGB')
26
+ test_image = np.array(test_image).astype(np.float32)
27
+ test_image = torch.tensor(test_image).permute(2,0,1).unsqueeze(0)
28
+ test_image = test_image / 127.5 - 1.0
29
+ test_image = test_image.to(device)
30
+
31
+ task_emb = torch.tensor([1, 0]).float().unsqueeze(0).repeat(1, 1).to(device)
32
+ task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1).repeat(1, 1)
33
+
34
+ # Run
35
+ pred = pipe(
36
+ rgb_in=test_image,
37
+ prompt='',
38
+ num_inference_steps=1,
39
+ generator=generator,
40
+ # guidance_scale=0,
41
+ output_type='np',
42
+ timesteps=[999],
43
+ task_emb=task_emb,
44
+ ).images[0]
45
+
46
+ # Post-process the prediction
47
+ if task_name == 'depth':
48
+ output_npy = pred.mean(axis=-1)
49
+ output_color = colorize_depth_map(output_npy)
50
+ else:
51
+ output_npy = pred
52
+ output_color = Image.fromarray((output_npy * 255).astype(np.uint8))
53
+
54
+ return output_color
55
+
56
+ def lotus(image_input, task_name, seed, device):
57
+ if task_name == 'depth':
58
+ model_g = 'jingheya/lotus-depth-g-v1-0'
59
+ model_d = 'jingheya/lotus-depth-d-v1-0'
60
+ else:
61
+ model_g = 'jingheya/lotus-normal-g-v1-0'
62
+ model_d = 'jingheya/lotus-normal-d-v1-0'
63
+
64
+ dtype = torch.float32
65
+ pipe_g = LotusGPipeline.from_pretrained(
66
+ model_g,
67
+ torch_dtype=dtype,
68
+ )
69
+ pipe_d = LotusDPipeline.from_pretrained(
70
+ model_d,
71
+ torch_dtype=dtype,
72
+ )
73
+ pipe_g.to(device)
74
+ pipe_d.to(device)
75
+ logging.info(f"Successfully loading pipeline from {model_g} and {model_d}.")
76
+ output_g = infer_pipe(pipe_g, image_input, task_name, seed, device)
77
+ output_d = infer_pipe(pipe_d, image_input, task_name, seed, device)
78
+ return output_g, output_d
79
+
80
+ def parse_args():
81
+ '''Set the Args'''
82
+ parser = argparse.ArgumentParser(
83
+ description="Run Lotus..."
84
+ )
85
+ # model settings
86
+ parser.add_argument(
87
+ "--pretrained_model_name_or_path",
88
+ type=str,
89
+ default=None,
90
+ help="pretrained model path from hugging face or local dir",
91
+ )
92
+ parser.add_argument(
93
+ "--prediction_type",
94
+ type=str,
95
+ default="sample",
96
+ help="The used prediction_type. ",
97
+ )
98
+ parser.add_argument(
99
+ "--timestep",
100
+ type=int,
101
+ default=999,
102
+ )
103
+ parser.add_argument(
104
+ "--mode",
105
+ type=str,
106
+ default="regression", # "generation"
107
+ help="Whether to use the generation or regression pipeline."
108
+ )
109
+ parser.add_argument(
110
+ "--task_name",
111
+ type=str,
112
+ default="depth", # "normal"
113
+ )
114
+ parser.add_argument(
115
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
116
+ )
117
+
118
+ # inference settings
119
+ parser.add_argument("--seed", type=int, default=None, help="Random seed.")
120
+ parser.add_argument(
121
+ "--output_dir", type=str, required=True, help="Output directory."
122
+ )
123
+ parser.add_argument(
124
+ "--input_dir", type=str, required=True, help="Input directory."
125
+ )
126
+ parser.add_argument(
127
+ "--half_precision",
128
+ action="store_true",
129
+ help="Run with half-precision (16-bit float), might lead to suboptimal result.",
130
+ )
131
+
132
+ args = parser.parse_args()
133
+
134
+ return args
135
+
136
+ def main():
137
+ logging.basicConfig(level=logging.INFO)
138
+ logging.info(f"Run inference...")
139
+
140
+ args = parse_args()
141
+
142
+ # -------------------- Preparation --------------------
143
+ # Random seed
144
+ if args.seed is not None:
145
+ seed_all(args.seed)
146
+
147
+ # Output directories
148
+ os.makedirs(args.output_dir, exist_ok=True)
149
+ logging.info(f"Output dir = {args.output_dir}")
150
+
151
+ output_dir_color = os.path.join(args.output_dir, f'{args.task_name}_vis')
152
+ output_dir_npy = os.path.join(args.output_dir, f'{args.task_name}')
153
+ if not os.path.exists(output_dir_color): os.makedirs(output_dir_color)
154
+ if not os.path.exists(output_dir_npy): os.makedirs(output_dir_npy)
155
+
156
+ # half_precision
157
+ if args.half_precision:
158
+ dtype = torch.float16
159
+ logging.info(f"Running with half precision ({dtype}).")
160
+ else:
161
+ dtype = torch.float32
162
+
163
+ # -------------------- Device --------------------
164
+ if torch.cuda.is_available():
165
+ device = torch.device("cuda")
166
+ else:
167
+ device = torch.device("cpu")
168
+ logging.warning("CUDA is not available. Running on CPU will be slow.")
169
+ logging.info(f"Device = {device}")
170
+
171
+ # -------------------- Data --------------------
172
+ root_dir = Path(args.input_dir)
173
+ test_images = list(root_dir.rglob('*.png')) + list(root_dir.rglob('*.jpg'))
174
+ test_images = sorted(test_images)
175
+ print('==> There are', len(test_images), 'images for validation.')
176
+ # -------------------- Model --------------------
177
+
178
+ if args.mode == 'generation':
179
+ pipeline = LotusGPipeline.from_pretrained(
180
+ args.pretrained_model_name_or_path,
181
+ torch_dtype=dtype,
182
+ )
183
+ elif args.mode == 'regression':
184
+ pipeline = LotusDPipeline.from_pretrained(
185
+ args.pretrained_model_name_or_path,
186
+ torch_dtype=dtype,
187
+ )
188
+ else:
189
+ raise ValueError(f'Invalid mode: {args.mode}')
190
+ logging.info(f"Successfully loading pipeline from {args.pretrained_model_name_or_path}.")
191
+
192
+ pipeline = pipeline.to(device)
193
+ pipeline.set_progress_bar_config(disable=True)
194
+
195
+ if args.enable_xformers_memory_efficient_attention:
196
+ pipeline.enable_xformers_memory_efficient_attention()
197
+
198
+
199
+ if args.seed is None:
200
+ generator = None
201
+ else:
202
+ generator = torch.Generator(device=device).manual_seed(args.seed)
203
+
204
+ # -------------------- Inference and saving --------------------
205
+ with torch.no_grad():
206
+ for i in tqdm(range(len(test_images))):
207
+ # Preprocess validation image
208
+ test_image = Image.open(test_images[i]).convert('RGB')
209
+ test_image = np.array(test_image).astype(np.float32)
210
+ test_image = torch.tensor(test_image).permute(2,0,1).unsqueeze(0)
211
+ test_image = test_image / 127.5 - 1.0
212
+ test_image = test_image.to(device)
213
+
214
+ task_emb = torch.tensor([1, 0]).float().unsqueeze(0).repeat(1, 1).to(device)
215
+ task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1).repeat(1, 1)
216
+
217
+ # Run
218
+ pred = pipeline(
219
+ rgb_in=test_image,
220
+ prompt='',
221
+ num_inference_steps=1,
222
+ generator=generator,
223
+ # guidance_scale=0,
224
+ output_type='np',
225
+ timesteps=[args.timestep],
226
+ task_emb=task_emb,
227
+ ).images[0]
228
+
229
+ # Post-process the prediction
230
+ save_file_name = os.path.basename(test_images[i])[:-4]
231
+ if args.task_name == 'depth':
232
+ output_npy = pred.mean(axis=-1)
233
+ output_color = colorize_depth_map(output_npy)
234
+ else:
235
+ output_npy = pred
236
+ output_color = Image.fromarray((output_npy * 255).astype(np.uint8))
237
+
238
+ output_color.save(os.path.join(output_dir_color, f'{save_file_name}.png'))
239
+ np.save(os.path.join(output_dir_npy, f'{save_file_name}.npy'), output_npy)
240
+
241
+ print('==> Inference is done. \n==> Results saved to:', args.output_dir)
242
+
243
+ if __name__ == '__main__':
244
+ main()
pipeline.py ADDED
@@ -0,0 +1,1285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import inspect
3
+ from typing import Any, Callable, Dict, List, Optional, Union
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from packaging import version
8
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
9
+ import tensorboard
10
+
11
+ from diffusers.configuration_utils import FrozenDict
12
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
13
+ from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
14
+ from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
15
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
16
+ from diffusers.schedulers import KarrasDiffusionSchedulers
17
+ from diffusers.utils import (
18
+ USE_PEFT_BACKEND,
19
+ deprecate,
20
+ logging,
21
+ replace_example_docstring,
22
+ scale_lora_layers,
23
+ unscale_lora_layers,
24
+ )
25
+ from diffusers.utils.torch_utils import randn_tensor
26
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
27
+ from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
28
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
29
+ from diffusers import StableDiffusionPipeline
30
+
31
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
+
33
+
34
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
35
+ """
36
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
37
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
38
+ """
39
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
40
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
41
+ # rescale the results from guidance (fixes overexposure)
42
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
43
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
44
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
45
+ return noise_cfg
46
+
47
+
48
+ def retrieve_timesteps(
49
+ scheduler,
50
+ num_inference_steps: Optional[int] = None,
51
+ device: Optional[Union[str, torch.device]] = None,
52
+ timesteps: Optional[List[int]] = None,
53
+ **kwargs,
54
+ ):
55
+ """
56
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
57
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
58
+
59
+ Args:
60
+ scheduler (`SchedulerMixin`):
61
+ The scheduler to get timesteps from.
62
+ num_inference_steps (`int`):
63
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
64
+ must be `None`.
65
+ device (`str` or `torch.device`, *optional*):
66
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
67
+ timesteps (`List[int]`, *optional*):
68
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
69
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
70
+ must be `None`.
71
+
72
+ Returns:
73
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
74
+ second element is the number of inference steps.
75
+ """
76
+ if timesteps is not None:
77
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
78
+ if not accepts_timesteps:
79
+ raise ValueError(
80
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
81
+ f" timestep schedules. Please check whether you are using the correct scheduler."
82
+ )
83
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
84
+ timesteps = scheduler.timesteps
85
+ num_inference_steps = len(timesteps)
86
+ else:
87
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
88
+ timesteps = scheduler.timesteps
89
+ return timesteps, num_inference_steps
90
+
91
+
92
+ class DirectDiffusionPipeline(
93
+ DiffusionPipeline,
94
+ StableDiffusionMixin,
95
+ TextualInversionLoaderMixin,
96
+ LoraLoaderMixin,
97
+ IPAdapterMixin,
98
+ FromSingleFileMixin,
99
+ ):
100
+ r"""
101
+ Pipeline for text-to-image generation using Stable Diffusion.
102
+
103
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
104
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
105
+
106
+ The pipeline also inherits the following loading methods:
107
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
108
+ - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
109
+ - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
110
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
111
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
112
+
113
+ Args:
114
+ vae ([`AutoencoderKL`]):
115
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
116
+ text_encoder ([`~transformers.CLIPTextModel`]):
117
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
118
+ tokenizer ([`~transformers.CLIPTokenizer`]):
119
+ A `CLIPTokenizer` to tokenize text.
120
+ unet ([`UNet2DConditionModel`]):
121
+ A `UNet2DConditionModel` to denoise the encoded image latents.
122
+ scheduler ([`SchedulerMixin`]):
123
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
124
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
125
+ safety_checker ([`StableDiffusionSafetyChecker`]):
126
+ Classification module that estimates whether generated images could be considered offensive or harmful.
127
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
128
+ about a model's potential harms.
129
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
130
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
131
+ """
132
+
133
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
134
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
135
+ _exclude_from_cpu_offload = ["safety_checker"]
136
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
137
+
138
+ def __init__(
139
+ self,
140
+ vae: AutoencoderKL,
141
+ text_encoder: CLIPTextModel,
142
+ tokenizer: CLIPTokenizer,
143
+ unet: UNet2DConditionModel,
144
+ scheduler: KarrasDiffusionSchedulers,
145
+ safety_checker: StableDiffusionSafetyChecker,
146
+ feature_extractor: CLIPImageProcessor,
147
+ image_encoder: CLIPVisionModelWithProjection = None,
148
+ requires_safety_checker: bool = True,
149
+ ):
150
+ super().__init__()
151
+
152
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
153
+ deprecation_message = (
154
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
155
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
156
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
157
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
158
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
159
+ " file"
160
+ )
161
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
162
+ new_config = dict(scheduler.config)
163
+ new_config["steps_offset"] = 1
164
+ scheduler._internal_dict = FrozenDict(new_config)
165
+
166
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
167
+ deprecation_message = (
168
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
169
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
170
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
171
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
172
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
173
+ )
174
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
175
+ new_config = dict(scheduler.config)
176
+ new_config["clip_sample"] = False
177
+ scheduler._internal_dict = FrozenDict(new_config)
178
+
179
+ if safety_checker is None and requires_safety_checker:
180
+ logger.warning(
181
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
182
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
183
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
184
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
185
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
186
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
187
+ )
188
+
189
+ if safety_checker is not None and feature_extractor is None:
190
+ raise ValueError(
191
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
192
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
193
+ )
194
+
195
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
196
+ version.parse(unet.config._diffusers_version).base_version
197
+ ) < version.parse("0.9.0.dev0")
198
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
199
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
200
+ deprecation_message = (
201
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
202
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
203
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
204
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
205
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
206
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
207
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
208
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
209
+ " the `unet/config.json` file"
210
+ )
211
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
212
+ new_config = dict(unet.config)
213
+ new_config["sample_size"] = 64
214
+ unet._internal_dict = FrozenDict(new_config)
215
+
216
+ self.register_modules(
217
+ vae=vae,
218
+ text_encoder=text_encoder,
219
+ tokenizer=tokenizer,
220
+ unet=unet,
221
+ scheduler=scheduler,
222
+ safety_checker=safety_checker,
223
+ feature_extractor=feature_extractor,
224
+ image_encoder=image_encoder,
225
+ )
226
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
227
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
228
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
229
+
230
+ def _encode_prompt(
231
+ self,
232
+ prompt,
233
+ device,
234
+ num_images_per_prompt,
235
+ do_classifier_free_guidance,
236
+ negative_prompt=None,
237
+ prompt_embeds: Optional[torch.FloatTensor] = None,
238
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
239
+ lora_scale: Optional[float] = None,
240
+ **kwargs,
241
+ ):
242
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
243
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
244
+
245
+ prompt_embeds_tuple = self.encode_prompt(
246
+ prompt=prompt,
247
+ device=device,
248
+ num_images_per_prompt=num_images_per_prompt,
249
+ do_classifier_free_guidance=do_classifier_free_guidance,
250
+ negative_prompt=negative_prompt,
251
+ prompt_embeds=prompt_embeds,
252
+ negative_prompt_embeds=negative_prompt_embeds,
253
+ lora_scale=lora_scale,
254
+ **kwargs,
255
+ )
256
+
257
+ # concatenate for backwards comp
258
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
259
+
260
+ return prompt_embeds
261
+
262
+ def encode_prompt(
263
+ self,
264
+ prompt,
265
+ device,
266
+ num_images_per_prompt,
267
+ do_classifier_free_guidance,
268
+ negative_prompt=None,
269
+ padding_type="do_not_pad",
270
+ prompt_embeds: Optional[torch.FloatTensor] = None,
271
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
272
+ lora_scale: Optional[float] = None,
273
+ clip_skip: Optional[int] = None,
274
+ ):
275
+ r"""
276
+ Encodes the prompt into text encoder hidden states.
277
+
278
+ Args:
279
+ prompt (`str` or `List[str]`, *optional*):
280
+ prompt to be encoded
281
+ device: (`torch.device`):
282
+ torch device
283
+ num_images_per_prompt (`int`):
284
+ number of images that should be generated per prompt
285
+ do_classifier_free_guidance (`bool`):
286
+ whether to use classifier free guidance or not
287
+ negative_prompt (`str` or `List[str]`, *optional*):
288
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
289
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
290
+ less than `1`).
291
+ prompt_embeds (`torch.FloatTensor`, *optional*):
292
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
293
+ provided, text embeddings will be generated from `prompt` input argument.
294
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
295
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
296
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
297
+ argument.
298
+ lora_scale (`float`, *optional*):
299
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
300
+ clip_skip (`int`, *optional*):
301
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
302
+ the output of the pre-final layer will be used for computing the prompt embeddings.
303
+ """
304
+ # set lora scale so that monkey patched LoRA
305
+ # function of text encoder can correctly access it
306
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
307
+ self._lora_scale = lora_scale
308
+
309
+ # dynamically adjust the LoRA scale
310
+ if not USE_PEFT_BACKEND:
311
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
312
+ else:
313
+ scale_lora_layers(self.text_encoder, lora_scale)
314
+
315
+ if prompt is not None and isinstance(prompt, str):
316
+ batch_size = 1
317
+ elif prompt is not None and isinstance(prompt, list):
318
+ batch_size = len(prompt)
319
+ else:
320
+ batch_size = prompt_embeds.shape[0]
321
+
322
+ if prompt_embeds is None:
323
+ # textual inversion: process multi-vector tokens if necessary
324
+ if isinstance(self, TextualInversionLoaderMixin):
325
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
326
+
327
+ text_inputs = self.tokenizer(
328
+ prompt,
329
+ padding=padding_type,
330
+ max_length=self.tokenizer.model_max_length,
331
+ truncation=True,
332
+ return_tensors="pt",
333
+ )
334
+ text_input_ids = text_inputs.input_ids
335
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
336
+
337
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
338
+ text_input_ids, untruncated_ids
339
+ ):
340
+ removed_text = self.tokenizer.batch_decode(
341
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
342
+ )
343
+ logger.warning(
344
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
345
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
346
+ )
347
+
348
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
349
+ attention_mask = text_inputs.attention_mask.to(device)
350
+ else:
351
+ attention_mask = None
352
+
353
+ if clip_skip is None:
354
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
355
+ prompt_embeds = prompt_embeds[0]
356
+ else:
357
+ prompt_embeds = self.text_encoder(
358
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
359
+ )
360
+ # Access the `hidden_states` first, that contains a tuple of
361
+ # all the hidden states from the encoder layers. Then index into
362
+ # the tuple to access the hidden states from the desired layer.
363
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
364
+ # We also need to apply the final LayerNorm here to not mess with the
365
+ # representations. The `last_hidden_states` that we typically use for
366
+ # obtaining the final prompt representations passes through the LayerNorm
367
+ # layer.
368
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
369
+
370
+ if self.text_encoder is not None:
371
+ prompt_embeds_dtype = self.text_encoder.dtype
372
+ elif self.unet is not None:
373
+ prompt_embeds_dtype = self.unet.dtype
374
+ else:
375
+ prompt_embeds_dtype = prompt_embeds.dtype
376
+
377
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
378
+
379
+ bs_embed, seq_len, _ = prompt_embeds.shape
380
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
381
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
382
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
383
+
384
+ # get unconditional embeddings for classifier free guidance
385
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
386
+ uncond_tokens: List[str]
387
+ if negative_prompt is None:
388
+ uncond_tokens = [""] * batch_size
389
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
390
+ raise TypeError(
391
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
392
+ f" {type(prompt)}."
393
+ )
394
+ elif isinstance(negative_prompt, str):
395
+ uncond_tokens = [negative_prompt]
396
+ elif batch_size != len(negative_prompt):
397
+ raise ValueError(
398
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
399
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
400
+ " the batch size of `prompt`."
401
+ )
402
+ else:
403
+ uncond_tokens = negative_prompt
404
+
405
+ # textual inversion: process multi-vector tokens if necessary
406
+ if isinstance(self, TextualInversionLoaderMixin):
407
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
408
+
409
+ max_length = prompt_embeds.shape[1]
410
+ uncond_input = self.tokenizer(
411
+ uncond_tokens,
412
+ padding="max_length",
413
+ max_length=max_length,
414
+ truncation=True,
415
+ return_tensors="pt",
416
+ )
417
+
418
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
419
+ attention_mask = uncond_input.attention_mask.to(device)
420
+ else:
421
+ attention_mask = None
422
+
423
+ negative_prompt_embeds = self.text_encoder(
424
+ uncond_input.input_ids.to(device),
425
+ attention_mask=attention_mask,
426
+ )
427
+ negative_prompt_embeds = negative_prompt_embeds[0]
428
+
429
+ if do_classifier_free_guidance:
430
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
431
+ seq_len = negative_prompt_embeds.shape[1]
432
+
433
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
434
+
435
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
436
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
437
+
438
+ if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
439
+ # Retrieve the original scale by scaling back the LoRA layers
440
+ unscale_lora_layers(self.text_encoder, lora_scale)
441
+
442
+ return prompt_embeds, negative_prompt_embeds
443
+
444
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
445
+ dtype = next(self.image_encoder.parameters()).dtype
446
+
447
+ if not isinstance(image, torch.Tensor):
448
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
449
+
450
+ image = image.to(device=device, dtype=dtype)
451
+ if output_hidden_states:
452
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
453
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
454
+ uncond_image_enc_hidden_states = self.image_encoder(
455
+ torch.zeros_like(image), output_hidden_states=True
456
+ ).hidden_states[-2]
457
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
458
+ num_images_per_prompt, dim=0
459
+ )
460
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
461
+ else:
462
+ image_embeds = self.image_encoder(image).image_embeds
463
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
464
+ uncond_image_embeds = torch.zeros_like(image_embeds)
465
+
466
+ return image_embeds, uncond_image_embeds
467
+
468
+ def prepare_ip_adapter_image_embeds(
469
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
470
+ ):
471
+ if ip_adapter_image_embeds is None:
472
+ if not isinstance(ip_adapter_image, list):
473
+ ip_adapter_image = [ip_adapter_image]
474
+
475
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
476
+ raise ValueError(
477
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
478
+ )
479
+
480
+ image_embeds = []
481
+ for single_ip_adapter_image, image_proj_layer in zip(
482
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
483
+ ):
484
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
485
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
486
+ single_ip_adapter_image, device, 1, output_hidden_state
487
+ )
488
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
489
+ single_negative_image_embeds = torch.stack(
490
+ [single_negative_image_embeds] * num_images_per_prompt, dim=0
491
+ )
492
+
493
+ if do_classifier_free_guidance:
494
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
495
+ single_image_embeds = single_image_embeds.to(device)
496
+
497
+ image_embeds.append(single_image_embeds)
498
+ else:
499
+ repeat_dims = [1]
500
+ image_embeds = []
501
+ for single_image_embeds in ip_adapter_image_embeds:
502
+ if do_classifier_free_guidance:
503
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
504
+ single_image_embeds = single_image_embeds.repeat(
505
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
506
+ )
507
+ single_negative_image_embeds = single_negative_image_embeds.repeat(
508
+ num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
509
+ )
510
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
511
+ else:
512
+ single_image_embeds = single_image_embeds.repeat(
513
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
514
+ )
515
+ image_embeds.append(single_image_embeds)
516
+
517
+ return image_embeds
518
+
519
+ def run_safety_checker(self, image, device, dtype):
520
+ if self.safety_checker is None:
521
+ has_nsfw_concept = None
522
+ else:
523
+ if torch.is_tensor(image):
524
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
525
+ else:
526
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
527
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
528
+ image, has_nsfw_concept = self.safety_checker(
529
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
530
+ )
531
+ return image, has_nsfw_concept
532
+
533
+ def decode_latents(self, latents):
534
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
535
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
536
+
537
+ latents = 1 / self.vae.config.scaling_factor * latents
538
+ image = self.vae.decode(latents, return_dict=False)[0]
539
+ image = (image / 2 + 0.5).clamp(0, 1)
540
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
541
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
542
+ return image
543
+
544
+ def prepare_extra_step_kwargs(self, generator, eta):
545
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
546
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
547
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
548
+ # and should be between [0, 1]
549
+
550
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
551
+ extra_step_kwargs = {}
552
+ if accepts_eta:
553
+ extra_step_kwargs["eta"] = eta
554
+
555
+ # check if the scheduler accepts generator
556
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
557
+ if accepts_generator:
558
+ extra_step_kwargs["generator"] = generator
559
+ return extra_step_kwargs
560
+
561
+ def check_inputs(
562
+ self,
563
+ prompt,
564
+ height,
565
+ width,
566
+ callback_steps,
567
+ negative_prompt=None,
568
+ prompt_embeds=None,
569
+ negative_prompt_embeds=None,
570
+ ip_adapter_image=None,
571
+ ip_adapter_image_embeds=None,
572
+ callback_on_step_end_tensor_inputs=None,
573
+ ):
574
+ if height % 8 != 0 or width % 8 != 0:
575
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
576
+
577
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
578
+ raise ValueError(
579
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
580
+ f" {type(callback_steps)}."
581
+ )
582
+ if callback_on_step_end_tensor_inputs is not None and not all(
583
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
584
+ ):
585
+ raise ValueError(
586
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
587
+ )
588
+
589
+ if prompt is not None and prompt_embeds is not None:
590
+ raise ValueError(
591
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
592
+ " only forward one of the two."
593
+ )
594
+ elif prompt is None and prompt_embeds is None:
595
+ raise ValueError(
596
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
597
+ )
598
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
599
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
600
+
601
+ if negative_prompt is not None and negative_prompt_embeds is not None:
602
+ raise ValueError(
603
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
604
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
605
+ )
606
+
607
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
608
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
609
+ raise ValueError(
610
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
611
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
612
+ f" {negative_prompt_embeds.shape}."
613
+ )
614
+
615
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
616
+ raise ValueError(
617
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
618
+ )
619
+
620
+ if ip_adapter_image_embeds is not None:
621
+ if not isinstance(ip_adapter_image_embeds, list):
622
+ raise ValueError(
623
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
624
+ )
625
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
626
+ raise ValueError(
627
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
628
+ )
629
+
630
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
631
+ shape = (
632
+ batch_size,
633
+ num_channels_latents,
634
+ int(height) // self.vae_scale_factor,
635
+ int(width) // self.vae_scale_factor,
636
+ )
637
+ if isinstance(generator, list) and len(generator) != batch_size:
638
+ raise ValueError(
639
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
640
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
641
+ )
642
+
643
+ if latents is None:
644
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
645
+ else:
646
+ latents = latents.to(device)
647
+
648
+ # scale the initial noise by the standard deviation required by the scheduler
649
+ latents = latents * self.scheduler.init_noise_sigma
650
+ return latents
651
+
652
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
653
+ def get_guidance_scale_embedding(
654
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
655
+ ) -> torch.FloatTensor:
656
+ """
657
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
658
+
659
+ Args:
660
+ w (`torch.Tensor`):
661
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
662
+ embedding_dim (`int`, *optional*, defaults to 512):
663
+ Dimension of the embeddings to generate.
664
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
665
+ Data type of the generated embeddings.
666
+
667
+ Returns:
668
+ `torch.FloatTensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
669
+ """
670
+ assert len(w.shape) == 1
671
+ w = w * 1000.0
672
+
673
+ half_dim = embedding_dim // 2
674
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
675
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
676
+ emb = w.to(dtype)[:, None] * emb[None, :]
677
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
678
+ if embedding_dim % 2 == 1: # zero pad
679
+ emb = torch.nn.functional.pad(emb, (0, 1))
680
+ assert emb.shape == (w.shape[0], embedding_dim)
681
+ return emb
682
+
683
+ @property
684
+ def guidance_scale(self):
685
+ return self._guidance_scale
686
+
687
+ @property
688
+ def guidance_rescale(self):
689
+ return self._guidance_rescale
690
+
691
+ @property
692
+ def clip_skip(self):
693
+ return self._clip_skip
694
+
695
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
696
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
697
+ # corresponds to doing no classifier free guidance.
698
+ @property
699
+ def do_classifier_free_guidance(self):
700
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
701
+
702
+ @property
703
+ def cross_attention_kwargs(self):
704
+ return self._cross_attention_kwargs
705
+
706
+ @property
707
+ def num_timesteps(self):
708
+ return self._num_timesteps
709
+
710
+ @property
711
+ def interrupt(self):
712
+ return self._interrupt
713
+
714
+ @torch.no_grad()
715
+ def __call__(
716
+ self,
717
+ rgb_in: Optional[torch.FloatTensor] = None,
718
+ prompt: Union[str, List[str]] = None,
719
+ height: Optional[int] = None,
720
+ width: Optional[int] = None,
721
+ num_inference_steps: int = 50,
722
+ timesteps: List[int] = None,
723
+ guidance_scale: float = 7.5,
724
+ negative_prompt: Optional[Union[str, List[str]]] = None,
725
+ num_images_per_prompt: Optional[int] = 1,
726
+ eta: float = 0.0,
727
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
728
+ latents: Optional[torch.FloatTensor] = None,
729
+ prompt_embeds: Optional[torch.FloatTensor] = None,
730
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
731
+ ip_adapter_image: Optional[PipelineImageInput] = None,
732
+ ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
733
+ output_type: Optional[str] = "pil",
734
+ return_dict: bool = True,
735
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
736
+ guidance_rescale: float = 0.0,
737
+ clip_skip: Optional[int] = None,
738
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
739
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
740
+ return_intermediate_timestep_idx: Optional[int] = None,
741
+ **kwargs,
742
+ ):
743
+ r"""
744
+ The call function to the pipeline for generation.
745
+
746
+ Args:
747
+
748
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
749
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
750
+ The height in pixels of the generated image.
751
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
752
+ The width in pixels of the generated image.
753
+ num_inference_steps (`int`, *optional*, defaults to 50):
754
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
755
+ expense of slower inference.
756
+ timesteps (`List[int]`, *optional*):
757
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
758
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
759
+ passed will be used. Must be in descending order.
760
+ guidance_scale (`float`, *optional*, defaults to 7.5):
761
+ A higher guidance scale value encourages the model to generate images closely linked to the text
762
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
763
+ negative_prompt (`str` or `List[str]`, *optional*):
764
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
765
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
766
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
767
+ The number of images to generate per prompt.
768
+ eta (`float`, *optional*, defaults to 0.0):
769
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
770
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
771
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
772
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
773
+ generation deterministic.
774
+ latents (`torch.FloatTensor`, *optional*):
775
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
776
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
777
+ tensor is generated by sampling using the supplied random `generator`.
778
+ prompt_embeds (`torch.FloatTensor`, *optional*):
779
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
780
+ provided, text embeddings are generated from the `prompt` input argument.
781
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
782
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
783
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
784
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
785
+ ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
786
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
787
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
788
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
789
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
790
+ output_type (`str`, *optional*, defaults to `"pil"`):
791
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
792
+ return_dict (`bool`, *optional*, defaults to `True`):
793
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
794
+ plain tuple.
795
+ cross_attention_kwargs (`dict`, *optional*):
796
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
797
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
798
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
799
+ Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
800
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
801
+ using zero terminal SNR.
802
+ clip_skip (`int`, *optional*):
803
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
804
+ the output of the pre-final layer will be used for computing the prompt embeddings.
805
+ callback_on_step_end (`Callable`, *optional*):
806
+ A function that calls at the end of each denoising steps during the inference. The function is called
807
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
808
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
809
+ `callback_on_step_end_tensor_inputs`.
810
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
811
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
812
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
813
+ `._callback_tensor_inputs` attribute of your pipeline class.
814
+
815
+ Examples:
816
+
817
+ Returns:
818
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
819
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
820
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
821
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
822
+ "not-safe-for-work" (nsfw) content.
823
+ """
824
+
825
+ callback = kwargs.pop("callback", None)
826
+ callback_steps = kwargs.pop("callback_steps", None)
827
+
828
+ if callback is not None:
829
+ deprecate(
830
+ "callback",
831
+ "1.0.0",
832
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
833
+ )
834
+ if callback_steps is not None:
835
+ deprecate(
836
+ "callback_steps",
837
+ "1.0.0",
838
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
839
+ )
840
+
841
+ # 0. Default height and width to unet
842
+ height, width = rgb_in.shape[2:]
843
+
844
+ # to deal with lora scaling and other possible forward hooks
845
+
846
+ # # 1. Check inputs. Raise error if not correct
847
+
848
+ self._guidance_scale = guidance_scale
849
+ self._guidance_rescale = guidance_rescale
850
+ self._clip_skip = clip_skip
851
+ self._cross_attention_kwargs = cross_attention_kwargs
852
+ self._interrupt = False
853
+
854
+ # 2. Define call parameters
855
+ batch_size = rgb_in.shape[0]
856
+
857
+ device = self._execution_device
858
+
859
+ # 3. Encode input prompt
860
+ lora_scale = (
861
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
862
+ )
863
+
864
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
865
+ prompt,
866
+ device,
867
+ num_images_per_prompt,
868
+ self.do_classifier_free_guidance,
869
+ negative_prompt,
870
+ prompt_embeds=prompt_embeds,
871
+ negative_prompt_embeds=negative_prompt_embeds,
872
+ lora_scale=lora_scale,
873
+ clip_skip=self.clip_skip,
874
+ )
875
+
876
+ # For classifier free guidance, we need to do two forward passes.
877
+ # Here we concatenate the unconditional and text embeddings into a single batch
878
+ # to avoid doing two forward passes
879
+ if self.do_classifier_free_guidance:
880
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
881
+
882
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
883
+ image_embeds = self.prepare_ip_adapter_image_embeds(
884
+ ip_adapter_image,
885
+ ip_adapter_image_embeds,
886
+ device,
887
+ batch_size * num_images_per_prompt,
888
+ self.do_classifier_free_guidance,
889
+ )
890
+
891
+ # 4. Prepare timesteps
892
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
893
+
894
+ # 5. Prepare latent variables
895
+ num_channels_latents = self.unet.config.in_channels // 2
896
+ latents = self.prepare_latents(
897
+ batch_size * num_images_per_prompt,
898
+ num_channels_latents,
899
+ height,
900
+ width,
901
+ prompt_embeds.dtype,
902
+ device,
903
+ generator,
904
+ latents,
905
+ )
906
+
907
+ rgb_latents = self.vae.encode(rgb_in.to(device)).latent_dist.sample()
908
+ rgb_latents = rgb_latents * self.vae.config.scaling_factor
909
+
910
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
911
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
912
+
913
+ # 6.1 Add image embeds for IP-Adapter
914
+ added_cond_kwargs = (
915
+ {"image_embeds": image_embeds}
916
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None)
917
+ else None
918
+ )
919
+
920
+ # 6.2 Optionally get Guidance Scale Embedding
921
+ timestep_cond = None
922
+ if self.unet.config.time_cond_proj_dim is not None:
923
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
924
+ timestep_cond = self.get_guidance_scale_embedding(
925
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
926
+ ).to(device=device, dtype=latents.dtype)
927
+
928
+ # 7. Denoising loop
929
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
930
+ self._num_timesteps = len(timesteps)
931
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
932
+ for i, t in enumerate(timesteps):
933
+ if self.interrupt:
934
+ continue
935
+
936
+ # expand the latents if we are doing classifier free guidance
937
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
938
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
939
+
940
+ latent_model_input = torch.cat(
941
+ [rgb_latents, latent_model_input], dim=1
942
+ )
943
+
944
+ # predict the noise residual
945
+ noise_pred = self.unet(
946
+ latent_model_input,
947
+ t,
948
+ encoder_hidden_states=prompt_embeds,
949
+ timestep_cond=timestep_cond,
950
+ cross_attention_kwargs=self.cross_attention_kwargs,
951
+ added_cond_kwargs=added_cond_kwargs,
952
+ return_dict=False,
953
+ )[0]
954
+
955
+ # perform guidance
956
+ if self.do_classifier_free_guidance:
957
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
958
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
959
+
960
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
961
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
962
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
963
+
964
+ # compute the previous noisy sample x_t -> x_t-1
965
+ pred_latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=True)
966
+ latents = pred_latents[0]
967
+
968
+ if callback_on_step_end is not None:
969
+ callback_kwargs = {}
970
+ for k in callback_on_step_end_tensor_inputs:
971
+ callback_kwargs[k] = locals()[k]
972
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
973
+
974
+ latents = callback_outputs.pop("latents", latents)
975
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
976
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
977
+
978
+ # call the callback, if provided
979
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
980
+ progress_bar.update()
981
+ if callback is not None and i % callback_steps == 0:
982
+ step_idx = i // getattr(self.scheduler, "order", 1)
983
+ callback(step_idx, t, latents)
984
+
985
+ if return_intermediate_timestep_idx == i:
986
+ latents = pred_latents[1]
987
+ break
988
+
989
+ if not output_type == "latent":
990
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
991
+ 0
992
+ ]
993
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
994
+ else:
995
+ image = latents
996
+ has_nsfw_concept = None
997
+
998
+ if has_nsfw_concept is None:
999
+ do_denormalize = [True] * image.shape[0]
1000
+ else:
1001
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1002
+
1003
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1004
+
1005
+ # Offload all models
1006
+ self.maybe_free_model_hooks()
1007
+
1008
+ if not return_dict:
1009
+ return (image, has_nsfw_concept)
1010
+
1011
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
1012
+
1013
+ class LotusDPipeline(DirectDiffusionPipeline):
1014
+ @torch.no_grad()
1015
+ def __call__(
1016
+ self,
1017
+ rgb_in: Optional[torch.FloatTensor] = None,
1018
+ task_emb: Optional[torch.FloatTensor] = None,
1019
+ prompt: Union[str, List[str]] = None,
1020
+ timesteps: List[int] = None,
1021
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1022
+ prompt_embeds: Optional[torch.FloatTensor] = None,
1023
+ output_type: Optional[str] = "pil",
1024
+ return_dict: bool = True,
1025
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1026
+ **kwargs,
1027
+ ):
1028
+ r"""
1029
+ The call function to the pipeline for generation.
1030
+
1031
+ Args:
1032
+ rgb_input (`torch.FloatTensor`):
1033
+ Input RGB tensor, range [-1, 1].
1034
+ task_emb (`torch.FloatTensor`)
1035
+ Task switcher for reconstruction or dense prediction (depth or normal).
1036
+ prompt (`str` or `List[str]`, *optional*):
1037
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
1038
+ timesteps (`List[int]`, *optional*):
1039
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
1040
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
1041
+ passed will be used. Must be in descending order.
1042
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1043
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
1044
+ generation deterministic.
1045
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1046
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
1047
+ provided, text embeddings are generated from the `prompt` input argument.
1048
+ output_type (`str`, *optional*, defaults to `"pil"`):
1049
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
1050
+ return_dict (`bool`, *optional*, defaults to `True`):
1051
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1052
+ plain tuple.
1053
+ cross_attention_kwargs (`dict`, *optional*):
1054
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
1055
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1056
+ Examples:
1057
+
1058
+ Returns:
1059
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1060
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
1061
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
1062
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
1063
+ "not-safe-for-work" (nsfw) content.
1064
+ """
1065
+
1066
+
1067
+
1068
+ # 1. Define call parameters
1069
+ self._cross_attention_kwargs = cross_attention_kwargs
1070
+
1071
+ device = self._execution_device
1072
+
1073
+ # 2. Encode input prompt
1074
+ prompt_embeds, _ = self.encode_prompt(
1075
+ prompt,
1076
+ device,
1077
+ num_images_per_prompt=1,
1078
+ do_classifier_free_guidance=None,
1079
+ prompt_embeds=prompt_embeds,
1080
+ )
1081
+
1082
+ # 3. Prepare timesteps
1083
+ timesteps = torch.tensor(timesteps, device=device).long()
1084
+
1085
+ # 4. Prepare latent variables
1086
+ rgb_latents = self.vae.encode(rgb_in.to(device)).latent_dist.sample()
1087
+ rgb_latents = rgb_latents * self.vae.config.scaling_factor
1088
+
1089
+ # 5. Denoising
1090
+ t = timesteps[0]
1091
+ latent_model_input = rgb_latents
1092
+
1093
+ pred = self.unet(
1094
+ latent_model_input,
1095
+ t,
1096
+ encoder_hidden_states=prompt_embeds,
1097
+ cross_attention_kwargs=self.cross_attention_kwargs,
1098
+ return_dict=False,
1099
+ class_labels=task_emb,
1100
+ )[0]
1101
+
1102
+ if not output_type == "latent":
1103
+ image = self.vae.decode(pred / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
1104
+ 0
1105
+ ]
1106
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1107
+ else:
1108
+ image = pred
1109
+ has_nsfw_concept = None
1110
+
1111
+ if has_nsfw_concept is None:
1112
+ do_denormalize = [True] * image.shape[0]
1113
+ else:
1114
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1115
+
1116
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1117
+
1118
+ # Offload all models
1119
+ self.maybe_free_model_hooks()
1120
+
1121
+ if not return_dict:
1122
+ return (image, has_nsfw_concept)
1123
+
1124
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
1125
+
1126
+ class LotusGPipeline(DirectDiffusionPipeline):
1127
+ @torch.no_grad()
1128
+ def __call__(
1129
+ self,
1130
+ rgb_in: Optional[torch.FloatTensor] = None, # Modification 240430
1131
+ task_emb: Optional[torch.FloatTensor] = None,
1132
+ prompt: Union[str, List[str]] = None,
1133
+ num_inference_steps: int = 50,
1134
+ timesteps: List[int] = None,
1135
+ eta: float = 0.0,
1136
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1137
+ latents: Optional[torch.FloatTensor] = None,
1138
+ prompt_embeds: Optional[torch.FloatTensor] = None,
1139
+ output_type: Optional[str] = "pil",
1140
+ return_dict: bool = True,
1141
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1142
+ **kwargs,
1143
+ ):
1144
+ r"""
1145
+ The call function to the pipeline for generation.
1146
+
1147
+ Args:
1148
+ rgb_input (`torch.FloatTensor`):
1149
+ Input RGB tensor, range [-1, 1].
1150
+ task_emb (`torch.FloatTensor`)
1151
+ The task switcher to transfer the model outout domain between prediction and reconstruction.
1152
+ prompt (`str` or `List[str]`, *optional*):
1153
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
1154
+ num_inference_steps (`int`, *optional*, defaults to 50):
1155
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1156
+ expense of slower inference.
1157
+ timesteps (`List[int]`, *optional*):
1158
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
1159
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
1160
+ passed will be used. Must be in descending order.
1161
+ eta (`float`, *optional*, defaults to 0.0):
1162
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
1163
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
1164
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1165
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
1166
+ generation deterministic.
1167
+ latents (`torch.FloatTensor`, *optional*):
1168
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
1169
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1170
+ tensor is generated by sampling using the supplied random `generator`.
1171
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1172
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
1173
+ provided, text embeddings are generated from the `prompt` input argument.
1174
+ output_type (`str`, *optional*, defaults to `"pil"`):
1175
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
1176
+ return_dict (`bool`, *optional*, defaults to `True`):
1177
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1178
+ plain tuple.
1179
+ cross_attention_kwargs (`dict`, *optional*):
1180
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
1181
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1182
+ Examples:
1183
+
1184
+ Returns:
1185
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1186
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
1187
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
1188
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
1189
+ "not-safe-for-work" (nsfw) content.
1190
+ """
1191
+
1192
+ self._cross_attention_kwargs = cross_attention_kwargs
1193
+
1194
+ # 1. Default height and width to unet
1195
+ height, width = rgb_in.shape[2:]
1196
+
1197
+ # 2. Define call parameters
1198
+ batch_size = rgb_in.shape[0]
1199
+ device = self._execution_device
1200
+ print("Device: ", device)
1201
+
1202
+ # 3. Encode input prompt
1203
+ prompt_embeds, _ = self.encode_prompt(
1204
+ prompt,
1205
+ device,
1206
+ num_images_per_prompt=1,
1207
+ do_classifier_free_guidance=None,
1208
+ prompt_embeds=prompt_embeds,
1209
+ )
1210
+
1211
+ # 4. Prepare timesteps
1212
+ timesteps = torch.tensor(timesteps, device=device).long()
1213
+
1214
+ # 5. Prepare latent variables
1215
+ num_channels_latents = self.unet.config.in_channels // 2
1216
+ latents = self.prepare_latents(
1217
+ batch_size,
1218
+ num_channels_latents,
1219
+ height,
1220
+ width,
1221
+ prompt_embeds.dtype,
1222
+ device,
1223
+ generator,
1224
+ latents,
1225
+ )
1226
+
1227
+ rgb_latents = self.vae.encode(rgb_in.to(device)).latent_dist.sample()
1228
+ rgb_latents = rgb_latents * self.vae.config.scaling_factor
1229
+
1230
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1231
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1232
+
1233
+ # 7. Denoising loop
1234
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1235
+ self._num_timesteps = len(timesteps)
1236
+
1237
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1238
+ for i, t in enumerate(timesteps):
1239
+ latent_model_input = self.scheduler.scale_model_input(latents, t)
1240
+ latent_model_input = torch.cat(
1241
+ [rgb_latents, latent_model_input], dim=1
1242
+ )
1243
+
1244
+ x0_pred = self.unet(
1245
+ latent_model_input,
1246
+ t,
1247
+ encoder_hidden_states=prompt_embeds,
1248
+ cross_attention_kwargs=self.cross_attention_kwargs,
1249
+ return_dict=False,
1250
+ class_labels=task_emb,
1251
+ )[0]
1252
+
1253
+ if len(timesteps) > 1:
1254
+ # compute the previous noisy sample x_t -> x_t-1
1255
+ latents = self.scheduler.step(x0_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1256
+ else:
1257
+ latents = x0_pred
1258
+
1259
+ # call the callback, if provided
1260
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1261
+ progress_bar.update()
1262
+
1263
+ if not output_type == "latent":
1264
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
1265
+ 0
1266
+ ]
1267
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1268
+ else:
1269
+ image = latents
1270
+ has_nsfw_concept = None
1271
+
1272
+ if has_nsfw_concept is None:
1273
+ do_denormalize = [True] * image.shape[0]
1274
+ else:
1275
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1276
+
1277
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1278
+
1279
+ # Offload all models
1280
+ self.maybe_free_model_hooks()
1281
+
1282
+ if not return_dict:
1283
+ return (image, has_nsfw_concept)
1284
+
1285
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.3.1 --index-url https://download.pytorch.org/whl/cu121
2
+ torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cu121
3
+ torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu121
4
+ diffusers==0.28.0
5
+ accelerate>=0.16.0
6
+ transformers>=4.25.1
7
+ datasets==2.21.0
8
+ ftfy==6.2.3
9
+ tensorboard==2.17.1
10
+ Jinja2==3.1.3
11
+ peft==0.7.0
12
+ bitsandbytes==0.44.1
13
+ geffnet==1.0.2
14
+ opencv-python==4.10.0.82
15
+ matplotlib==3.8.4
16
+ h5py==3.11.0
17
+ omegaconf==2.3.0
18
+ tabulate==0.9.0
19
+ imageio==2.35.1
20
+ spaces==0.28.3
21
+ gradio==4.21.0
22
+ gradio-imageslider==0.0.16
23
+ gradio_client==0.12.0
utils/__pycache__/image_utils.cpython-310.pyc ADDED
Binary file (2.93 kB). View file
 
utils/__pycache__/image_utils.cpython-39.pyc ADDED
Binary file (2.92 kB). View file
 
utils/__pycache__/seed_all.cpython-310.pyc ADDED
Binary file (472 Bytes). View file
 
utils/__pycache__/seed_all.cpython-39.pyc ADDED
Binary file (476 Bytes). View file
 
utils/args.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+
4
+ def parse_args():
5
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
6
+ parser.add_argument(
7
+ "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1."
8
+ )
9
+ parser.add_argument(
10
+ "--pretrained_model_name_or_path",
11
+ type=str,
12
+ default=None,
13
+ required=True,
14
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
15
+ )
16
+ parser.add_argument(
17
+ "--revision",
18
+ type=str,
19
+ default=None,
20
+ required=False,
21
+ help="Revision of pretrained model identifier from huggingface.co/models.",
22
+ )
23
+ parser.add_argument(
24
+ "--variant",
25
+ type=str,
26
+ default=None,
27
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
28
+ )
29
+ parser.add_argument(
30
+ "--dataset_name",
31
+ type=str,
32
+ default=None,
33
+ help=(
34
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
35
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
36
+ " or to a folder containing files that 🤗 Datasets can understand."
37
+ ),
38
+ )
39
+ parser.add_argument(
40
+ "--dataset_config_name",
41
+ type=str,
42
+ default=None,
43
+ help="The config of the Dataset, leave as None if there's only one config.",
44
+ )
45
+ parser.add_argument(
46
+ "--train_data_dir_hypersim",
47
+ type=str,
48
+ default=None,
49
+ help=(
50
+ "A folder containing the training data. Folder contents must follow the structure described in"
51
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
52
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
53
+ ),
54
+ )
55
+ parser.add_argument(
56
+ "--train_data_dir_vkitti",
57
+ type=str,
58
+ default=None,
59
+ help=(
60
+ "A folder containing the training data. Folder contents must follow the structure described in"
61
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
62
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
63
+ ),
64
+ )
65
+ parser.add_argument(
66
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
67
+ )
68
+ parser.add_argument(
69
+ "--depth_column", type=str, default="depth", help="The column of the dataset containing a depth file."
70
+ )
71
+ parser.add_argument(
72
+ "--caption_column",
73
+ type=str,
74
+ default="text",
75
+ help="The column of the dataset containing a caption or a list of captions.",
76
+ )
77
+ parser.add_argument(
78
+ "--max_train_samples",
79
+ type=int,
80
+ default=None,
81
+ help=(
82
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
83
+ "value if set."
84
+ ),
85
+ )
86
+ parser.add_argument(
87
+ "--timestep",
88
+ type=int,
89
+ default=999,
90
+ )
91
+ parser.add_argument(
92
+ "--base_test_data_dir",
93
+ type=str,
94
+ default="datasets/eval/"
95
+ )
96
+ parser.add_argument(
97
+ "--task_name",
98
+ type=str,
99
+ default="depth", # "normal"
100
+ )
101
+ parser.add_argument(
102
+ "--validation_images",
103
+ type=str,
104
+ default=None,
105
+ help=("A set of images evaluated every `--validation_steps` and logged to `--report_to`."),
106
+ )
107
+ parser.add_argument(
108
+ "--output_dir",
109
+ type=str,
110
+ default="sd-model-finetuned",
111
+ help="The output directory where the model predictions and checkpoints will be written.",
112
+ )
113
+ parser.add_argument(
114
+ "--cache_dir",
115
+ type=str,
116
+ default=None,
117
+ help="The directory where the downloaded models and datasets will be stored.",
118
+ )
119
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
120
+ parser.add_argument(
121
+ "--resolution_hypersim",
122
+ type=int,
123
+ default=512,
124
+ help=(
125
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
126
+ " resolution"
127
+ ),
128
+ )
129
+ parser.add_argument(
130
+ "--resolution_vkitti",
131
+ type=int,
132
+ default=512,
133
+ help=(
134
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
135
+ " resolution"
136
+ ),
137
+ )
138
+ parser.add_argument(
139
+ "--prob_hypersim",
140
+ type=float,
141
+ default=0.9,
142
+ )
143
+ parser.add_argument(
144
+ "--mix_dataset",
145
+ action="store_true"
146
+ )
147
+ parser.add_argument(
148
+ "--mode",
149
+ type=str,
150
+ default="regression", # "generation"
151
+ help="Whether to use the generation or regression pipeline."
152
+ )
153
+ parser.add_argument(
154
+ "--norm_type",
155
+ type=str,
156
+ choices=['instnorm','truncnorm'],
157
+ default='truncnorm'
158
+ )
159
+ parser.add_argument(
160
+ "--center_crop",
161
+ default=False,
162
+ action="store_true",
163
+ help=(
164
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
165
+ " cropped. The images will be resized to the resolution first before cropping."
166
+ ),
167
+ )
168
+ parser.add_argument(
169
+ "--random_flip",
170
+ action="store_true",
171
+ help="whether to randomly flip images horizontally",
172
+ )
173
+ parser.add_argument(
174
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
175
+ )
176
+ parser.add_argument("--num_train_epochs", type=int, default=100)
177
+ parser.add_argument(
178
+ "--max_train_steps",
179
+ type=int,
180
+ default=None,
181
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
182
+ )
183
+ parser.add_argument(
184
+ "--gradient_accumulation_steps",
185
+ type=int,
186
+ default=1,
187
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
188
+ )
189
+ parser.add_argument(
190
+ "--gradient_checkpointing",
191
+ action="store_true",
192
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
193
+ )
194
+ parser.add_argument(
195
+ "--learning_rate",
196
+ type=float,
197
+ default=1e-4,
198
+ help="Initial learning rate (after the potential warmup period) to use.",
199
+ )
200
+ parser.add_argument(
201
+ "--scale_lr",
202
+ action="store_true",
203
+ default=False,
204
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
205
+ )
206
+ parser.add_argument(
207
+ "--lr_scheduler",
208
+ type=str,
209
+ default="constant",
210
+ help=(
211
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
212
+ ' "constant", "constant_with_warmup"]'
213
+ ),
214
+ )
215
+ parser.add_argument(
216
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
217
+ )
218
+ parser.add_argument(
219
+ "--snr_gamma",
220
+ type=float,
221
+ default=None,
222
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
223
+ "More details here: https://arxiv.org/abs/2303.09556.",
224
+ )
225
+ parser.add_argument(
226
+ "--dream_training",
227
+ action="store_true",
228
+ help=(
229
+ "Use the DREAM training method, which makes training more efficient and accurate at the ",
230
+ "expense of doing an extra forward pass. See: https://arxiv.org/abs/2312.00210",
231
+ ),
232
+ )
233
+ parser.add_argument(
234
+ "--dream_detail_preservation",
235
+ type=float,
236
+ default=1.0,
237
+ help="Dream detail preservation factor p (should be greater than 0; default=1.0, as suggested in the paper)",
238
+ )
239
+ parser.add_argument(
240
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
241
+ )
242
+ parser.add_argument(
243
+ "--allow_tf32",
244
+ action="store_true",
245
+ help=(
246
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
247
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
248
+ ),
249
+ )
250
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
251
+ parser.add_argument(
252
+ "--non_ema_revision",
253
+ type=str,
254
+ default=None,
255
+ required=False,
256
+ help=(
257
+ "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
258
+ " remote repository specified with --pretrained_model_name_or_path."
259
+ ),
260
+ )
261
+ parser.add_argument(
262
+ "--dataloader_num_workers",
263
+ type=int,
264
+ default=0,
265
+ help=(
266
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
267
+ ),
268
+ )
269
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
270
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
271
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
272
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
273
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
274
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
275
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
276
+ parser.add_argument(
277
+ "--prediction_type",
278
+ type=str,
279
+ default="sample",
280
+ help="The used prediction_type. ",
281
+ )
282
+ parser.add_argument(
283
+ "--hub_model_id",
284
+ type=str,
285
+ default=None,
286
+ help="The name of the repository to keep in sync with the local `output_dir`.",
287
+ )
288
+ parser.add_argument(
289
+ "--logging_dir",
290
+ type=str,
291
+ default="logs",
292
+ help=(
293
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
294
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
295
+ ),
296
+ )
297
+ parser.add_argument(
298
+ "--mixed_precision",
299
+ type=str,
300
+ default=None,
301
+ choices=["no", "fp16", "bf16"],
302
+ help=(
303
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
304
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
305
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
306
+ ),
307
+ )
308
+ parser.add_argument(
309
+ "--report_to",
310
+ type=str,
311
+ default="tensorboard",
312
+ help=(
313
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
314
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
315
+ ),
316
+ )
317
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
318
+ parser.add_argument(
319
+ "--checkpointing_steps",
320
+ type=int,
321
+ default=500,
322
+ help=(
323
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
324
+ " training using `--resume_from_checkpoint`."
325
+ ),
326
+ )
327
+ parser.add_argument(
328
+ "--checkpoints_total_limit",
329
+ type=int,
330
+ default=None,
331
+ help=("Max number of checkpoints to store."),
332
+ )
333
+ parser.add_argument(
334
+ "--resume_from_checkpoint",
335
+ type=str,
336
+ default=None,
337
+ help=(
338
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
339
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
340
+ ),
341
+ )
342
+ parser.add_argument(
343
+ "--checkpoint_dir",
344
+ type=str,
345
+ default=None,
346
+ )
347
+ parser.add_argument(
348
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
349
+ )
350
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
351
+ parser.add_argument("--use_pretrained_sd", action="store_true")
352
+ parser.add_argument(
353
+ "--truncnorm_min",
354
+ type=float,
355
+ default=0.02,
356
+ )
357
+ parser.add_argument(
358
+ "--validation_steps",
359
+ type=int,
360
+ default=500,
361
+ help="Run validation every X steps.",
362
+ )
363
+ parser.add_argument(
364
+ "--tracker_project_name",
365
+ type=str,
366
+ default="text2image-fine-tune",
367
+ help=(
368
+ "The `project_name` argument passed to Accelerator.init_trackers for"
369
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
370
+ ),
371
+ )
372
+ parser.add_argument(
373
+ "--inference",
374
+ action="store_true"
375
+ )
376
+
377
+ args = parser.parse_args()
378
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
379
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
380
+ args.local_rank = env_local_rank
381
+
382
+ # Sanity checks
383
+ if not args.inference and args.dataset_name is None and args.train_data_dir_hypersim is None:
384
+ raise ValueError("Need either a dataset name or a training folder.")
385
+
386
+ # default to using the same revision for the non-ema model if not specified
387
+ if args.non_ema_revision is None:
388
+ args.non_ema_revision = args.revision
389
+
390
+ return args
utils/image_utils.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import matplotlib
3
+ import numpy as np
4
+
5
+ from PIL import Image
6
+
7
+ import torch
8
+ from torchvision.transforms import InterpolationMode
9
+ from torchvision.transforms.functional import resize
10
+
11
+ def concatenate_images(*image_lists):
12
+ # Ensure at least one image list is provided
13
+ if not image_lists or not image_lists[0]:
14
+ raise ValueError("At least one non-empty image list must be provided")
15
+
16
+ # Determine the maximum width of any single row and the total height
17
+ max_width = 0
18
+ total_height = 0
19
+ row_widths = []
20
+ row_heights = []
21
+
22
+ # Compute dimensions for each row
23
+ for image_list in image_lists:
24
+ if image_list: # Ensure the list is not empty
25
+ width = sum(img.width for img in image_list)
26
+ height = image_list[0].height # Assuming all images in the list have the same height
27
+ max_width = max(max_width, width)
28
+ total_height += height
29
+ row_widths.append(width)
30
+ row_heights.append(height)
31
+
32
+ # Create a new image to concatenate everything into
33
+ new_image = Image.new('RGB', (max_width, total_height))
34
+
35
+ # Concatenate each row of images
36
+ y_offset = 0
37
+ for i, image_list in enumerate(image_lists):
38
+ x_offset = 0
39
+ for img in image_list:
40
+ new_image.paste(img, (x_offset, y_offset))
41
+ x_offset += img.width
42
+ y_offset += row_heights[i] # Move the offset down to the next row
43
+
44
+ return new_image
45
+
46
+
47
+ def colorize_depth_map(depth, mask=None):
48
+ cm = matplotlib.colormaps["Spectral"]
49
+ # normalize
50
+ depth = ((depth - depth.min()) / (depth.max() - depth.min()))
51
+ # colorize
52
+ img_colored_np = cm(depth, bytes=False)[:, :, 0:3] # (h,w,3)
53
+ depth_colored = (img_colored_np * 255).astype(np.uint8)
54
+ if mask is not None:
55
+ masked_image = np.zeros_like(depth_colored)
56
+ masked_image[mask.numpy()] = depth_colored[mask.numpy()]
57
+ depth_colored_img = Image.fromarray(masked_image)
58
+ else:
59
+ depth_colored_img = Image.fromarray(depth_colored)
60
+ return depth_colored_img
61
+
62
+
63
+ def resize_max_res(
64
+ img: torch.Tensor,
65
+ max_edge_resolution: int,
66
+ resample_method: InterpolationMode = InterpolationMode.BILINEAR,
67
+ ) -> torch.Tensor:
68
+ """
69
+ Resize image to limit maximum edge length while keeping aspect ratio.
70
+
71
+ Args:
72
+ img (`torch.Tensor`):
73
+ Image tensor to be resized. Expected shape: [B, C, H, W]
74
+ max_edge_resolution (`int`):
75
+ Maximum edge length (pixel).
76
+ resample_method (`PIL.Image.Resampling`):
77
+ Resampling method used to resize images.
78
+
79
+ Returns:
80
+ `torch.Tensor`: Resized image.
81
+ """
82
+ assert 4 == img.dim(), f"Invalid input shape {img.shape}"
83
+
84
+ original_height, original_width = img.shape[-2:]
85
+ downscale_factor = min(
86
+ max_edge_resolution / original_width, max_edge_resolution / original_height
87
+ )
88
+
89
+ new_width = int(original_width * downscale_factor)
90
+ new_height = int(original_height * downscale_factor)
91
+
92
+ resized_img = resize(img, (new_height, new_width), resample_method, antialias=True)
93
+ return resized_img
94
+
95
+
96
+ def get_tv_resample_method(method_str: str) -> InterpolationMode:
97
+ resample_method_dict = {
98
+ "bilinear": InterpolationMode.BILINEAR,
99
+ "bicubic": InterpolationMode.BICUBIC,
100
+ "nearest": InterpolationMode.NEAREST_EXACT,
101
+ "nearest-exact": InterpolationMode.NEAREST_EXACT,
102
+ }
103
+ resample_method = resample_method_dict.get(method_str, None)
104
+ if resample_method is None:
105
+ raise ValueError(f"Unknown resampling method: {resample_method}")
106
+ else:
107
+ return resample_method
utils/seed_all.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------------------------
15
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
16
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
+ # More information about the method can be found at https://marigoldmonodepth.github.io
18
+ # --------------------------------------------------------------------------
19
+
20
+
21
+ import numpy as np
22
+ import random
23
+ import torch
24
+
25
+
26
+ def seed_all(seed: int = 0):
27
+ """
28
+ Set random seeds of all components.
29
+ """
30
+ random.seed(seed)
31
+ np.random.seed(seed)
32
+ torch.manual_seed(seed)
33
+ torch.cuda.manual_seed_all(seed)
utils/visualize.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ import torch
5
+
6
+ from matplotlib import cm
7
+ import matplotlib.pyplot as plt
8
+
9
+ import logging
10
+ logger = logging.getLogger('root')
11
+
12
+
13
+
14
+ def tensor_to_numpy(tensor_in):
15
+ """ torch tensor to numpy array
16
+ """
17
+ if tensor_in is not None:
18
+ if tensor_in.ndim == 3:
19
+ # (C, H, W) -> (H, W, C)
20
+ tensor_in = tensor_in.detach().cpu().permute(1, 2, 0).numpy()
21
+ elif tensor_in.ndim == 4:
22
+ # (B, C, H, W) -> (B, H, W, C)
23
+ tensor_in = tensor_in.detach().cpu().permute(0, 2, 3, 1).numpy()
24
+ else:
25
+ raise Exception('invalid tensor size')
26
+ return tensor_in
27
+
28
+ # def unnormalize(img_in, img_stats={'mean': [0.485, 0.456, 0.406],
29
+ # 'std': [0.229, 0.224, 0.225]}):
30
+ def unnormalize(img_in, img_stats={'mean': [0.5,0.5,0.5], 'std': [0.5,0.5,0.5]}):
31
+ """ unnormalize input image
32
+ """
33
+ if torch.is_tensor(img_in):
34
+ img_in = tensor_to_numpy(img_in)
35
+
36
+ img_out = np.zeros_like(img_in)
37
+ for ich in range(3):
38
+ img_out[..., ich] = img_in[..., ich] * img_stats['std'][ich]
39
+ img_out[..., ich] += img_stats['mean'][ich]
40
+ img_out = (img_out * 255.0).astype(np.uint8)
41
+ return img_out
42
+
43
+ def normal_to_rgb(normal, normal_mask=None):
44
+ """ surface normal map to RGB
45
+ (used for visualization)
46
+
47
+ NOTE: x, y, z are mapped to R, G, B
48
+ NOTE: [-1, 1] are mapped to [0, 255]
49
+ """
50
+ if torch.is_tensor(normal):
51
+ normal = tensor_to_numpy(normal)
52
+ normal_mask = tensor_to_numpy(normal_mask)
53
+
54
+ normal_norm = np.linalg.norm(normal, axis=-1, keepdims=True)
55
+ normal_norm[normal_norm < 1e-12] = 1e-12
56
+ normal = normal / normal_norm
57
+
58
+ normal_rgb = (((normal + 1) * 0.5) * 255).astype(np.uint8)
59
+ if normal_mask is not None:
60
+ normal_rgb = normal_rgb * normal_mask # (B, H, W, 3)
61
+ return normal_rgb
62
+
63
+ def kappa_to_alpha(pred_kappa, to_numpy=True):
64
+ """ Confidence kappa to uncertainty alpha
65
+ Assuming AngMF distribution (introduced in https://arxiv.org/abs/2109.09881)
66
+ """
67
+ if torch.is_tensor(pred_kappa) and to_numpy:
68
+ pred_kappa = tensor_to_numpy(pred_kappa)
69
+
70
+ if torch.is_tensor(pred_kappa):
71
+ alpha = ((2 * pred_kappa) / ((pred_kappa ** 2.0) + 1)) \
72
+ + ((torch.exp(- pred_kappa * np.pi) * np.pi) / (1 + torch.exp(- pred_kappa * np.pi)))
73
+ alpha = torch.rad2deg(alpha)
74
+ else:
75
+ alpha = ((2 * pred_kappa) / ((pred_kappa ** 2.0) + 1)) \
76
+ + ((np.exp(- pred_kappa * np.pi) * np.pi) / (1 + np.exp(- pred_kappa * np.pi)))
77
+ alpha = np.degrees(alpha)
78
+
79
+ return alpha
80
+
81
+
82
+ def visualize_normal(target_dir, prefixs, img, pred_norm, pred_kappa,
83
+ gt_norm, gt_norm_mask, pred_error, num_vis=-1):
84
+ """ visualize normal
85
+ """
86
+ error_max = 60.0
87
+
88
+ img = tensor_to_numpy(img) # (B, H, W, 3)
89
+ pred_norm = tensor_to_numpy(pred_norm) # (B, H, W, 3)
90
+ pred_kappa = tensor_to_numpy(pred_kappa) # (B, H, W, 1)
91
+ gt_norm = tensor_to_numpy(gt_norm) # (B, H, W, 3)
92
+ gt_norm_mask = tensor_to_numpy(gt_norm_mask) # (B, H, W, 1)
93
+ pred_error = tensor_to_numpy(pred_error) # (B, H, W, 1)
94
+
95
+ num_vis = len(prefixs) if num_vis == -1 else num_vis
96
+ for i in range(num_vis):
97
+ # img
98
+ img_ = unnormalize(img[i, ...])
99
+ target_path = '%s/%s_img.png' % (target_dir, prefixs[i])
100
+ plt.imsave(target_path, img_)
101
+
102
+ # pred_norm
103
+ target_path = '%s/%s_norm.png' % (target_dir, prefixs[i])
104
+ plt.imsave(target_path, normal_to_rgb(pred_norm[i, ...]))
105
+
106
+ # pred_kappa
107
+ if pred_kappa is not None:
108
+ pred_alpha = kappa_to_alpha(pred_kappa[i, :, :, 0])
109
+ target_path = '%s/%s_pred_alpha.png' % (target_dir, prefixs[i])
110
+ plt.imsave(target_path, pred_alpha, vmin=0.0, vmax=error_max, cmap='jet')
111
+
112
+ # gt_norm, pred_error
113
+ if gt_norm is not None:
114
+ target_path = '%s/%s_gt.png' % (target_dir, prefixs[i])
115
+ plt.imsave(target_path, normal_to_rgb(gt_norm[i, ...], gt_norm_mask[i, ...]))
116
+
117
+ E = pred_error[i, :, :, 0] * gt_norm_mask[i, :, :, 0]
118
+ target_path = '%s/%s_pred_error.png' % (target_dir, prefixs[i])
119
+ plt.imsave(target_path, E, vmin=0, vmax=error_max, cmap='jet')