Spaces:
Running
on
Zero
Running
on
Zero
Upload Files
Browse files- example_images/kitten.jpg +0 -0
- example_images/lion.jpeg +0 -0
- example_images/monkey.jpeg +0 -0
- gradio_app.py +365 -0
- main.py +85 -0
- requirements.txt +7 -0
- src/config.py +64 -0
- src/ddpm_scheduler.py +219 -0
- src/enums_utils.py +190 -0
- src/euler_scheduler.py +588 -0
- src/eunms.py +26 -0
- src/images_utils.py +74 -0
- src/inversion_utils.py +86 -0
- src/lcm_scheduler.py +196 -0
- src/lpips.py +147 -0
- src/metric_util.py +61 -0
- src/sd_inversion_pipeline.py +634 -0
- src/sdxl_inversion_pipeline.py +430 -0
- style.css +4 -0
example_images/kitten.jpg
ADDED
example_images/lion.jpeg
ADDED
example_images/monkey.jpeg
ADDED
gradio_app.py
ADDED
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
from PIL import Image
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from src.eunms import Model_Type, Scheduler_Type, Gradient_Averaging_Type, Epsilon_Update_Type
|
8 |
+
from src.enums_utils import model_type_to_size, get_pipes
|
9 |
+
from src.config import RunConfig
|
10 |
+
from main import run as run_model
|
11 |
+
|
12 |
+
|
13 |
+
DESCRIPTION = '''# ReNoise: Real Image Inversion Through Iterative Noising
|
14 |
+
This is a demo for our ''ReNoise: Real Image Inversion Through Iterative Noising'' [paper](https://garibida.github.io/ReNoise-Inversion/). Code is available [here](https://github.com/garibida/ReNoise-Inversion)
|
15 |
+
Our ReNoise inversion technique can be applied to various diffusion models, including recent few-step ones such as SDXL-Turbo.
|
16 |
+
This demo preform real image editing using our ReNoise inversion. The input image is resize to size of 512x512, the optimal size of SDXL Turbo.
|
17 |
+
'''
|
18 |
+
|
19 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
20 |
+
model_type = Model_Type.SDXL_Turbo
|
21 |
+
scheduler_type = Scheduler_Type.EULER
|
22 |
+
image_size = model_type_to_size(Model_Type.SDXL_Turbo)
|
23 |
+
pipe_inversion, pipe_inference = get_pipes(model_type, scheduler_type, device=device)
|
24 |
+
|
25 |
+
cache_size = 10
|
26 |
+
prev_configs = [None for i in range(cache_size)]
|
27 |
+
prev_inv_latents = [None for i in range(cache_size)]
|
28 |
+
prev_images = [None for i in range(cache_size)]
|
29 |
+
prev_noises = [None for i in range(cache_size)]
|
30 |
+
|
31 |
+
def main_pipeline(
|
32 |
+
input_image: str,
|
33 |
+
src_prompt: str,
|
34 |
+
tgt_prompt: str,
|
35 |
+
edit_cfg: float,
|
36 |
+
number_of_renoising_iterations: int,
|
37 |
+
inersion_strength: float,
|
38 |
+
avg_gradients: bool,
|
39 |
+
first_step_range_start: int,
|
40 |
+
first_step_range_end: int,
|
41 |
+
rest_step_range_start: int,
|
42 |
+
rest_step_range_end: int,
|
43 |
+
lambda_ac: float,
|
44 |
+
lambda_kl: float,
|
45 |
+
noise_correction: bool):
|
46 |
+
|
47 |
+
global prev_configs, prev_inv_latents, prev_images, prev_noises
|
48 |
+
|
49 |
+
update_epsilon_type = Epsilon_Update_Type.OPTIMIZE if noise_correction else Epsilon_Update_Type.NONE
|
50 |
+
avg_gradients_type = Gradient_Averaging_Type.ON_END if avg_gradients else Gradient_Averaging_Type.NONE
|
51 |
+
|
52 |
+
first_step_range = (first_step_range_start, first_step_range_end)
|
53 |
+
rest_step_range = (rest_step_range_start, rest_step_range_end)
|
54 |
+
|
55 |
+
config = RunConfig(model_type = model_type,
|
56 |
+
num_inference_steps = 4,
|
57 |
+
num_inversion_steps = 4,
|
58 |
+
guidance_scale = 0.0,
|
59 |
+
max_num_aprox_steps_first_step = first_step_range_end+1,
|
60 |
+
num_aprox_steps = number_of_renoising_iterations,
|
61 |
+
inversion_max_step = inersion_strength,
|
62 |
+
gradient_averaging_type = avg_gradients_type,
|
63 |
+
gradient_averaging_first_step_range = first_step_range,
|
64 |
+
gradient_averaging_step_range = rest_step_range,
|
65 |
+
scheduler_type = scheduler_type,
|
66 |
+
num_reg_steps = 4,
|
67 |
+
num_ac_rolls = 5,
|
68 |
+
lambda_ac = lambda_ac,
|
69 |
+
lambda_kl = lambda_kl,
|
70 |
+
update_epsilon_type = update_epsilon_type,
|
71 |
+
do_reconstruction = True)
|
72 |
+
config.prompt = src_prompt
|
73 |
+
|
74 |
+
inv_latent = None
|
75 |
+
noise_list = None
|
76 |
+
for i in range(cache_size):
|
77 |
+
if prev_configs[i] is not None and prev_configs[i] == config and prev_images[i] == input_image:
|
78 |
+
print(f"Using cache for config #{i}")
|
79 |
+
inv_latent = prev_inv_latents[i]
|
80 |
+
noise_list = prev_noises[i]
|
81 |
+
prev_configs.pop(i)
|
82 |
+
prev_inv_latents.pop(i)
|
83 |
+
prev_images.pop(i)
|
84 |
+
prev_noises.pop(i)
|
85 |
+
break
|
86 |
+
|
87 |
+
original_image = Image.open(input_image).convert("RGB").resize(image_size)
|
88 |
+
|
89 |
+
res_image, inv_latent, noise, all_latents = run_model(original_image,
|
90 |
+
config,
|
91 |
+
latents=inv_latent,
|
92 |
+
pipe_inversion=pipe_inversion,
|
93 |
+
pipe_inference=pipe_inference,
|
94 |
+
edit_prompt=tgt_prompt,
|
95 |
+
noise=noise_list,
|
96 |
+
edit_cfg=edit_cfg)
|
97 |
+
|
98 |
+
prev_configs.append(config)
|
99 |
+
prev_inv_latents.append(inv_latent)
|
100 |
+
prev_images.append(input_image)
|
101 |
+
prev_noises.append(noise)
|
102 |
+
|
103 |
+
if len(prev_configs) > cache_size:
|
104 |
+
print("Popping cache")
|
105 |
+
prev_configs.pop(0)
|
106 |
+
prev_inv_latents.pop(0)
|
107 |
+
prev_images.pop(0)
|
108 |
+
prev_noises.pop(0)
|
109 |
+
|
110 |
+
return res_image
|
111 |
+
|
112 |
+
|
113 |
+
with gr.Blocks(css='style.css') as demo:
|
114 |
+
gr.Markdown(DESCRIPTION)
|
115 |
+
|
116 |
+
gr.HTML(
|
117 |
+
'''<a href="https://huggingface.co/spaces/orpatashnik/local-prompt-mixing?duplicate=true">
|
118 |
+
<img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate the Space to run privately without waiting in queue''')
|
119 |
+
|
120 |
+
with gr.Row():
|
121 |
+
with gr.Column():
|
122 |
+
input_image = gr.Image(
|
123 |
+
label="Input image",
|
124 |
+
type="filepath",
|
125 |
+
height=image_size[0],
|
126 |
+
width=image_size[1]
|
127 |
+
)
|
128 |
+
src_prompt = gr.Text(
|
129 |
+
label='Source Prompt',
|
130 |
+
max_lines=1,
|
131 |
+
placeholder='A kitten is sitting in a basket on a branch',
|
132 |
+
)
|
133 |
+
tgt_prompt = gr.Text(
|
134 |
+
label='Target Prompt',
|
135 |
+
max_lines=1,
|
136 |
+
placeholder='A plush toy kitten is sitting in a basket on a branch',
|
137 |
+
)
|
138 |
+
with gr.Accordion("Advanced Options", open=False):
|
139 |
+
edit_cfg = gr.Slider(
|
140 |
+
label='Denoise Classifier-Free Guidence Scale',
|
141 |
+
minimum=1.0,
|
142 |
+
maximum=3.5,
|
143 |
+
value=1.0,
|
144 |
+
step=0.1
|
145 |
+
)
|
146 |
+
number_of_renoising_iterations = gr.Slider(
|
147 |
+
label='Number of ReNoise Iterations',
|
148 |
+
minimum=0,
|
149 |
+
maximum=20,
|
150 |
+
value=9,
|
151 |
+
step=1
|
152 |
+
)
|
153 |
+
inersion_strength = gr.Slider(
|
154 |
+
label='Inversion Strength',
|
155 |
+
minimum=0.0,
|
156 |
+
maximum=1.0,
|
157 |
+
value=1.0,
|
158 |
+
step=0.25
|
159 |
+
)
|
160 |
+
avg_gradients = gr.Checkbox(
|
161 |
+
label="Preform Estimation Averaging"
|
162 |
+
)
|
163 |
+
first_step_range_start = gr.Slider(
|
164 |
+
label='First Estimation in Average (t < 250)',
|
165 |
+
minimum=0,
|
166 |
+
maximum=21,
|
167 |
+
value=0,
|
168 |
+
step=1
|
169 |
+
)
|
170 |
+
first_step_range_end = gr.Slider(
|
171 |
+
label='Last Estimation in Average (t < 250)',
|
172 |
+
minimum=0,
|
173 |
+
maximum=21,
|
174 |
+
value=5,
|
175 |
+
step=1
|
176 |
+
)
|
177 |
+
rest_step_range_start = gr.Slider(
|
178 |
+
label='First Estimation in Average (t > 250)',
|
179 |
+
minimum=0,
|
180 |
+
maximum=21,
|
181 |
+
value=8,
|
182 |
+
step=1
|
183 |
+
)
|
184 |
+
rest_step_range_end = gr.Slider(
|
185 |
+
label='Last Estimation in Average (t > 250)',
|
186 |
+
minimum=0,
|
187 |
+
maximum=21,
|
188 |
+
value=10,
|
189 |
+
step=1
|
190 |
+
)
|
191 |
+
num_reg_steps = 4
|
192 |
+
num_ac_rolls = 5
|
193 |
+
lambda_ac = gr.Slider(
|
194 |
+
label='Labmda AC',
|
195 |
+
minimum=0.0,
|
196 |
+
maximum=50.0,
|
197 |
+
value=20.0,
|
198 |
+
step=1.0
|
199 |
+
)
|
200 |
+
lambda_kl = gr.Slider(
|
201 |
+
label='Labmda Patch KL',
|
202 |
+
minimum=0.0,
|
203 |
+
maximum=0.4,
|
204 |
+
value=0.065,
|
205 |
+
step=0.005
|
206 |
+
)
|
207 |
+
noise_correction = gr.Checkbox(
|
208 |
+
label="Preform Noise Correction"
|
209 |
+
)
|
210 |
+
|
211 |
+
run_button = gr.Button('Edit')
|
212 |
+
with gr.Column():
|
213 |
+
# result = gr.Gallery(label='Result')
|
214 |
+
result = gr.Image(
|
215 |
+
label="Result",
|
216 |
+
type="pil",
|
217 |
+
height=image_size[0],
|
218 |
+
width=image_size[1]
|
219 |
+
)
|
220 |
+
|
221 |
+
examples = [
|
222 |
+
[
|
223 |
+
"example_images/kitten.jpg", #input_image
|
224 |
+
"A kitten is sitting in a basket on a branch", #src_prompt
|
225 |
+
"a lego kitten is sitting in a basket on a branch", #tgt_prompt
|
226 |
+
1.0, #edit_cfg
|
227 |
+
9, #number_of_renoising_iterations
|
228 |
+
1.0, #inersion_strength
|
229 |
+
True, #avg_gradients
|
230 |
+
0, #first_step_range_start
|
231 |
+
5, #first_step_range_end
|
232 |
+
8, #rest_step_range_start
|
233 |
+
10, #rest_step_range_end
|
234 |
+
20.0, #lambda_ac
|
235 |
+
0.055, #lambda_kl
|
236 |
+
False #noise_correction
|
237 |
+
],
|
238 |
+
[
|
239 |
+
"example_images/kitten.jpg", #input_image
|
240 |
+
"A kitten is sitting in a basket on a branch", #src_prompt
|
241 |
+
"a brokkoli is sitting in a basket on a branch", #tgt_prompt
|
242 |
+
1.0, #edit_cfg
|
243 |
+
9, #number_of_renoising_iterations
|
244 |
+
1.0, #inersion_strength
|
245 |
+
True, #avg_gradients
|
246 |
+
0, #first_step_range_start
|
247 |
+
5, #first_step_range_end
|
248 |
+
8, #rest_step_range_start
|
249 |
+
10, #rest_step_range_end
|
250 |
+
20.0, #lambda_ac
|
251 |
+
0.055, #lambda_kl
|
252 |
+
False #noise_correction
|
253 |
+
],
|
254 |
+
[
|
255 |
+
"example_images/kitten.jpg", #input_image
|
256 |
+
"A kitten is sitting in a basket on a branch", #src_prompt
|
257 |
+
"a dog is sitting in a basket on a branch", #tgt_prompt
|
258 |
+
1.0, #edit_cfg
|
259 |
+
9, #number_of_renoising_iterations
|
260 |
+
1.0, #inersion_strength
|
261 |
+
True, #avg_gradients
|
262 |
+
0, #first_step_range_start
|
263 |
+
5, #first_step_range_end
|
264 |
+
8, #rest_step_range_start
|
265 |
+
10, #rest_step_range_end
|
266 |
+
20.0, #lambda_ac
|
267 |
+
0.055, #lambda_kl
|
268 |
+
False #noise_correction
|
269 |
+
],
|
270 |
+
[
|
271 |
+
"example_images/monkey.jpeg", #input_image
|
272 |
+
"a monkey sitting on a tree branch in the forest", #src_prompt
|
273 |
+
"a beaver sitting on a tree branch in the forest", #tgt_prompt
|
274 |
+
1.0, #edit_cfg
|
275 |
+
9, #number_of_renoising_iterations
|
276 |
+
1.0, #inersion_strength
|
277 |
+
True, #avg_gradients
|
278 |
+
0, #first_step_range_start
|
279 |
+
5, #first_step_range_end
|
280 |
+
8, #rest_step_range_start
|
281 |
+
10, #rest_step_range_end
|
282 |
+
20.0, #lambda_ac
|
283 |
+
0.055, #lambda_kl
|
284 |
+
True #noise_correction
|
285 |
+
],
|
286 |
+
[
|
287 |
+
"example_images/monkey.jpeg", #input_image
|
288 |
+
"a monkey sitting on a tree branch in the forest", #src_prompt
|
289 |
+
"a raccoon sitting on a tree branch in the forest", #tgt_prompt
|
290 |
+
1.0, #edit_cfg
|
291 |
+
9, #number_of_renoising_iterations
|
292 |
+
1.0, #inersion_strength
|
293 |
+
True, #avg_gradients
|
294 |
+
0, #first_step_range_start
|
295 |
+
5, #first_step_range_end
|
296 |
+
8, #rest_step_range_start
|
297 |
+
10, #rest_step_range_end
|
298 |
+
20.0, #lambda_ac
|
299 |
+
0.055, #lambda_kl
|
300 |
+
True #noise_correction
|
301 |
+
],
|
302 |
+
[
|
303 |
+
"example_images/lion.jpeg", #input_image
|
304 |
+
"a lion is sitting in the grass at sunset", #src_prompt
|
305 |
+
"a tiger is sitting in the grass at sunset", #tgt_prompt
|
306 |
+
1.0, #edit_cfg
|
307 |
+
9, #number_of_renoising_iterations
|
308 |
+
1.0, #inersion_strength
|
309 |
+
True, #avg_gradients
|
310 |
+
0, #first_step_range_start
|
311 |
+
5, #first_step_range_end
|
312 |
+
8, #rest_step_range_start
|
313 |
+
10, #rest_step_range_end
|
314 |
+
20.0, #lambda_ac
|
315 |
+
0.055, #lambda_kl
|
316 |
+
True #noise_correction
|
317 |
+
]
|
318 |
+
]
|
319 |
+
|
320 |
+
gr.Examples(examples=examples,
|
321 |
+
inputs=[
|
322 |
+
input_image,
|
323 |
+
src_prompt,
|
324 |
+
tgt_prompt,
|
325 |
+
edit_cfg,
|
326 |
+
number_of_renoising_iterations,
|
327 |
+
inersion_strength,
|
328 |
+
avg_gradients,
|
329 |
+
first_step_range_start,
|
330 |
+
first_step_range_end,
|
331 |
+
rest_step_range_start,
|
332 |
+
rest_step_range_end,
|
333 |
+
lambda_ac,
|
334 |
+
lambda_kl,
|
335 |
+
noise_correction
|
336 |
+
],
|
337 |
+
outputs=[
|
338 |
+
result
|
339 |
+
],
|
340 |
+
fn=main_pipeline,
|
341 |
+
cache_examples=True)
|
342 |
+
|
343 |
+
|
344 |
+
inputs = [
|
345 |
+
input_image,
|
346 |
+
src_prompt,
|
347 |
+
tgt_prompt,
|
348 |
+
edit_cfg,
|
349 |
+
number_of_renoising_iterations,
|
350 |
+
inersion_strength,
|
351 |
+
avg_gradients,
|
352 |
+
first_step_range_start,
|
353 |
+
first_step_range_end,
|
354 |
+
rest_step_range_start,
|
355 |
+
rest_step_range_end,
|
356 |
+
lambda_ac,
|
357 |
+
lambda_kl,
|
358 |
+
noise_correction
|
359 |
+
]
|
360 |
+
outputs = [
|
361 |
+
result
|
362 |
+
]
|
363 |
+
run_button.click(fn=main_pipeline, inputs=inputs, outputs=outputs)
|
364 |
+
|
365 |
+
demo.queue(max_size=50).launch(share=True)
|
main.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pyrallis
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
from diffusers.utils.torch_utils import randn_tensor
|
5 |
+
|
6 |
+
from src.config import RunConfig, Scheduler_Type
|
7 |
+
from src.enums_utils import model_type_to_size
|
8 |
+
|
9 |
+
@pyrallis.wrap()
|
10 |
+
def main(cfg: RunConfig):
|
11 |
+
run(cfg)
|
12 |
+
|
13 |
+
def inversion_callback(pipe, step, timestep, callback_kwargs):
|
14 |
+
return callback_kwargs
|
15 |
+
|
16 |
+
def inference_callback(pipe, step, timestep, callback_kwargs):
|
17 |
+
return callback_kwargs
|
18 |
+
|
19 |
+
def run(init_image: Image, cfg: RunConfig, pipe_inversion, pipe_inference, latents = None, edit_prompt = None, edit_cfg = 1.0, noise = None):
|
20 |
+
# pyrallis.dump(cfg, open(cfg.output_path / 'config.yaml', 'w'))
|
21 |
+
|
22 |
+
if latents is None and cfg.scheduler_type == Scheduler_Type.EULER or cfg.scheduler_type == Scheduler_Type.LCM or cfg.scheduler_type == Scheduler_Type.DDPM:
|
23 |
+
g_cpu = torch.Generator().manual_seed(7865)
|
24 |
+
img_size = model_type_to_size(cfg.model_type)
|
25 |
+
VQAE_SCALE = 8
|
26 |
+
latents_size = (1, 4, img_size[0] // VQAE_SCALE, img_size[1] // VQAE_SCALE)
|
27 |
+
noise = [randn_tensor(latents_size, dtype=torch.float16, device=torch.device("cuda:0"), generator=g_cpu) for i in range(cfg.num_inversion_steps)]
|
28 |
+
pipe_inversion.scheduler.set_noise_list(noise)
|
29 |
+
pipe_inference.scheduler.set_noise_list(noise)
|
30 |
+
pipe_inversion.scheduler_inference.set_noise_list(noise)
|
31 |
+
|
32 |
+
if latents is not None and cfg.scheduler_type == Scheduler_Type.EULER or cfg.scheduler_type == Scheduler_Type.LCM or cfg.scheduler_type == Scheduler_Type.DDPM:
|
33 |
+
pipe_inversion.scheduler.set_noise_list(noise)
|
34 |
+
pipe_inference.scheduler.set_noise_list(noise)
|
35 |
+
pipe_inversion.scheduler_inference.set_noise_list(noise)
|
36 |
+
|
37 |
+
|
38 |
+
pipe_inversion.cfg = cfg
|
39 |
+
pipe_inference.cfg = cfg
|
40 |
+
all_latents = None
|
41 |
+
|
42 |
+
if latents is None:
|
43 |
+
print("Inverting...")
|
44 |
+
if cfg.save_gpu_mem:
|
45 |
+
pipe_inference.to("cpu")
|
46 |
+
pipe_inversion.to("cuda")
|
47 |
+
res = pipe_inversion(prompt = cfg.prompt,
|
48 |
+
num_inversion_steps = cfg.num_inversion_steps,
|
49 |
+
num_inference_steps = cfg.num_inference_steps,
|
50 |
+
image = init_image,
|
51 |
+
guidance_scale = cfg.guidance_scale,
|
52 |
+
opt_iters = cfg.opt_iters,
|
53 |
+
opt_lr = cfg.opt_lr,
|
54 |
+
callback_on_step_end = inversion_callback,
|
55 |
+
strength = cfg.inversion_max_step,
|
56 |
+
denoising_start = 1.0-cfg.inversion_max_step,
|
57 |
+
opt_loss_kl_lambda = cfg.loss_kl_lambda,
|
58 |
+
num_aprox_steps = cfg.num_aprox_steps)
|
59 |
+
latents = res[0][0]
|
60 |
+
all_latents = res[1]
|
61 |
+
|
62 |
+
inv_latent = latents.clone()
|
63 |
+
|
64 |
+
if cfg.do_reconstruction:
|
65 |
+
print("Generating...")
|
66 |
+
edit_prompt = cfg.prompt if edit_prompt is None else edit_prompt
|
67 |
+
guidance_scale = edit_cfg
|
68 |
+
if cfg.save_gpu_mem:
|
69 |
+
pipe_inversion.to("cpu")
|
70 |
+
pipe_inference.to("cuda")
|
71 |
+
img = pipe_inference(prompt = edit_prompt,
|
72 |
+
num_inference_steps = cfg.num_inference_steps,
|
73 |
+
negative_prompt = cfg.prompt,
|
74 |
+
callback_on_step_end = inference_callback,
|
75 |
+
image = latents,
|
76 |
+
strength = cfg.inversion_max_step,
|
77 |
+
denoising_start = 1.0-cfg.inversion_max_step,
|
78 |
+
guidance_scale = guidance_scale).images[0]
|
79 |
+
else:
|
80 |
+
img = None
|
81 |
+
|
82 |
+
return img, inv_latent, noise, all_latents
|
83 |
+
|
84 |
+
if __name__ == "__main__":
|
85 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
torch==2.2.1
|
3 |
+
torchvision==0.17.1
|
4 |
+
diffusers==0.24.0
|
5 |
+
transformers==4.32.1
|
6 |
+
pyrallis==0.3.1
|
7 |
+
accelerate==0.25.0
|
src/config.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import NamedTuple
|
4 |
+
|
5 |
+
from src.eunms import Model_Type, Scheduler_Type, Gradient_Averaging_Type, Epsilon_Update_Type
|
6 |
+
|
7 |
+
@dataclass
|
8 |
+
class RunConfig:
|
9 |
+
model_type : Model_Type = Model_Type.SDXL_Turbo
|
10 |
+
|
11 |
+
scheduler_type : Scheduler_Type = Scheduler_Type.EULER
|
12 |
+
|
13 |
+
prompt: str = ""
|
14 |
+
|
15 |
+
num_inference_steps: int = 4
|
16 |
+
|
17 |
+
num_inversion_steps: int = 100
|
18 |
+
|
19 |
+
opt_lr: float = 0.1
|
20 |
+
|
21 |
+
opt_iters: int = 0
|
22 |
+
|
23 |
+
opt_none_inference_steps: bool = False
|
24 |
+
|
25 |
+
guidance_scale: float = 0.0
|
26 |
+
|
27 |
+
# pipe_inversion: DiffusionPipeline = None
|
28 |
+
|
29 |
+
# pipe_inference: DiffusionPipeline = None
|
30 |
+
|
31 |
+
save_gpu_mem: bool = False
|
32 |
+
|
33 |
+
do_reconstruction: bool = True
|
34 |
+
|
35 |
+
loss_kl_lambda: float = 10.0
|
36 |
+
|
37 |
+
max_num_aprox_steps_first_step: int = 1
|
38 |
+
|
39 |
+
num_aprox_steps: int = 10
|
40 |
+
|
41 |
+
inversion_max_step: float = 1.0
|
42 |
+
|
43 |
+
gradient_averaging_type: Gradient_Averaging_Type = Gradient_Averaging_Type.NONE
|
44 |
+
|
45 |
+
gradient_averaging_first_step_range: tuple = (0, 10)
|
46 |
+
|
47 |
+
gradient_averaging_step_range: tuple = (0, 10)
|
48 |
+
|
49 |
+
noise_friendly_inversion: bool = False
|
50 |
+
|
51 |
+
update_epsilon_type: Epsilon_Update_Type = Gradient_Averaging_Type.NONE
|
52 |
+
|
53 |
+
#pip2pip zero
|
54 |
+
|
55 |
+
lambda_ac: float = 20.0
|
56 |
+
|
57 |
+
lambda_kl: float = 20.0
|
58 |
+
|
59 |
+
num_reg_steps: int = 5
|
60 |
+
|
61 |
+
num_ac_rolls: int = 5
|
62 |
+
|
63 |
+
def __post_init__(self):
|
64 |
+
pass
|
src/ddpm_scheduler.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import DDPMScheduler, LCMScheduler
|
2 |
+
from diffusers.utils import BaseOutput
|
3 |
+
from diffusers.utils.torch_utils import randn_tensor
|
4 |
+
import torch
|
5 |
+
from typing import List, Optional, Tuple, Union
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
class DDPMSchedulerOutput(BaseOutput):
|
9 |
+
"""
|
10 |
+
Output class for the scheduler's `step` function output.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
14 |
+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
15 |
+
denoising loop.
|
16 |
+
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
17 |
+
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
18 |
+
`pred_original_sample` can be used to preview progress or for guidance.
|
19 |
+
"""
|
20 |
+
|
21 |
+
prev_sample: torch.FloatTensor
|
22 |
+
pred_original_sample: Optional[torch.FloatTensor] = None
|
23 |
+
|
24 |
+
class MyDDPMScheduler(DDPMScheduler):
|
25 |
+
def set_noise_list(self, noise_list):
|
26 |
+
self.noise_list = noise_list
|
27 |
+
|
28 |
+
def step_and_update(
|
29 |
+
self,
|
30 |
+
model_output: torch.FloatTensor,
|
31 |
+
timestep: int,
|
32 |
+
sample: torch.FloatTensor,
|
33 |
+
next_sample: torch.FloatTensor = None,
|
34 |
+
generator=None,
|
35 |
+
return_dict: bool = True,
|
36 |
+
) -> Union[DDPMSchedulerOutput, Tuple]:
|
37 |
+
"""
|
38 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
39 |
+
process from the learned model outputs (most often the predicted noise).
|
40 |
+
|
41 |
+
Args:
|
42 |
+
model_output (`torch.FloatTensor`):
|
43 |
+
The direct output from learned diffusion model.
|
44 |
+
timestep (`float`):
|
45 |
+
The current discrete timestep in the diffusion chain.
|
46 |
+
sample (`torch.FloatTensor`):
|
47 |
+
A current instance of a sample created by the diffusion process.
|
48 |
+
generator (`torch.Generator`, *optional*):
|
49 |
+
A random number generator.
|
50 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
51 |
+
Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`.
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
[`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`:
|
55 |
+
If return_dict is `True`, [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] is returned, otherwise a
|
56 |
+
tuple is returned where the first element is the sample tensor.
|
57 |
+
|
58 |
+
"""
|
59 |
+
t = timestep
|
60 |
+
|
61 |
+
prev_t = self.previous_timestep(t)
|
62 |
+
|
63 |
+
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
|
64 |
+
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
|
65 |
+
else:
|
66 |
+
predicted_variance = None
|
67 |
+
|
68 |
+
# 1. compute alphas, betas
|
69 |
+
alpha_prod_t = self.alphas_cumprod[t]
|
70 |
+
alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
|
71 |
+
beta_prod_t = 1 - alpha_prod_t
|
72 |
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
73 |
+
current_alpha_t = alpha_prod_t / alpha_prod_t_prev
|
74 |
+
current_beta_t = 1 - current_alpha_t
|
75 |
+
|
76 |
+
# 2. compute predicted original sample from predicted noise also called
|
77 |
+
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
|
78 |
+
if self.config.prediction_type == "epsilon":
|
79 |
+
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
80 |
+
elif self.config.prediction_type == "sample":
|
81 |
+
pred_original_sample = model_output
|
82 |
+
elif self.config.prediction_type == "v_prediction":
|
83 |
+
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
|
84 |
+
else:
|
85 |
+
raise ValueError(
|
86 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
|
87 |
+
" `v_prediction` for the DDPMScheduler."
|
88 |
+
)
|
89 |
+
|
90 |
+
# 3. Clip or threshold "predicted x_0"
|
91 |
+
if self.config.thresholding:
|
92 |
+
pred_original_sample = self._threshold_sample(pred_original_sample)
|
93 |
+
elif self.config.clip_sample:
|
94 |
+
pred_original_sample = pred_original_sample.clamp(
|
95 |
+
-self.config.clip_sample_range, self.config.clip_sample_range
|
96 |
+
)
|
97 |
+
|
98 |
+
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
|
99 |
+
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
100 |
+
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
|
101 |
+
current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
|
102 |
+
|
103 |
+
# 5. Compute predicted previous sample µ_t
|
104 |
+
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
105 |
+
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
|
106 |
+
|
107 |
+
# 6. Add noise
|
108 |
+
variance = 0
|
109 |
+
if t > 0:
|
110 |
+
v = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5)
|
111 |
+
if v > 1e-9:
|
112 |
+
self.noise_list[int(t.item() // (1000 // self.num_inference_steps))] = (next_sample - pred_prev_sample) / v
|
113 |
+
variance_noise = self.noise_list[int(t.item() // (1000 // self.num_inference_steps))]
|
114 |
+
variance = v * variance_noise
|
115 |
+
|
116 |
+
pred_prev_sample = pred_prev_sample + variance
|
117 |
+
|
118 |
+
if not return_dict:
|
119 |
+
return (pred_prev_sample,)
|
120 |
+
|
121 |
+
return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
|
122 |
+
|
123 |
+
def step(
|
124 |
+
self,
|
125 |
+
model_output: torch.FloatTensor,
|
126 |
+
timestep: int,
|
127 |
+
sample: torch.FloatTensor,
|
128 |
+
generator=None,
|
129 |
+
return_dict: bool = True,
|
130 |
+
) -> Union[DDPMSchedulerOutput, Tuple]:
|
131 |
+
"""
|
132 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
133 |
+
process from the learned model outputs (most often the predicted noise).
|
134 |
+
|
135 |
+
Args:
|
136 |
+
model_output (`torch.FloatTensor`):
|
137 |
+
The direct output from learned diffusion model.
|
138 |
+
timestep (`float`):
|
139 |
+
The current discrete timestep in the diffusion chain.
|
140 |
+
sample (`torch.FloatTensor`):
|
141 |
+
A current instance of a sample created by the diffusion process.
|
142 |
+
generator (`torch.Generator`, *optional*):
|
143 |
+
A random number generator.
|
144 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
145 |
+
Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`.
|
146 |
+
|
147 |
+
Returns:
|
148 |
+
[`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`:
|
149 |
+
If return_dict is `True`, [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] is returned, otherwise a
|
150 |
+
tuple is returned where the first element is the sample tensor.
|
151 |
+
|
152 |
+
"""
|
153 |
+
t = timestep
|
154 |
+
|
155 |
+
prev_t = self.previous_timestep(t)
|
156 |
+
|
157 |
+
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
|
158 |
+
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
|
159 |
+
else:
|
160 |
+
predicted_variance = None
|
161 |
+
|
162 |
+
# 1. compute alphas, betas
|
163 |
+
alpha_prod_t = self.alphas_cumprod[t]
|
164 |
+
alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
|
165 |
+
beta_prod_t = 1 - alpha_prod_t
|
166 |
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
167 |
+
current_alpha_t = alpha_prod_t / alpha_prod_t_prev
|
168 |
+
current_beta_t = 1 - current_alpha_t
|
169 |
+
|
170 |
+
# 2. compute predicted original sample from predicted noise also called
|
171 |
+
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
|
172 |
+
if self.config.prediction_type == "epsilon":
|
173 |
+
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
174 |
+
elif self.config.prediction_type == "sample":
|
175 |
+
pred_original_sample = model_output
|
176 |
+
elif self.config.prediction_type == "v_prediction":
|
177 |
+
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
|
178 |
+
else:
|
179 |
+
raise ValueError(
|
180 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
|
181 |
+
" `v_prediction` for the DDPMScheduler."
|
182 |
+
)
|
183 |
+
|
184 |
+
# 3. Clip or threshold "predicted x_0"
|
185 |
+
if self.config.thresholding:
|
186 |
+
pred_original_sample = self._threshold_sample(pred_original_sample)
|
187 |
+
elif self.config.clip_sample:
|
188 |
+
pred_original_sample = pred_original_sample.clamp(
|
189 |
+
-self.config.clip_sample_range, self.config.clip_sample_range
|
190 |
+
)
|
191 |
+
|
192 |
+
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
|
193 |
+
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
194 |
+
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
|
195 |
+
current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
|
196 |
+
|
197 |
+
# 5. Compute predicted previous sample µ_t
|
198 |
+
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
199 |
+
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
|
200 |
+
|
201 |
+
# 6. Add noise
|
202 |
+
variance = 0
|
203 |
+
if t > 0:
|
204 |
+
device = model_output.device
|
205 |
+
variance_noise = self.noise_list[int(t.item() // (1000 // self.num_inference_steps))]
|
206 |
+
if self.variance_type == "fixed_small_log":
|
207 |
+
variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise
|
208 |
+
elif self.variance_type == "learned_range":
|
209 |
+
variance = self._get_variance(t, predicted_variance=predicted_variance)
|
210 |
+
variance = torch.exp(0.5 * variance) * variance_noise
|
211 |
+
else:
|
212 |
+
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise
|
213 |
+
|
214 |
+
pred_prev_sample = pred_prev_sample + variance
|
215 |
+
|
216 |
+
if not return_dict:
|
217 |
+
return (pred_prev_sample,)
|
218 |
+
|
219 |
+
return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
|
src/enums_utils.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
from diffusers import DDIMScheduler, StableDiffusionImg2ImgPipeline, StableDiffusionXLImg2ImgPipeline, AutoPipelineForImage2Image
|
4 |
+
|
5 |
+
from src.eunms import Model_Type, Scheduler_Type
|
6 |
+
from src.euler_scheduler import MyEulerAncestralDiscreteScheduler
|
7 |
+
from src.lcm_scheduler import MyLCMScheduler
|
8 |
+
from src.ddpm_scheduler import MyDDPMScheduler
|
9 |
+
from src.sdxl_inversion_pipeline import SDXLDDIMPipeline
|
10 |
+
from src.sd_inversion_pipeline import SDDDIMPipeline
|
11 |
+
|
12 |
+
def scheduler_type_to_class(scheduler_type):
|
13 |
+
if scheduler_type == Scheduler_Type.DDIM:
|
14 |
+
return DDIMScheduler
|
15 |
+
elif scheduler_type == Scheduler_Type.EULER:
|
16 |
+
return MyEulerAncestralDiscreteScheduler
|
17 |
+
elif scheduler_type == Scheduler_Type.LCM:
|
18 |
+
return MyLCMScheduler
|
19 |
+
elif scheduler_type == Scheduler_Type.DDPM:
|
20 |
+
return MyDDPMScheduler
|
21 |
+
else:
|
22 |
+
raise ValueError("Unknown scheduler type")
|
23 |
+
|
24 |
+
def model_type_to_class(model_type):
|
25 |
+
if model_type == Model_Type.SDXL:
|
26 |
+
return StableDiffusionXLImg2ImgPipeline, SDXLDDIMPipeline
|
27 |
+
elif model_type == Model_Type.SDXL_Turbo:
|
28 |
+
return AutoPipelineForImage2Image, SDXLDDIMPipeline
|
29 |
+
elif model_type == Model_Type.LCM_SDXL:
|
30 |
+
return AutoPipelineForImage2Image, SDXLDDIMPipeline
|
31 |
+
elif model_type == Model_Type.SD15:
|
32 |
+
return StableDiffusionImg2ImgPipeline, SDDDIMPipeline
|
33 |
+
elif model_type == Model_Type.SD14:
|
34 |
+
return StableDiffusionImg2ImgPipeline, SDDDIMPipeline
|
35 |
+
elif model_type == Model_Type.SD21:
|
36 |
+
return StableDiffusionImg2ImgPipeline, SDDDIMPipeline
|
37 |
+
elif model_type == Model_Type.SD21_Turbo:
|
38 |
+
return StableDiffusionImg2ImgPipeline, SDDDIMPipeline
|
39 |
+
else:
|
40 |
+
raise ValueError("Unknown model type")
|
41 |
+
|
42 |
+
def model_type_to_model_name(model_type):
|
43 |
+
if model_type == Model_Type.SDXL:
|
44 |
+
return "stabilityai/stable-diffusion-xl-base-1.0"
|
45 |
+
elif model_type == Model_Type.SDXL_Turbo:
|
46 |
+
return "stabilityai/sdxl-turbo"
|
47 |
+
elif model_type == Model_Type.LCM_SDXL:
|
48 |
+
return "stabilityai/stable-diffusion-xl-base-1.0"
|
49 |
+
elif model_type == Model_Type.SD15:
|
50 |
+
return "runwayml/stable-diffusion-v1-5"
|
51 |
+
elif model_type == Model_Type.SD14:
|
52 |
+
return "CompVis/stable-diffusion-v1-4"
|
53 |
+
elif model_type == Model_Type.SD21:
|
54 |
+
return "stabilityai/stable-diffusion-2-1"
|
55 |
+
elif model_type == Model_Type.SD21_Turbo:
|
56 |
+
return "stabilityai/sd-turbo"
|
57 |
+
else:
|
58 |
+
raise ValueError("Unknown model type")
|
59 |
+
|
60 |
+
|
61 |
+
def model_type_to_size(model_type):
|
62 |
+
if model_type == Model_Type.SDXL:
|
63 |
+
return (1024, 1024)
|
64 |
+
elif model_type == Model_Type.SDXL_Turbo:
|
65 |
+
return (512, 512)
|
66 |
+
elif model_type == Model_Type.LCM_SDXL:
|
67 |
+
return (768, 768) #TODO: check
|
68 |
+
elif model_type == Model_Type.SD15:
|
69 |
+
return (512, 512)
|
70 |
+
elif model_type == Model_Type.SD14:
|
71 |
+
return (512, 512)
|
72 |
+
elif model_type == Model_Type.SD21:
|
73 |
+
return (512, 512)
|
74 |
+
elif model_type == Model_Type.SD21_Turbo:
|
75 |
+
return (512, 512)
|
76 |
+
else:
|
77 |
+
raise ValueError("Unknown model type")
|
78 |
+
|
79 |
+
def is_float16(model_type):
|
80 |
+
if model_type == Model_Type.SDXL:
|
81 |
+
return True
|
82 |
+
elif model_type == Model_Type.SDXL_Turbo:
|
83 |
+
return True
|
84 |
+
elif model_type == Model_Type.LCM_SDXL:
|
85 |
+
return True
|
86 |
+
elif model_type == Model_Type.SD15:
|
87 |
+
return False
|
88 |
+
elif model_type == Model_Type.SD14:
|
89 |
+
return False
|
90 |
+
elif model_type == Model_Type.SD21:
|
91 |
+
return False
|
92 |
+
elif model_type == Model_Type.SD21_Turbo:
|
93 |
+
return False
|
94 |
+
else:
|
95 |
+
raise ValueError("Unknown model type")
|
96 |
+
|
97 |
+
def is_sd(model_type):
|
98 |
+
if model_type == Model_Type.SDXL:
|
99 |
+
return False
|
100 |
+
elif model_type == Model_Type.SDXL_Turbo:
|
101 |
+
return False
|
102 |
+
elif model_type == Model_Type.LCM_SDXL:
|
103 |
+
return False
|
104 |
+
elif model_type == Model_Type.SD15:
|
105 |
+
return True
|
106 |
+
elif model_type == Model_Type.SD14:
|
107 |
+
return True
|
108 |
+
elif model_type == Model_Type.SD21:
|
109 |
+
return True
|
110 |
+
elif model_type == Model_Type.SD21_Turbo:
|
111 |
+
return True
|
112 |
+
else:
|
113 |
+
raise ValueError("Unknown model type")
|
114 |
+
|
115 |
+
def _get_pipes(model_type, device):
|
116 |
+
model_name = model_type_to_model_name(model_type)
|
117 |
+
pipeline_inf, pipeline_inv = model_type_to_class(model_type)
|
118 |
+
|
119 |
+
if is_float16(model_type):
|
120 |
+
pipe_inversion = pipeline_inv.from_pretrained(
|
121 |
+
model_name,
|
122 |
+
torch_dtype=torch.float16,
|
123 |
+
use_safetensors=True,
|
124 |
+
variant="fp16",
|
125 |
+
safety_checker = None
|
126 |
+
).to(device)
|
127 |
+
|
128 |
+
pipe_inference = pipeline_inf.from_pretrained(
|
129 |
+
model_name,
|
130 |
+
torch_dtype=torch.float16,
|
131 |
+
use_safetensors=True,
|
132 |
+
variant="fp16",
|
133 |
+
safety_checker = None
|
134 |
+
).to(device)
|
135 |
+
else:
|
136 |
+
pipe_inversion = pipeline_inv.from_pretrained(
|
137 |
+
model_name,
|
138 |
+
use_safetensors=True,
|
139 |
+
safety_checker = None
|
140 |
+
).to(device)
|
141 |
+
|
142 |
+
pipe_inference = pipeline_inf.from_pretrained(
|
143 |
+
model_name,
|
144 |
+
use_safetensors=True,
|
145 |
+
safety_checker = None
|
146 |
+
).to(device)
|
147 |
+
|
148 |
+
return pipe_inversion, pipe_inference
|
149 |
+
|
150 |
+
def get_pipes(model_type, scheduler_type, device="cuda"):
|
151 |
+
# model_name = model_type_to_model_name(model_type)
|
152 |
+
# pipeline_inf, pipeline_inv = model_type_to_class(model_type)
|
153 |
+
scheduler_class = scheduler_type_to_class(scheduler_type)
|
154 |
+
|
155 |
+
pipe_inversion, pipe_inference = _get_pipes(model_type, device)
|
156 |
+
|
157 |
+
# pipe_inversion = pipeline_inv.from_pretrained(
|
158 |
+
# model_name,
|
159 |
+
# # torch_dtype=torch.float16,
|
160 |
+
# use_safetensors=True,
|
161 |
+
# # variant="fp16",
|
162 |
+
# safety_checker = None
|
163 |
+
# ).to("cuda")
|
164 |
+
|
165 |
+
# pipe_inference = pipeline_inf.from_pretrained(
|
166 |
+
# model_name,
|
167 |
+
# # torch_dtype=torch.float16,
|
168 |
+
# use_safetensors=True,
|
169 |
+
# # variant="fp16",
|
170 |
+
# safety_checker = None
|
171 |
+
# ).to("cuda")
|
172 |
+
|
173 |
+
pipe_inference.scheduler = scheduler_class.from_config(pipe_inference.scheduler.config)
|
174 |
+
pipe_inversion.scheduler = scheduler_class.from_config(pipe_inversion.scheduler.config)
|
175 |
+
pipe_inversion.scheduler_inference = scheduler_class.from_config(pipe_inference.scheduler.config)
|
176 |
+
|
177 |
+
if is_sd(model_type):
|
178 |
+
pipe_inference.scheduler.add_noise = lambda init_latents, noise, timestep: init_latents
|
179 |
+
pipe_inversion.scheduler.add_noise = lambda init_latents, noise, timestep: init_latents
|
180 |
+
pipe_inversion.scheduler_inference.add_noise = lambda init_latents, noise, timestep: init_latents
|
181 |
+
|
182 |
+
if model_type == Model_Type.LCM_SDXL:
|
183 |
+
adapter_id = "latent-consistency/lcm-lora-sdxl"
|
184 |
+
# load and fuse lcm lora
|
185 |
+
pipe_inversion.load_lora_weights(adapter_id)
|
186 |
+
# pipe_inversion.fuse_lora()
|
187 |
+
pipe_inference.load_lora_weights(adapter_id)
|
188 |
+
# pipe_inference.fuse_lora()
|
189 |
+
|
190 |
+
return pipe_inversion, pipe_inference
|
src/euler_scheduler.py
ADDED
@@ -0,0 +1,588 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import EulerAncestralDiscreteScheduler, LCMScheduler
|
2 |
+
from diffusers.utils import BaseOutput
|
3 |
+
from diffusers.utils.torch_utils import randn_tensor
|
4 |
+
import torch
|
5 |
+
from typing import List, Optional, Tuple, Union
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from src.eunms import Epsilon_Update_Type
|
9 |
+
|
10 |
+
# g_cpu = torch.Generator().manual_seed(7865)
|
11 |
+
# noise = [randn_tensor((1, 4, 64, 64), dtype=torch.float16, device=torch.device("cuda:0"), generator=g_cpu) for i in range(4)]
|
12 |
+
# for i, n in enumerate(noise):
|
13 |
+
# torch.save(n, f"noise_{i}.pt")
|
14 |
+
|
15 |
+
class EulerAncestralDiscreteSchedulerOutput(BaseOutput):
|
16 |
+
"""
|
17 |
+
Output class for the scheduler's `step` function output.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
21 |
+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
22 |
+
denoising loop.
|
23 |
+
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
24 |
+
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
25 |
+
`pred_original_sample` can be used to preview progress or for guidance.
|
26 |
+
"""
|
27 |
+
|
28 |
+
prev_sample: torch.FloatTensor
|
29 |
+
pred_original_sample: Optional[torch.FloatTensor] = None
|
30 |
+
|
31 |
+
class MyEulerAncestralDiscreteScheduler(EulerAncestralDiscreteScheduler):
|
32 |
+
def set_noise_list(self, noise_list):
|
33 |
+
self.noise_list = noise_list
|
34 |
+
|
35 |
+
def get_noise_to_remove(self):
|
36 |
+
sigma_from = self.sigmas[self.step_index]
|
37 |
+
sigma_to = self.sigmas[self.step_index + 1]
|
38 |
+
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
|
39 |
+
|
40 |
+
return self.noise_list[self.step_index] * sigma_up\
|
41 |
+
|
42 |
+
def scale_model_input(
|
43 |
+
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
|
44 |
+
) -> torch.FloatTensor:
|
45 |
+
"""
|
46 |
+
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
47 |
+
current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
sample (`torch.FloatTensor`):
|
51 |
+
The input sample.
|
52 |
+
timestep (`int`, *optional*):
|
53 |
+
The current timestep in the diffusion chain.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
`torch.FloatTensor`:
|
57 |
+
A scaled input sample.
|
58 |
+
"""
|
59 |
+
|
60 |
+
self._init_step_index(timestep.view((1)))
|
61 |
+
return EulerAncestralDiscreteScheduler.scale_model_input(self, sample, timestep)
|
62 |
+
|
63 |
+
|
64 |
+
def step(
|
65 |
+
self,
|
66 |
+
model_output: torch.FloatTensor,
|
67 |
+
timestep: Union[float, torch.FloatTensor],
|
68 |
+
sample: torch.FloatTensor,
|
69 |
+
generator: Optional[torch.Generator] = None,
|
70 |
+
return_dict: bool = True,
|
71 |
+
) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
|
72 |
+
"""
|
73 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
74 |
+
process from the learned model outputs (most often the predicted noise).
|
75 |
+
|
76 |
+
Args:
|
77 |
+
model_output (`torch.FloatTensor`):
|
78 |
+
The direct output from learned diffusion model.
|
79 |
+
timestep (`float`):
|
80 |
+
The current discrete timestep in the diffusion chain.
|
81 |
+
sample (`torch.FloatTensor`):
|
82 |
+
A current instance of a sample created by the diffusion process.
|
83 |
+
generator (`torch.Generator`, *optional*):
|
84 |
+
A random number generator.
|
85 |
+
return_dict (`bool`):
|
86 |
+
Whether or not to return a
|
87 |
+
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
|
91 |
+
If return_dict is `True`,
|
92 |
+
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
|
93 |
+
otherwise a tuple is returned where the first element is the sample tensor.
|
94 |
+
|
95 |
+
"""
|
96 |
+
|
97 |
+
if (
|
98 |
+
isinstance(timestep, int)
|
99 |
+
or isinstance(timestep, torch.IntTensor)
|
100 |
+
or isinstance(timestep, torch.LongTensor)
|
101 |
+
):
|
102 |
+
raise ValueError(
|
103 |
+
(
|
104 |
+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
105 |
+
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
106 |
+
" one of the `scheduler.timesteps` as a timestep."
|
107 |
+
),
|
108 |
+
)
|
109 |
+
|
110 |
+
if not self.is_scale_input_called:
|
111 |
+
logger.warning(
|
112 |
+
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
|
113 |
+
"See `StableDiffusionPipeline` for a usage example."
|
114 |
+
)
|
115 |
+
|
116 |
+
self._init_step_index(timestep.view((1)))
|
117 |
+
|
118 |
+
sigma = self.sigmas[self.step_index]
|
119 |
+
|
120 |
+
# Upcast to avoid precision issues when computing prev_sample
|
121 |
+
sample = sample.to(torch.float32)
|
122 |
+
|
123 |
+
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
124 |
+
if self.config.prediction_type == "epsilon":
|
125 |
+
pred_original_sample = sample - sigma * model_output
|
126 |
+
elif self.config.prediction_type == "v_prediction":
|
127 |
+
# * c_out + input * c_skip
|
128 |
+
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
|
129 |
+
elif self.config.prediction_type == "sample":
|
130 |
+
raise NotImplementedError("prediction_type not implemented yet: sample")
|
131 |
+
else:
|
132 |
+
raise ValueError(
|
133 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
134 |
+
)
|
135 |
+
|
136 |
+
sigma_from = self.sigmas[self.step_index]
|
137 |
+
sigma_to = self.sigmas[self.step_index + 1]
|
138 |
+
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
|
139 |
+
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
140 |
+
|
141 |
+
# 2. Convert to an ODE derivative
|
142 |
+
# derivative = (sample - pred_original_sample) / sigma
|
143 |
+
derivative = model_output
|
144 |
+
|
145 |
+
dt = sigma_down - sigma
|
146 |
+
|
147 |
+
prev_sample = sample + derivative * dt
|
148 |
+
|
149 |
+
device = model_output.device
|
150 |
+
# noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
|
151 |
+
# prev_sample = prev_sample + noise * sigma_up
|
152 |
+
|
153 |
+
prev_sample = prev_sample + self.noise_list[self.step_index] * sigma_up
|
154 |
+
|
155 |
+
# Cast sample back to model compatible dtype
|
156 |
+
prev_sample = prev_sample.to(model_output.dtype)
|
157 |
+
|
158 |
+
# upon completion increase step index by one
|
159 |
+
self._step_index += 1
|
160 |
+
|
161 |
+
if not return_dict:
|
162 |
+
return (prev_sample,)
|
163 |
+
|
164 |
+
return EulerAncestralDiscreteSchedulerOutput(
|
165 |
+
prev_sample=prev_sample, pred_original_sample=pred_original_sample
|
166 |
+
)
|
167 |
+
|
168 |
+
def step_and_update_noise(
|
169 |
+
self,
|
170 |
+
model_output: torch.FloatTensor,
|
171 |
+
timestep: Union[float, torch.FloatTensor],
|
172 |
+
sample: torch.FloatTensor,
|
173 |
+
expected_prev_sample: torch.FloatTensor,
|
174 |
+
update_epsilon_type=Epsilon_Update_Type.OVERRIDE,
|
175 |
+
generator: Optional[torch.Generator] = None,
|
176 |
+
return_dict: bool = True,
|
177 |
+
) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
|
178 |
+
"""
|
179 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
180 |
+
process from the learned model outputs (most often the predicted noise).
|
181 |
+
|
182 |
+
Args:
|
183 |
+
model_output (`torch.FloatTensor`):
|
184 |
+
The direct output from learned diffusion model.
|
185 |
+
timestep (`float`):
|
186 |
+
The current discrete timestep in the diffusion chain.
|
187 |
+
sample (`torch.FloatTensor`):
|
188 |
+
A current instance of a sample created by the diffusion process.
|
189 |
+
generator (`torch.Generator`, *optional*):
|
190 |
+
A random number generator.
|
191 |
+
return_dict (`bool`):
|
192 |
+
Whether or not to return a
|
193 |
+
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
|
197 |
+
If return_dict is `True`,
|
198 |
+
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
|
199 |
+
otherwise a tuple is returned where the first element is the sample tensor.
|
200 |
+
|
201 |
+
"""
|
202 |
+
|
203 |
+
if (
|
204 |
+
isinstance(timestep, int)
|
205 |
+
or isinstance(timestep, torch.IntTensor)
|
206 |
+
or isinstance(timestep, torch.LongTensor)
|
207 |
+
):
|
208 |
+
raise ValueError(
|
209 |
+
(
|
210 |
+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
211 |
+
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
212 |
+
" one of the `scheduler.timesteps` as a timestep."
|
213 |
+
),
|
214 |
+
)
|
215 |
+
|
216 |
+
if not self.is_scale_input_called:
|
217 |
+
logger.warning(
|
218 |
+
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
|
219 |
+
"See `StableDiffusionPipeline` for a usage example."
|
220 |
+
)
|
221 |
+
|
222 |
+
self._init_step_index(timestep.view((1)))
|
223 |
+
|
224 |
+
sigma = self.sigmas[self.step_index]
|
225 |
+
|
226 |
+
# Upcast to avoid precision issues when computing prev_sample
|
227 |
+
sample = sample.to(torch.float32)
|
228 |
+
|
229 |
+
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
230 |
+
if self.config.prediction_type == "epsilon":
|
231 |
+
pred_original_sample = sample - sigma * model_output
|
232 |
+
elif self.config.prediction_type == "v_prediction":
|
233 |
+
# * c_out + input * c_skip
|
234 |
+
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
|
235 |
+
elif self.config.prediction_type == "sample":
|
236 |
+
raise NotImplementedError("prediction_type not implemented yet: sample")
|
237 |
+
else:
|
238 |
+
raise ValueError(
|
239 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
240 |
+
)
|
241 |
+
|
242 |
+
sigma_from = self.sigmas[self.step_index]
|
243 |
+
sigma_to = self.sigmas[self.step_index + 1]
|
244 |
+
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
|
245 |
+
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
246 |
+
|
247 |
+
# 2. Convert to an ODE derivative
|
248 |
+
# derivative = (sample - pred_original_sample) / sigma
|
249 |
+
derivative = model_output
|
250 |
+
|
251 |
+
dt = sigma_down - sigma
|
252 |
+
|
253 |
+
prev_sample = sample + derivative * dt
|
254 |
+
|
255 |
+
device = model_output.device
|
256 |
+
# noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
|
257 |
+
# prev_sample = prev_sample + noise * sigma_up
|
258 |
+
|
259 |
+
if sigma_up > 0:
|
260 |
+
req_noise = (expected_prev_sample - prev_sample) / sigma_up
|
261 |
+
if update_epsilon_type == Epsilon_Update_Type.OVERRIDE:
|
262 |
+
self.noise_list[self.step_index] = req_noise
|
263 |
+
else:
|
264 |
+
for i in range(10):
|
265 |
+
n = torch.autograd.Variable(self.noise_list[self.step_index].detach().clone(), requires_grad=True)
|
266 |
+
loss = torch.norm(n - req_noise.detach())
|
267 |
+
loss.backward()
|
268 |
+
self.noise_list[self.step_index] -= n.grad.detach() * 1.8
|
269 |
+
|
270 |
+
|
271 |
+
prev_sample = prev_sample + self.noise_list[self.step_index] * sigma_up
|
272 |
+
|
273 |
+
# Cast sample back to model compatible dtype
|
274 |
+
prev_sample = prev_sample.to(model_output.dtype)
|
275 |
+
|
276 |
+
# upon completion increase step index by one
|
277 |
+
self._step_index += 1
|
278 |
+
|
279 |
+
if not return_dict:
|
280 |
+
return (prev_sample,)
|
281 |
+
|
282 |
+
return EulerAncestralDiscreteSchedulerOutput(
|
283 |
+
prev_sample=prev_sample, pred_original_sample=pred_original_sample
|
284 |
+
)
|
285 |
+
|
286 |
+
def inv_step(
|
287 |
+
self,
|
288 |
+
model_output: torch.FloatTensor,
|
289 |
+
timestep: Union[float, torch.FloatTensor],
|
290 |
+
sample: torch.FloatTensor,
|
291 |
+
generator: Optional[torch.Generator] = None,
|
292 |
+
return_dict: bool = True,
|
293 |
+
) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
|
294 |
+
"""
|
295 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
296 |
+
process from the learned model outputs (most often the predicted noise).
|
297 |
+
|
298 |
+
Args:
|
299 |
+
model_output (`torch.FloatTensor`):
|
300 |
+
The direct output from learned diffusion model.
|
301 |
+
timestep (`float`):
|
302 |
+
The current discrete timestep in the diffusion chain.
|
303 |
+
sample (`torch.FloatTensor`):
|
304 |
+
A current instance of a sample created by the diffusion process.
|
305 |
+
generator (`torch.Generator`, *optional*):
|
306 |
+
A random number generator.
|
307 |
+
return_dict (`bool`):
|
308 |
+
Whether or not to return a
|
309 |
+
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
|
310 |
+
|
311 |
+
Returns:
|
312 |
+
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
|
313 |
+
If return_dict is `True`,
|
314 |
+
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
|
315 |
+
otherwise a tuple is returned where the first element is the sample tensor.
|
316 |
+
|
317 |
+
"""
|
318 |
+
|
319 |
+
if (
|
320 |
+
isinstance(timestep, int)
|
321 |
+
or isinstance(timestep, torch.IntTensor)
|
322 |
+
or isinstance(timestep, torch.LongTensor)
|
323 |
+
):
|
324 |
+
raise ValueError(
|
325 |
+
(
|
326 |
+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
327 |
+
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
328 |
+
" one of the `scheduler.timesteps` as a timestep."
|
329 |
+
),
|
330 |
+
)
|
331 |
+
|
332 |
+
if not self.is_scale_input_called:
|
333 |
+
logger.warning(
|
334 |
+
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
|
335 |
+
"See `StableDiffusionPipeline` for a usage example."
|
336 |
+
)
|
337 |
+
|
338 |
+
self._init_step_index(timestep.view((1)))
|
339 |
+
|
340 |
+
sigma = self.sigmas[self.step_index]
|
341 |
+
|
342 |
+
# Upcast to avoid precision issues when computing prev_sample
|
343 |
+
sample = sample.to(torch.float32)
|
344 |
+
|
345 |
+
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
346 |
+
if self.config.prediction_type == "epsilon":
|
347 |
+
pred_original_sample = sample - sigma * model_output
|
348 |
+
elif self.config.prediction_type == "v_prediction":
|
349 |
+
# * c_out + input * c_skip
|
350 |
+
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
|
351 |
+
elif self.config.prediction_type == "sample":
|
352 |
+
raise NotImplementedError("prediction_type not implemented yet: sample")
|
353 |
+
else:
|
354 |
+
raise ValueError(
|
355 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
356 |
+
)
|
357 |
+
|
358 |
+
sigma_from = self.sigmas[self.step_index]
|
359 |
+
sigma_to = self.sigmas[self.step_index+1]
|
360 |
+
# sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
|
361 |
+
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2).abs() / sigma_from**2) ** 0.5
|
362 |
+
# sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
363 |
+
sigma_down = sigma_to**2 / sigma_from
|
364 |
+
|
365 |
+
# 2. Convert to an ODE derivative
|
366 |
+
# derivative = (sample - pred_original_sample) / sigma
|
367 |
+
derivative = model_output
|
368 |
+
|
369 |
+
dt = sigma_down - sigma
|
370 |
+
# dt = sigma_down - sigma_from
|
371 |
+
|
372 |
+
prev_sample = sample - derivative * dt
|
373 |
+
|
374 |
+
device = model_output.device
|
375 |
+
# noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
|
376 |
+
# prev_sample = prev_sample + noise * sigma_up
|
377 |
+
|
378 |
+
prev_sample = prev_sample - self.noise_list[self.step_index] * sigma_up
|
379 |
+
|
380 |
+
# Cast sample back to model compatible dtype
|
381 |
+
prev_sample = prev_sample.to(model_output.dtype)
|
382 |
+
|
383 |
+
# upon completion increase step index by one
|
384 |
+
self._step_index += 1
|
385 |
+
|
386 |
+
if not return_dict:
|
387 |
+
return (prev_sample,)
|
388 |
+
|
389 |
+
return EulerAncestralDiscreteSchedulerOutput(
|
390 |
+
prev_sample=prev_sample, pred_original_sample=pred_original_sample
|
391 |
+
)
|
392 |
+
|
393 |
+
def get_all_sigmas(self) -> torch.FloatTensor:
|
394 |
+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
395 |
+
sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
|
396 |
+
return torch.from_numpy(sigmas)
|
397 |
+
|
398 |
+
def add_noise_off_schedule(
|
399 |
+
self,
|
400 |
+
original_samples: torch.FloatTensor,
|
401 |
+
noise: torch.FloatTensor,
|
402 |
+
timesteps: torch.FloatTensor,
|
403 |
+
) -> torch.FloatTensor:
|
404 |
+
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
405 |
+
sigmas = self.get_all_sigmas()
|
406 |
+
sigmas = sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
407 |
+
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
408 |
+
# mps does not support float64
|
409 |
+
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
410 |
+
else:
|
411 |
+
timesteps = timesteps.to(original_samples.device)
|
412 |
+
|
413 |
+
step_indices = 1000 - int(timesteps.item())
|
414 |
+
|
415 |
+
sigma = sigmas[step_indices].flatten()
|
416 |
+
while len(sigma.shape) < len(original_samples.shape):
|
417 |
+
sigma = sigma.unsqueeze(-1)
|
418 |
+
|
419 |
+
noisy_samples = original_samples + noise * sigma
|
420 |
+
return noisy_samples
|
421 |
+
|
422 |
+
# def update_noise_for_friendly_inversion(
|
423 |
+
# self,
|
424 |
+
# model_output: torch.FloatTensor,
|
425 |
+
# timestep: Union[float, torch.FloatTensor],
|
426 |
+
# z_t: torch.FloatTensor,
|
427 |
+
# z_tp1: torch.FloatTensor,
|
428 |
+
# return_dict: bool = True,
|
429 |
+
# ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
|
430 |
+
# if (
|
431 |
+
# isinstance(timestep, int)
|
432 |
+
# or isinstance(timestep, torch.IntTensor)
|
433 |
+
# or isinstance(timestep, torch.LongTensor)
|
434 |
+
# ):
|
435 |
+
# raise ValueError(
|
436 |
+
# (
|
437 |
+
# "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
438 |
+
# " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
439 |
+
# " one of the `scheduler.timesteps` as a timestep."
|
440 |
+
# ),
|
441 |
+
# )
|
442 |
+
|
443 |
+
# if not self.is_scale_input_called:
|
444 |
+
# logger.warning(
|
445 |
+
# "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
|
446 |
+
# "See `StableDiffusionPipeline` for a usage example."
|
447 |
+
# )
|
448 |
+
|
449 |
+
# self._init_step_index(timestep.view((1)))
|
450 |
+
|
451 |
+
# sigma = self.sigmas[self.step_index]
|
452 |
+
|
453 |
+
# sigma_from = self.sigmas[self.step_index]
|
454 |
+
# sigma_to = self.sigmas[self.step_index+1]
|
455 |
+
# # sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
|
456 |
+
# sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2).abs() / sigma_from**2) ** 0.5
|
457 |
+
# # sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
458 |
+
# sigma_down = sigma_to**2 / sigma_from
|
459 |
+
|
460 |
+
# # 2. Conv = (sample - pred_original_sample) / sigma
|
461 |
+
# derivative = model_output
|
462 |
+
|
463 |
+
# dt = sigma_down - sigma
|
464 |
+
# # dt = sigma_down - sigma_from
|
465 |
+
|
466 |
+
# prev_sample = z_t - derivative * dt
|
467 |
+
|
468 |
+
# if sigma_up > 0:
|
469 |
+
# self.noise_list[self.step_index] = (prev_sample - z_tp1) / sigma_up
|
470 |
+
|
471 |
+
# prev_sample = prev_sample - self.noise_list[self.step_index] * sigma_up
|
472 |
+
|
473 |
+
|
474 |
+
# if not return_dict:
|
475 |
+
# return (prev_sample,)
|
476 |
+
|
477 |
+
# return EulerAncestralDiscreteSchedulerOutput(
|
478 |
+
# prev_sample=prev_sample, pred_original_sample=None
|
479 |
+
# )
|
480 |
+
|
481 |
+
|
482 |
+
# def step_friendly_inversion(
|
483 |
+
# self,
|
484 |
+
# model_output: torch.FloatTensor,
|
485 |
+
# timestep: Union[float, torch.FloatTensor],
|
486 |
+
# sample: torch.FloatTensor,
|
487 |
+
# generator: Optional[torch.Generator] = None,
|
488 |
+
# return_dict: bool = True,
|
489 |
+
# expected_next_sample: torch.FloatTensor = None,
|
490 |
+
# ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
|
491 |
+
# """
|
492 |
+
# Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
493 |
+
# process from the learned model outputs (most often the predicted noise).
|
494 |
+
|
495 |
+
# Args:
|
496 |
+
# model_output (`torch.FloatTensor`):
|
497 |
+
# The direct output from learned diffusion model.
|
498 |
+
# timestep (`float`):
|
499 |
+
# The current discrete timestep in the diffusion chain.
|
500 |
+
# sample (`torch.FloatTensor`):
|
501 |
+
# A current instance of a sample created by the diffusion process.
|
502 |
+
# generator (`torch.Generator`, *optional*):
|
503 |
+
# A random number generator.
|
504 |
+
# return_dict (`bool`):
|
505 |
+
# Whether or not to return a
|
506 |
+
# [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
|
507 |
+
|
508 |
+
# Returns:
|
509 |
+
# [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
|
510 |
+
# If return_dict is `True`,
|
511 |
+
# [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
|
512 |
+
# otherwise a tuple is returned where the first element is the sample tensor.
|
513 |
+
|
514 |
+
# """
|
515 |
+
|
516 |
+
# if (
|
517 |
+
# isinstance(timestep, int)
|
518 |
+
# or isinstance(timestep, torch.IntTensor)
|
519 |
+
# or isinstance(timestep, torch.LongTensor)
|
520 |
+
# ):
|
521 |
+
# raise ValueError(
|
522 |
+
# (
|
523 |
+
# "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
524 |
+
# " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
525 |
+
# " one of the `scheduler.timesteps` as a timestep."
|
526 |
+
# ),
|
527 |
+
# )
|
528 |
+
|
529 |
+
# if not self.is_scale_input_called:
|
530 |
+
# logger.warning(
|
531 |
+
# "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
|
532 |
+
# "See `StableDiffusionPipeline` for a usage example."
|
533 |
+
# )
|
534 |
+
|
535 |
+
# self._init_step_index(timestep.view((1)))
|
536 |
+
|
537 |
+
# sigma = self.sigmas[self.step_index]
|
538 |
+
|
539 |
+
# # Upcast to avoid precision issues when computing prev_sample
|
540 |
+
# sample = sample.to(torch.float32)
|
541 |
+
|
542 |
+
# # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
543 |
+
# if self.config.prediction_type == "epsilon":
|
544 |
+
# pred_original_sample = sample - sigma * model_output
|
545 |
+
# elif self.config.prediction_type == "v_prediction":
|
546 |
+
# # * c_out + input * c_skip
|
547 |
+
# pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
|
548 |
+
# elif self.config.prediction_type == "sample":
|
549 |
+
# raise NotImplementedError("prediction_type not implemented yet: sample")
|
550 |
+
# else:
|
551 |
+
# raise ValueError(
|
552 |
+
# f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
553 |
+
# )
|
554 |
+
|
555 |
+
# sigma_from = self.sigmas[self.step_index]
|
556 |
+
# sigma_to = self.sigmas[self.step_index + 1]
|
557 |
+
# sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
|
558 |
+
# sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
559 |
+
|
560 |
+
# # 2. Convert to an ODE derivative
|
561 |
+
# # derivative = (sample - pred_original_sample) / sigma
|
562 |
+
# derivative = model_output
|
563 |
+
|
564 |
+
# dt = sigma_down - sigma
|
565 |
+
|
566 |
+
# prev_sample = sample + derivative * dt
|
567 |
+
|
568 |
+
# device = model_output.device
|
569 |
+
# # noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
|
570 |
+
# # prev_sample = prev_sample + noise * sigma_up
|
571 |
+
|
572 |
+
# if sigma_up > 0:
|
573 |
+
# self.noise_list[self.step_index] = (expected_next_sample - prev_sample) / sigma_up
|
574 |
+
|
575 |
+
# prev_sample = prev_sample + self.noise_list[self.step_index] * sigma_up
|
576 |
+
|
577 |
+
# # Cast sample back to model compatible dtype
|
578 |
+
# prev_sample = prev_sample.to(model_output.dtype)
|
579 |
+
|
580 |
+
# # upon completion increase step index by one
|
581 |
+
# self._step_index += 1
|
582 |
+
|
583 |
+
# if not return_dict:
|
584 |
+
# return (prev_sample,)
|
585 |
+
|
586 |
+
# return EulerAncestralDiscreteSchedulerOutput(
|
587 |
+
# prev_sample=prev_sample, pred_original_sample=pred_original_sample
|
588 |
+
# )
|
src/eunms.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
|
3 |
+
class Scheduler_Type(Enum):
|
4 |
+
DDIM = 1
|
5 |
+
EULER = 2
|
6 |
+
LCM = 3
|
7 |
+
DDPM = 4
|
8 |
+
|
9 |
+
class Model_Type(Enum):
|
10 |
+
SDXL = 1
|
11 |
+
SDXL_Turbo = 2
|
12 |
+
LCM_SDXL = 3
|
13 |
+
SD15 = 4
|
14 |
+
SD21 = 5
|
15 |
+
SD21_Turbo = 6
|
16 |
+
SD14 = 7
|
17 |
+
|
18 |
+
class Gradient_Averaging_Type(Enum):
|
19 |
+
NONE = 1
|
20 |
+
EACH_ITER = 2
|
21 |
+
ON_END = 3
|
22 |
+
|
23 |
+
class Epsilon_Update_Type(Enum):
|
24 |
+
NONE = 1
|
25 |
+
OVERRIDE = 2
|
26 |
+
OPTIMIZE = 3
|
src/images_utils.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
|
5 |
+
def read_images_in_path(path, size = (512,512)):
|
6 |
+
image_paths = []
|
7 |
+
for filename in os.listdir(path):
|
8 |
+
if filename.endswith(".png") or filename.endswith(".jpg") or filename.endswith(".jpeg"):
|
9 |
+
image_path = os.path.join(path, filename)
|
10 |
+
image_paths.append(image_path)
|
11 |
+
image_paths = sorted(image_paths)
|
12 |
+
return [Image.open(image_path).convert("RGB").resize(size) for image_path in image_paths]
|
13 |
+
|
14 |
+
def concatenate_images(image_lists, return_list = False):
|
15 |
+
num_rows = len(image_lists[0])
|
16 |
+
num_columns = len(image_lists)
|
17 |
+
image_width = image_lists[0][0].width
|
18 |
+
image_height = image_lists[0][0].height
|
19 |
+
|
20 |
+
grid_width = num_columns * image_width
|
21 |
+
grid_height = num_rows * image_height if not return_list else image_height
|
22 |
+
if not return_list:
|
23 |
+
grid_image = [Image.new('RGB', (grid_width, grid_height))]
|
24 |
+
else:
|
25 |
+
grid_image = [Image.new('RGB', (grid_width, grid_height)) for i in range(num_rows)]
|
26 |
+
|
27 |
+
for i in range(num_rows):
|
28 |
+
row_index = i if return_list else 0
|
29 |
+
for j in range(num_columns):
|
30 |
+
image = image_lists[j][i]
|
31 |
+
x_offset = j * image_width
|
32 |
+
y_offset = i * image_height if not return_list else 0
|
33 |
+
grid_image[row_index].paste(image, (x_offset, y_offset))
|
34 |
+
|
35 |
+
return grid_image if return_list else grid_image[0]
|
36 |
+
|
37 |
+
def concatenate_images_single(image_lists):
|
38 |
+
num_columns = len(image_lists)
|
39 |
+
image_width = image_lists[0].width
|
40 |
+
image_height = image_lists[0].height
|
41 |
+
|
42 |
+
grid_width = num_columns * image_width
|
43 |
+
grid_height = image_height
|
44 |
+
grid_image = Image.new('RGB', (grid_width, grid_height))
|
45 |
+
|
46 |
+
for j in range(num_columns):
|
47 |
+
image = image_lists[j]
|
48 |
+
x_offset = j * image_width
|
49 |
+
y_offset = 0
|
50 |
+
grid_image.paste(image, (x_offset, y_offset))
|
51 |
+
|
52 |
+
return grid_image
|
53 |
+
|
54 |
+
def get_captions_for_images(images, device):
|
55 |
+
from transformers import Blip2Processor, Blip2ForConditionalGeneration
|
56 |
+
|
57 |
+
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
58 |
+
model = Blip2ForConditionalGeneration.from_pretrained(
|
59 |
+
"Salesforce/blip2-opt-2.7b", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.float16
|
60 |
+
) # doctest: +IGNORE_RESULT
|
61 |
+
|
62 |
+
res = []
|
63 |
+
|
64 |
+
for image in images:
|
65 |
+
inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
|
66 |
+
|
67 |
+
generated_ids = model.generate(**inputs)
|
68 |
+
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
69 |
+
res.append(generated_text)
|
70 |
+
|
71 |
+
del processor
|
72 |
+
del model
|
73 |
+
|
74 |
+
return res
|
src/inversion_utils.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from random import randrange
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
def noise_regularization(
|
7 |
+
e_t, noise_pred_optimal, lambda_kl, lambda_ac, num_reg_steps, num_ac_rolls
|
8 |
+
):
|
9 |
+
for _outer in range(num_reg_steps):
|
10 |
+
if lambda_kl > 0:
|
11 |
+
_var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True)
|
12 |
+
l_kld = patchify_latents_kl_divergence(_var, noise_pred_optimal)
|
13 |
+
l_kld.backward()
|
14 |
+
_grad = _var.grad.detach()
|
15 |
+
_grad = torch.clip(_grad, -100, 100)
|
16 |
+
e_t = e_t - lambda_kl * _grad
|
17 |
+
if lambda_ac > 0:
|
18 |
+
for _inner in range(num_ac_rolls):
|
19 |
+
_var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True)
|
20 |
+
l_ac = auto_corr_loss(_var)
|
21 |
+
l_ac.backward()
|
22 |
+
_grad = _var.grad.detach() / num_ac_rolls
|
23 |
+
e_t = e_t - lambda_ac * _grad
|
24 |
+
e_t = e_t.detach()
|
25 |
+
|
26 |
+
return e_t
|
27 |
+
|
28 |
+
|
29 |
+
def auto_corr_loss(x, random_shift=True):
|
30 |
+
B, C, H, W = x.shape
|
31 |
+
assert B == 1
|
32 |
+
x = x.squeeze(0)
|
33 |
+
# x must be shape [C,H,W] now
|
34 |
+
reg_loss = 0.0
|
35 |
+
for ch_idx in range(x.shape[0]):
|
36 |
+
noise = x[ch_idx][None, None, :, :]
|
37 |
+
while True:
|
38 |
+
if random_shift:
|
39 |
+
roll_amount = randrange(noise.shape[2] // 2)
|
40 |
+
else:
|
41 |
+
roll_amount = 1
|
42 |
+
reg_loss += (
|
43 |
+
noise * torch.roll(noise, shifts=roll_amount, dims=2)
|
44 |
+
).mean() ** 2
|
45 |
+
reg_loss += (
|
46 |
+
noise * torch.roll(noise, shifts=roll_amount, dims=3)
|
47 |
+
).mean() ** 2
|
48 |
+
if noise.shape[2] <= 8:
|
49 |
+
break
|
50 |
+
noise = F.avg_pool2d(noise, kernel_size=2)
|
51 |
+
return reg_loss
|
52 |
+
|
53 |
+
|
54 |
+
def patchify_latents_kl_divergence(x0, x1, patch_size=4, num_channels=4):
|
55 |
+
|
56 |
+
def patchify_tensor(input_tensor):
|
57 |
+
patches = (
|
58 |
+
input_tensor.unfold(1, patch_size, patch_size)
|
59 |
+
.unfold(2, patch_size, patch_size)
|
60 |
+
.unfold(3, patch_size, patch_size)
|
61 |
+
)
|
62 |
+
patches = patches.contiguous().view(-1, num_channels, patch_size, patch_size)
|
63 |
+
return patches
|
64 |
+
|
65 |
+
x0 = patchify_tensor(x0)
|
66 |
+
x1 = patchify_tensor(x1)
|
67 |
+
|
68 |
+
kl = latents_kl_divergence(x0, x1).sum()
|
69 |
+
return kl
|
70 |
+
|
71 |
+
|
72 |
+
def latents_kl_divergence(x0, x1):
|
73 |
+
EPSILON = 1e-6
|
74 |
+
x0 = x0.view(x0.shape[0], x0.shape[1], -1)
|
75 |
+
x1 = x1.view(x1.shape[0], x1.shape[1], -1)
|
76 |
+
mu0 = x0.mean(dim=-1)
|
77 |
+
mu1 = x1.mean(dim=-1)
|
78 |
+
var0 = x0.var(dim=-1)
|
79 |
+
var1 = x1.var(dim=-1)
|
80 |
+
kl = (
|
81 |
+
torch.log((var1 + EPSILON) / (var0 + EPSILON))
|
82 |
+
+ (var0 + (mu0 - mu1) ** 2) / (var1 + EPSILON)
|
83 |
+
- 1
|
84 |
+
)
|
85 |
+
kl = torch.abs(kl).sum(dim=-1)
|
86 |
+
return kl
|
src/lcm_scheduler.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import LCMScheduler
|
2 |
+
from diffusers.utils import BaseOutput
|
3 |
+
from diffusers.utils.torch_utils import randn_tensor
|
4 |
+
import torch
|
5 |
+
from typing import List, Optional, Tuple, Union
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
class LCMSchedulerOutput(BaseOutput):
|
9 |
+
"""
|
10 |
+
Output class for the scheduler's `step` function output.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
14 |
+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
15 |
+
denoising loop.
|
16 |
+
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
17 |
+
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
18 |
+
`pred_original_sample` can be used to preview progress or for guidance.
|
19 |
+
"""
|
20 |
+
|
21 |
+
prev_sample: torch.FloatTensor
|
22 |
+
denoised: Optional[torch.FloatTensor] = None
|
23 |
+
|
24 |
+
class MyLCMScheduler(LCMScheduler):
|
25 |
+
|
26 |
+
def set_noise_list(self, noise_list):
|
27 |
+
self.noise_list = noise_list
|
28 |
+
|
29 |
+
def step(
|
30 |
+
self,
|
31 |
+
model_output: torch.FloatTensor,
|
32 |
+
timestep: int,
|
33 |
+
sample: torch.FloatTensor,
|
34 |
+
generator: Optional[torch.Generator] = None,
|
35 |
+
return_dict: bool = True,
|
36 |
+
) -> Union[LCMSchedulerOutput, Tuple]:
|
37 |
+
"""
|
38 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
39 |
+
process from the learned model outputs (most often the predicted noise).
|
40 |
+
|
41 |
+
Args:
|
42 |
+
model_output (`torch.FloatTensor`):
|
43 |
+
The direct output from learned diffusion model.
|
44 |
+
timestep (`float`):
|
45 |
+
The current discrete timestep in the diffusion chain.
|
46 |
+
sample (`torch.FloatTensor`):
|
47 |
+
A current instance of a sample created by the diffusion process.
|
48 |
+
generator (`torch.Generator`, *optional*):
|
49 |
+
A random number generator.
|
50 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
51 |
+
Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
|
52 |
+
Returns:
|
53 |
+
[`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
|
54 |
+
If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
|
55 |
+
tuple is returned where the first element is the sample tensor.
|
56 |
+
"""
|
57 |
+
if self.num_inference_steps is None:
|
58 |
+
raise ValueError(
|
59 |
+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
60 |
+
)
|
61 |
+
|
62 |
+
self._init_step_index(timestep)
|
63 |
+
|
64 |
+
# 1. get previous step value
|
65 |
+
prev_step_index = self.step_index + 1
|
66 |
+
if prev_step_index < len(self.timesteps):
|
67 |
+
prev_timestep = self.timesteps[prev_step_index]
|
68 |
+
else:
|
69 |
+
prev_timestep = timestep
|
70 |
+
|
71 |
+
# 2. compute alphas, betas
|
72 |
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
73 |
+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
74 |
+
|
75 |
+
beta_prod_t = 1 - alpha_prod_t
|
76 |
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
77 |
+
|
78 |
+
# 3. Get scalings for boundary conditions
|
79 |
+
c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
|
80 |
+
|
81 |
+
# 4. Compute the predicted original sample x_0 based on the model parameterization
|
82 |
+
if self.config.prediction_type == "epsilon": # noise-prediction
|
83 |
+
predicted_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
|
84 |
+
elif self.config.prediction_type == "sample": # x-prediction
|
85 |
+
predicted_original_sample = model_output
|
86 |
+
elif self.config.prediction_type == "v_prediction": # v-prediction
|
87 |
+
predicted_original_sample = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
|
88 |
+
else:
|
89 |
+
raise ValueError(
|
90 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
|
91 |
+
" `v_prediction` for `LCMScheduler`."
|
92 |
+
)
|
93 |
+
|
94 |
+
# 5. Clip or threshold "predicted x_0"
|
95 |
+
if self.config.thresholding:
|
96 |
+
predicted_original_sample = self._threshold_sample(predicted_original_sample)
|
97 |
+
elif self.config.clip_sample:
|
98 |
+
predicted_original_sample = predicted_original_sample.clamp(
|
99 |
+
-self.config.clip_sample_range, self.config.clip_sample_range
|
100 |
+
)
|
101 |
+
|
102 |
+
# 6. Denoise model output using boundary conditions
|
103 |
+
denoised = c_out * predicted_original_sample + c_skip * sample
|
104 |
+
|
105 |
+
# 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference
|
106 |
+
# Noise is not used on the final timestep of the timestep schedule.
|
107 |
+
# This also means that noise is not used for one-step sampling.
|
108 |
+
if self.step_index != self.num_inference_steps - 1:
|
109 |
+
noise = self.noise_list[self.step_index]
|
110 |
+
prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
|
111 |
+
else:
|
112 |
+
prev_sample = denoised
|
113 |
+
|
114 |
+
# upon completion increase step index by one
|
115 |
+
self._step_index += 1
|
116 |
+
|
117 |
+
if not return_dict:
|
118 |
+
return (prev_sample, denoised)
|
119 |
+
|
120 |
+
return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
|
121 |
+
|
122 |
+
|
123 |
+
def inv_step(
|
124 |
+
self,
|
125 |
+
model_output: torch.FloatTensor,
|
126 |
+
timestep: int,
|
127 |
+
sample: torch.FloatTensor,
|
128 |
+
generator: Optional[torch.Generator] = None,
|
129 |
+
return_dict: bool = True,
|
130 |
+
) -> Union[LCMSchedulerOutput, Tuple]:
|
131 |
+
"""
|
132 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
133 |
+
process from the learned model outputs (most often the predicted noise).
|
134 |
+
|
135 |
+
Args:
|
136 |
+
model_output (`torch.FloatTensor`):
|
137 |
+
The direct output from learned diffusion model.
|
138 |
+
timestep (`float`):
|
139 |
+
The current discrete timestep in the diffusion chain.
|
140 |
+
sample (`torch.FloatTensor`):
|
141 |
+
A current instance of a sample created by the diffusion process.
|
142 |
+
generator (`torch.Generator`, *optional*):
|
143 |
+
A random number generator.
|
144 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
145 |
+
Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
|
146 |
+
Returns:
|
147 |
+
[`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
|
148 |
+
If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
|
149 |
+
tuple is returned where the first element is the sample tensor.
|
150 |
+
"""
|
151 |
+
if self.num_inference_steps is None:
|
152 |
+
raise ValueError(
|
153 |
+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
154 |
+
)
|
155 |
+
|
156 |
+
self._init_step_index(timestep)
|
157 |
+
|
158 |
+
# 1. get previous step value
|
159 |
+
prev_step_index = self.step_index + 1
|
160 |
+
if prev_step_index < len(self.timesteps):
|
161 |
+
prev_timestep = self.timesteps[prev_step_index]
|
162 |
+
else:
|
163 |
+
prev_timestep = timestep
|
164 |
+
|
165 |
+
# 2. compute alphas, betas
|
166 |
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
167 |
+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
168 |
+
|
169 |
+
beta_prod_t = 1 - alpha_prod_t
|
170 |
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
171 |
+
|
172 |
+
# 3. Get scalings for boundary conditions
|
173 |
+
c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
|
174 |
+
|
175 |
+
if self.step_index != self.num_inference_steps - 1:
|
176 |
+
c_skip_actual = c_skip * alpha_prod_t_prev.sqrt()
|
177 |
+
c_out_actual = c_out * alpha_prod_t_prev.sqrt()
|
178 |
+
noise = self.noise_list[self.step_index] * beta_prod_t_prev.sqrt()
|
179 |
+
else:
|
180 |
+
c_skip_actual = c_skip
|
181 |
+
c_out_actual = c_out
|
182 |
+
noise = 0
|
183 |
+
|
184 |
+
|
185 |
+
dem = c_out_actual / (alpha_prod_t.sqrt()) + c_skip
|
186 |
+
eps_mul = beta_prod_t.sqrt() * c_out_actual / (alpha_prod_t.sqrt())
|
187 |
+
|
188 |
+
prev_sample = (sample + eps_mul * model_output - noise) / dem
|
189 |
+
|
190 |
+
# upon completion increase step index by one
|
191 |
+
self._step_index += 1
|
192 |
+
|
193 |
+
if not return_dict:
|
194 |
+
return (prev_sample, prev_sample)
|
195 |
+
|
196 |
+
return LCMSchedulerOutput(prev_sample=prev_sample, denoised=prev_sample)
|
src/lpips.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from PIL import Image
|
4 |
+
from itertools import chain
|
5 |
+
from torchvision import models
|
6 |
+
from typing import Sequence
|
7 |
+
from collections import OrderedDict
|
8 |
+
|
9 |
+
def get_network(net_type: str = 'vgg'):
|
10 |
+
if net_type == 'alex':
|
11 |
+
return AlexNet()
|
12 |
+
elif net_type == 'squeeze':
|
13 |
+
return SqueezeNet()
|
14 |
+
elif net_type == 'vgg':
|
15 |
+
return VGG16()
|
16 |
+
else:
|
17 |
+
raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
|
18 |
+
|
19 |
+
def normalize_activation(x, eps=1e-10):
|
20 |
+
norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
|
21 |
+
return x / (norm_factor + eps)
|
22 |
+
|
23 |
+
class BaseNet(nn.Module):
|
24 |
+
def __init__(self):
|
25 |
+
super(BaseNet, self).__init__()
|
26 |
+
|
27 |
+
# register buffer
|
28 |
+
self.register_buffer(
|
29 |
+
'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
|
30 |
+
self.register_buffer(
|
31 |
+
'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
|
32 |
+
|
33 |
+
def set_requires_grad(self, state: bool):
|
34 |
+
for param in chain(self.parameters(), self.buffers()):
|
35 |
+
param.requires_grad = state
|
36 |
+
|
37 |
+
def z_score(self, x: torch.Tensor):
|
38 |
+
return (x - self.mean) / self.std
|
39 |
+
|
40 |
+
def forward(self, x: torch.Tensor):
|
41 |
+
x = self.z_score(x)
|
42 |
+
|
43 |
+
output = []
|
44 |
+
for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
|
45 |
+
x = layer(x)
|
46 |
+
if i in self.target_layers:
|
47 |
+
output.append(normalize_activation(x))
|
48 |
+
if len(output) == len(self.target_layers):
|
49 |
+
break
|
50 |
+
return output
|
51 |
+
|
52 |
+
|
53 |
+
class SqueezeNet(BaseNet):
|
54 |
+
def __init__(self):
|
55 |
+
super(SqueezeNet, self).__init__()
|
56 |
+
|
57 |
+
self.layers = models.squeezenet1_1(True).features
|
58 |
+
self.target_layers = [2, 5, 8, 10, 11, 12, 13]
|
59 |
+
self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
|
60 |
+
|
61 |
+
self.set_requires_grad(False)
|
62 |
+
|
63 |
+
|
64 |
+
class AlexNet(BaseNet):
|
65 |
+
def __init__(self):
|
66 |
+
super(AlexNet, self).__init__()
|
67 |
+
|
68 |
+
self.layers = models.alexnet(True).features
|
69 |
+
self.target_layers = [2, 5, 8, 10, 12]
|
70 |
+
self.n_channels_list = [64, 192, 384, 256, 256]
|
71 |
+
|
72 |
+
self.set_requires_grad(False)
|
73 |
+
|
74 |
+
|
75 |
+
class VGG16(BaseNet):
|
76 |
+
def __init__(self):
|
77 |
+
super(VGG16, self).__init__()
|
78 |
+
|
79 |
+
self.layers = models.vgg16(True).features
|
80 |
+
self.target_layers = [4, 9, 16, 23, 30]
|
81 |
+
self.n_channels_list = [64, 128, 256, 512, 512]
|
82 |
+
|
83 |
+
self.set_requires_grad(False)
|
84 |
+
|
85 |
+
class LinLayers(nn.ModuleList):
|
86 |
+
def __init__(self, n_channels_list: Sequence[int]):
|
87 |
+
super(LinLayers, self).__init__([
|
88 |
+
nn.Sequential(
|
89 |
+
nn.Identity(),
|
90 |
+
nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
|
91 |
+
) for nc in n_channels_list
|
92 |
+
])
|
93 |
+
|
94 |
+
for param in self.parameters():
|
95 |
+
param.requires_grad = False
|
96 |
+
|
97 |
+
|
98 |
+
def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
|
99 |
+
# build url
|
100 |
+
url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
|
101 |
+
+ f'master/lpips/weights/v{version}/{net_type}.pth'
|
102 |
+
|
103 |
+
# download
|
104 |
+
old_state_dict = torch.hub.load_state_dict_from_url(
|
105 |
+
url, progress=True,
|
106 |
+
map_location=None if torch.cuda.is_available() else torch.device('cpu')
|
107 |
+
)
|
108 |
+
|
109 |
+
# rename keys
|
110 |
+
new_state_dict = OrderedDict()
|
111 |
+
for key, val in old_state_dict.items():
|
112 |
+
new_key = key
|
113 |
+
new_key = new_key.replace('lin', '')
|
114 |
+
new_key = new_key.replace('model.', '')
|
115 |
+
new_state_dict[new_key] = val
|
116 |
+
|
117 |
+
return new_state_dict
|
118 |
+
|
119 |
+
class LPIPS(nn.Module):
|
120 |
+
r"""Creates a criterion that measures
|
121 |
+
Learned Perceptual Image Patch Similarity (LPIPS).
|
122 |
+
Arguments:
|
123 |
+
net_type (str): the network type to compare the features:
|
124 |
+
'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
|
125 |
+
version (str): the version of LPIPS. Default: 0.1.
|
126 |
+
"""
|
127 |
+
def __init__(self, net_type: str = 'vgg', version: str = '0.1'):
|
128 |
+
|
129 |
+
assert version in ['0.1'], 'v0.1 is only supported now'
|
130 |
+
|
131 |
+
super(LPIPS, self).__init__()
|
132 |
+
|
133 |
+
# pretrained network
|
134 |
+
self.net = get_network(net_type).to("cuda")
|
135 |
+
|
136 |
+
# linear layers
|
137 |
+
self.lin = LinLayers(self.net.n_channels_list).to("cuda")
|
138 |
+
self.lin.load_state_dict(get_state_dict(net_type, version))
|
139 |
+
|
140 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
141 |
+
feat_x, feat_y = self.net(x), self.net(y)
|
142 |
+
|
143 |
+
diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
|
144 |
+
res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
|
145 |
+
|
146 |
+
return torch.sum(torch.cat(res, 0)) / x.shape[0]
|
147 |
+
|
src/metric_util.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image
|
3 |
+
from torchvision import transforms
|
4 |
+
|
5 |
+
from src.lpips import LPIPS
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
dev = 'cuda'
|
9 |
+
to_tensor_transform = transforms.Compose([transforms.ToTensor()])
|
10 |
+
mse_loss = nn.MSELoss()
|
11 |
+
|
12 |
+
def calculate_l2_difference(image1, image2, device = 'cuda'):
|
13 |
+
if isinstance(image1, Image.Image):
|
14 |
+
image1 = to_tensor_transform(image1).to(device)
|
15 |
+
if isinstance(image2, Image.Image):
|
16 |
+
image2 = to_tensor_transform(image2).to(device)
|
17 |
+
|
18 |
+
mse = mse_loss(image1, image2).item()
|
19 |
+
return mse
|
20 |
+
|
21 |
+
def calculate_psnr(image1, image2, device = 'cuda'):
|
22 |
+
max_value = 1.0
|
23 |
+
if isinstance(image1, Image.Image):
|
24 |
+
image1 = to_tensor_transform(image1).to(device)
|
25 |
+
if isinstance(image2, Image.Image):
|
26 |
+
image2 = to_tensor_transform(image2).to(device)
|
27 |
+
|
28 |
+
mse = mse_loss(image1, image2)
|
29 |
+
psnr = 10 * torch.log10(max_value**2 / mse).item()
|
30 |
+
return psnr
|
31 |
+
|
32 |
+
|
33 |
+
loss_fn = LPIPS(net_type='vgg').to(dev).eval()
|
34 |
+
|
35 |
+
def calculate_lpips(image1, image2, device = 'cuda'):
|
36 |
+
if isinstance(image1, Image.Image):
|
37 |
+
image1 = to_tensor_transform(image1).to(device)
|
38 |
+
if isinstance(image2, Image.Image):
|
39 |
+
image2 = to_tensor_transform(image2).to(device)
|
40 |
+
|
41 |
+
loss = loss_fn(image1, image2).item()
|
42 |
+
return loss
|
43 |
+
|
44 |
+
def calculate_metrics(image1, image2, device = 'cuda', size=(512, 512)):
|
45 |
+
if isinstance(image1, Image.Image):
|
46 |
+
image1 = image1.resize(size)
|
47 |
+
image1 = to_tensor_transform(image1).to(device)
|
48 |
+
if isinstance(image2, Image.Image):
|
49 |
+
image2 = image2.resize(size)
|
50 |
+
image2 = to_tensor_transform(image2).to(device)
|
51 |
+
|
52 |
+
l2 = calculate_l2_difference(image1, image2, device)
|
53 |
+
psnr = calculate_psnr(image1, image2, device)
|
54 |
+
lpips = calculate_lpips(image1, image2, device)
|
55 |
+
return {"l2": l2, "psnr": psnr, "lpips": lpips}
|
56 |
+
|
57 |
+
def get_empty_metrics():
|
58 |
+
return {"l2": 0, "psnr": 0, "lpips": 0}
|
59 |
+
|
60 |
+
def print_results(results):
|
61 |
+
print(f"Reconstruction Metrics: L2: {results['l2']},\t PSNR: {results['psnr']},\t LPIPS: {results['lpips']}")
|
src/sd_inversion_pipeline.py
ADDED
@@ -0,0 +1,634 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Plug&Play Feature Injection
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
5 |
+
from random import randrange
|
6 |
+
import PIL
|
7 |
+
import numpy as np
|
8 |
+
from tqdm import tqdm
|
9 |
+
from torch.cuda.amp import custom_bwd, custom_fwd
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
|
13 |
+
from diffusers import (
|
14 |
+
StableDiffusionPipeline,
|
15 |
+
StableDiffusionImg2ImgPipeline,
|
16 |
+
DDIMScheduler,
|
17 |
+
)
|
18 |
+
from diffusers.utils.torch_utils import randn_tensor
|
19 |
+
|
20 |
+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
21 |
+
StableDiffusionPipelineOutput,
|
22 |
+
retrieve_timesteps,
|
23 |
+
PipelineImageInput
|
24 |
+
)
|
25 |
+
|
26 |
+
from src.eunms import Scheduler_Type, Gradient_Averaging_Type, Epsilon_Update_Type
|
27 |
+
|
28 |
+
def _backward_ddim(x_tm1, alpha_t, alpha_tm1, eps_xt):
|
29 |
+
"""
|
30 |
+
let a = alpha_t, b = alpha_{t - 1}
|
31 |
+
We have a > b,
|
32 |
+
x_{t} - x_{t - 1} = sqrt(a) ((sqrt(1/b) - sqrt(1/a)) * x_{t-1} + (sqrt(1/a - 1) - sqrt(1/b - 1)) * eps_{t-1})
|
33 |
+
From https://arxiv.org/pdf/2105.05233.pdf, section F.
|
34 |
+
"""
|
35 |
+
|
36 |
+
a, b = alpha_t, alpha_tm1
|
37 |
+
sa = a**0.5
|
38 |
+
sb = b**0.5
|
39 |
+
|
40 |
+
return sa * ((1 / sb) * x_tm1 + ((1 / a - 1) ** 0.5 - (1 / b - 1) ** 0.5) * eps_xt)
|
41 |
+
|
42 |
+
|
43 |
+
class SDDDIMPipeline(StableDiffusionImg2ImgPipeline):
|
44 |
+
# @torch.no_grad()
|
45 |
+
def __call__(
|
46 |
+
self,
|
47 |
+
prompt: Union[str, List[str]] = None,
|
48 |
+
image: PipelineImageInput = None,
|
49 |
+
strength: float = 1.0,
|
50 |
+
num_inversion_steps: Optional[int] = 50,
|
51 |
+
timesteps: List[int] = None,
|
52 |
+
guidance_scale: Optional[float] = 7.5,
|
53 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
54 |
+
num_images_per_prompt: Optional[int] = 1,
|
55 |
+
eta: Optional[float] = 0.0,
|
56 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
57 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
58 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
59 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
60 |
+
output_type: Optional[str] = "pil",
|
61 |
+
return_dict: bool = True,
|
62 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
63 |
+
clip_skip: int = None,
|
64 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
65 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
66 |
+
opt_lr: float = 0.001,
|
67 |
+
opt_iters: int = 1,
|
68 |
+
opt_none_inference_steps: bool = False,
|
69 |
+
opt_loss_kl_lambda: float = 10.0,
|
70 |
+
num_inference_steps: int = 50,
|
71 |
+
num_aprox_steps: int = 100,
|
72 |
+
**kwargs,
|
73 |
+
):
|
74 |
+
callback = kwargs.pop("callback", None)
|
75 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
76 |
+
|
77 |
+
if callback is not None:
|
78 |
+
deprecate(
|
79 |
+
"callback",
|
80 |
+
"1.0.0",
|
81 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
82 |
+
)
|
83 |
+
if callback_steps is not None:
|
84 |
+
deprecate(
|
85 |
+
"callback_steps",
|
86 |
+
"1.0.0",
|
87 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
88 |
+
)
|
89 |
+
|
90 |
+
# 1. Check inputs. Raise error if not correct
|
91 |
+
self.check_inputs(
|
92 |
+
prompt,
|
93 |
+
strength,
|
94 |
+
callback_steps,
|
95 |
+
negative_prompt,
|
96 |
+
prompt_embeds,
|
97 |
+
negative_prompt_embeds,
|
98 |
+
callback_on_step_end_tensor_inputs,
|
99 |
+
)
|
100 |
+
|
101 |
+
self._guidance_scale = guidance_scale
|
102 |
+
self._clip_skip = clip_skip
|
103 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
104 |
+
|
105 |
+
# 2. Define call parameters
|
106 |
+
if prompt is not None and isinstance(prompt, str):
|
107 |
+
batch_size = 1
|
108 |
+
elif prompt is not None and isinstance(prompt, list):
|
109 |
+
batch_size = len(prompt)
|
110 |
+
else:
|
111 |
+
batch_size = prompt_embeds.shape[0]
|
112 |
+
|
113 |
+
device = self._execution_device
|
114 |
+
|
115 |
+
# 3. Encode input prompt
|
116 |
+
text_encoder_lora_scale = (
|
117 |
+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
118 |
+
)
|
119 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
120 |
+
prompt,
|
121 |
+
device,
|
122 |
+
num_images_per_prompt,
|
123 |
+
self.do_classifier_free_guidance,
|
124 |
+
negative_prompt,
|
125 |
+
prompt_embeds=prompt_embeds,
|
126 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
127 |
+
lora_scale=text_encoder_lora_scale,
|
128 |
+
clip_skip=self.clip_skip,
|
129 |
+
)
|
130 |
+
# For classifier free guidance, we need to do two forward passes.
|
131 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
132 |
+
# to avoid doing two forward passes
|
133 |
+
if self.do_classifier_free_guidance:
|
134 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
135 |
+
|
136 |
+
if ip_adapter_image is not None:
|
137 |
+
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
|
138 |
+
if self.do_classifier_free_guidance:
|
139 |
+
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
140 |
+
|
141 |
+
# 4. Preprocess image
|
142 |
+
image = self.image_processor.preprocess(image)
|
143 |
+
|
144 |
+
# 5. set timesteps
|
145 |
+
timesteps, num_inversion_steps = retrieve_timesteps(self.scheduler, num_inversion_steps, device, timesteps)
|
146 |
+
timesteps, num_inversion_steps = self.get_timesteps(num_inversion_steps, strength, device)
|
147 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
148 |
+
_, num_inference_steps = retrieve_timesteps(self.scheduler_inference, num_inference_steps, device, None)
|
149 |
+
|
150 |
+
# 6. Prepare latent variables
|
151 |
+
with torch.no_grad():
|
152 |
+
latents = self.prepare_latents(
|
153 |
+
image,
|
154 |
+
latent_timestep,
|
155 |
+
batch_size,
|
156 |
+
num_images_per_prompt,
|
157 |
+
prompt_embeds.dtype,
|
158 |
+
device,
|
159 |
+
generator,
|
160 |
+
)
|
161 |
+
|
162 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
163 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
164 |
+
|
165 |
+
# 7.1 Add image embeds for IP-Adapter
|
166 |
+
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
|
167 |
+
|
168 |
+
# 7.2 Optionally get Guidance Scale Embedding
|
169 |
+
timestep_cond = None
|
170 |
+
if self.unet.config.time_cond_proj_dim is not None:
|
171 |
+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
|
172 |
+
timestep_cond = self.get_guidance_scale_embedding(
|
173 |
+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
174 |
+
).to(device=device, dtype=latents.dtype)
|
175 |
+
|
176 |
+
# 8. Denoising loop
|
177 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
178 |
+
self._num_timesteps = len(timesteps)
|
179 |
+
prev_timestep = None
|
180 |
+
self.prev_z = torch.clone(latents)
|
181 |
+
self.prev_z4 = torch.clone(latents)
|
182 |
+
self.z_0 = torch.clone(latents)
|
183 |
+
g_cpu = torch.Generator().manual_seed(7865)
|
184 |
+
self.noise = randn_tensor(self.z_0.shape, generator=g_cpu, device=self.z_0.device, dtype=self.z_0.dtype)
|
185 |
+
|
186 |
+
|
187 |
+
all_latents = [latents.clone()]
|
188 |
+
with self.progress_bar(total=num_inversion_steps) as progress_bar:
|
189 |
+
for i, t in enumerate(reversed(timesteps)):
|
190 |
+
|
191 |
+
z_tp1 = self.inversion_step(latents,
|
192 |
+
t,
|
193 |
+
prompt_embeds,
|
194 |
+
added_cond_kwargs,
|
195 |
+
prev_timestep=prev_timestep,
|
196 |
+
num_aprox_steps=num_aprox_steps)
|
197 |
+
|
198 |
+
if t in self.scheduler_inference.timesteps:
|
199 |
+
z_tp1 = self.optimize_z_tp1(z_tp1,
|
200 |
+
latents,
|
201 |
+
t,
|
202 |
+
prompt_embeds,
|
203 |
+
added_cond_kwargs,
|
204 |
+
nom_opt_iters=opt_iters,
|
205 |
+
lr=opt_lr,
|
206 |
+
opt_loss_kl_lambda=opt_loss_kl_lambda)
|
207 |
+
|
208 |
+
prev_timestep = t
|
209 |
+
latents = z_tp1
|
210 |
+
|
211 |
+
all_latents.append(latents.clone())
|
212 |
+
|
213 |
+
if callback_on_step_end is not None:
|
214 |
+
callback_kwargs = {}
|
215 |
+
for k in callback_on_step_end_tensor_inputs:
|
216 |
+
callback_kwargs[k] = locals()[k]
|
217 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
218 |
+
|
219 |
+
latents = callback_outputs.pop("latents", latents)
|
220 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
221 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
222 |
+
|
223 |
+
# call the callback, if provided
|
224 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
225 |
+
progress_bar.update()
|
226 |
+
if callback is not None and i % callback_steps == 0:
|
227 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
228 |
+
callback(step_idx, t, latents)
|
229 |
+
|
230 |
+
image = latents
|
231 |
+
|
232 |
+
# Offload all models
|
233 |
+
self.maybe_free_model_hooks()
|
234 |
+
|
235 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None), all_latents
|
236 |
+
|
237 |
+
def noise_regularization(self, e_t, noise_pred_optimal):
|
238 |
+
for _outer in range(self.cfg.num_reg_steps):
|
239 |
+
if self.cfg.lambda_kl>0:
|
240 |
+
_var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True)
|
241 |
+
# l_kld = self.kl_divergence(_var)
|
242 |
+
l_kld = self.patchify_latents_kl_divergence(_var, noise_pred_optimal)
|
243 |
+
l_kld.backward()
|
244 |
+
_grad = _var.grad.detach()
|
245 |
+
_grad = torch.clip(_grad, -100, 100)
|
246 |
+
e_t = e_t - self.cfg.lambda_kl*_grad
|
247 |
+
if self.cfg.lambda_ac>0:
|
248 |
+
for _inner in range(self.cfg.num_ac_rolls):
|
249 |
+
_var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True)
|
250 |
+
l_ac = self.auto_corr_loss(_var)
|
251 |
+
l_ac.backward()
|
252 |
+
_grad = _var.grad.detach()/self.cfg.num_ac_rolls
|
253 |
+
e_t = e_t - self.cfg.lambda_ac*_grad
|
254 |
+
e_t = e_t.detach()
|
255 |
+
|
256 |
+
return e_t
|
257 |
+
|
258 |
+
def auto_corr_loss(self, x, random_shift=True):
|
259 |
+
B,C,H,W = x.shape
|
260 |
+
assert B==1
|
261 |
+
x = x.squeeze(0)
|
262 |
+
# x must be shape [C,H,W] now
|
263 |
+
reg_loss = 0.0
|
264 |
+
for ch_idx in range(x.shape[0]):
|
265 |
+
noise = x[ch_idx][None, None,:,:]
|
266 |
+
while True:
|
267 |
+
if random_shift: roll_amount = randrange(noise.shape[2]//2)
|
268 |
+
else: roll_amount = 1
|
269 |
+
reg_loss += (noise*torch.roll(noise, shifts=roll_amount, dims=2)).mean()**2
|
270 |
+
reg_loss += (noise*torch.roll(noise, shifts=roll_amount, dims=3)).mean()**2
|
271 |
+
if noise.shape[2] <= 8:
|
272 |
+
break
|
273 |
+
noise = F.avg_pool2d(noise, kernel_size=2)
|
274 |
+
return reg_loss
|
275 |
+
|
276 |
+
def kl_divergence(self, x):
|
277 |
+
_mu = x.mean()
|
278 |
+
_var = x.var()
|
279 |
+
return _var + _mu**2 - 1 - torch.log(_var+1e-7)
|
280 |
+
|
281 |
+
# @torch.no_grad()
|
282 |
+
def inversion_step(
|
283 |
+
self,
|
284 |
+
z_t: torch.tensor,
|
285 |
+
t: torch.tensor,
|
286 |
+
prompt_embeds,
|
287 |
+
added_cond_kwargs,
|
288 |
+
prev_timestep: Optional[torch.tensor] = None,
|
289 |
+
num_aprox_steps: int = 100
|
290 |
+
) -> torch.tensor:
|
291 |
+
extra_step_kwargs = {}
|
292 |
+
|
293 |
+
avg_range = self.cfg.gradient_averaging_first_step_range if t.item() < 250 else self.cfg.gradient_averaging_step_range
|
294 |
+
|
295 |
+
# When doing more then one approximation step in the first step it adds artifacts
|
296 |
+
if t.item() < 250:
|
297 |
+
num_aprox_steps = min(self.cfg.max_num_aprox_steps_first_step, num_aprox_steps)
|
298 |
+
|
299 |
+
approximated_z_tp1 = z_t.clone()
|
300 |
+
nosie_pred_avg = None
|
301 |
+
|
302 |
+
if self.cfg.num_reg_steps > 0:
|
303 |
+
z_tp1_forward = self.scheduler.add_noise(self.z_0, self.noise, t.view((1))).detach()
|
304 |
+
latent_model_input = torch.cat([z_tp1_forward] * 2) if self.do_classifier_free_guidance else z_tp1_forward
|
305 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
306 |
+
|
307 |
+
with torch.no_grad():
|
308 |
+
# predict the noise residual
|
309 |
+
noise_pred_optimal = self.unet(
|
310 |
+
latent_model_input,
|
311 |
+
t,
|
312 |
+
encoder_hidden_states=prompt_embeds,
|
313 |
+
timestep_cond=None,
|
314 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
315 |
+
added_cond_kwargs=added_cond_kwargs,
|
316 |
+
return_dict=False,
|
317 |
+
)[0].detach()
|
318 |
+
else:
|
319 |
+
noise_pred_optimal = None
|
320 |
+
|
321 |
+
for i in range(num_aprox_steps + 1):
|
322 |
+
latent_model_input = torch.cat([approximated_z_tp1] * 2) if self.do_classifier_free_guidance else approximated_z_tp1
|
323 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
324 |
+
|
325 |
+
with torch.no_grad():
|
326 |
+
# predict the noise residual
|
327 |
+
noise_pred = self.unet(
|
328 |
+
latent_model_input,
|
329 |
+
t,
|
330 |
+
encoder_hidden_states=prompt_embeds,
|
331 |
+
timestep_cond=None,
|
332 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
333 |
+
added_cond_kwargs=added_cond_kwargs,
|
334 |
+
return_dict=False,
|
335 |
+
)[0]
|
336 |
+
|
337 |
+
if i >= avg_range[0] and i < avg_range[1]:
|
338 |
+
j = i - avg_range[0]
|
339 |
+
if nosie_pred_avg is None:
|
340 |
+
nosie_pred_avg = noise_pred.clone()
|
341 |
+
else:
|
342 |
+
nosie_pred_avg = j * nosie_pred_avg / (j + 1) + noise_pred / (j + 1)
|
343 |
+
if self.cfg.gradient_averaging_type == Gradient_Averaging_Type.EACH_ITER:
|
344 |
+
noise_pred = nosie_pred_avg.clone()
|
345 |
+
|
346 |
+
# perform guidance
|
347 |
+
if self.do_classifier_free_guidance:
|
348 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
349 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
350 |
+
|
351 |
+
if i >= avg_range[0] or (self.cfg.gradient_averaging_type == Gradient_Averaging_Type.NONE and i > 0):
|
352 |
+
noise_pred = self.noise_regularization(noise_pred, noise_pred_optimal)
|
353 |
+
|
354 |
+
if self.cfg.scheduler_type == Scheduler_Type.EULER:
|
355 |
+
approximated_z_tp1 = self.scheduler.inv_step(noise_pred, t, z_t, **extra_step_kwargs, return_dict=False)[0].detach()
|
356 |
+
else:
|
357 |
+
alpha_prod_t = self.scheduler.alphas_cumprod[t]
|
358 |
+
alpha_prod_t_prev = (
|
359 |
+
self.scheduler.alphas_cumprod[prev_timestep]
|
360 |
+
if prev_timestep is not None
|
361 |
+
else self.scheduler.final_alpha_cumprod
|
362 |
+
)
|
363 |
+
approximated_z_tp1 = _backward_ddim(
|
364 |
+
x_tm1=z_t,
|
365 |
+
alpha_t=alpha_prod_t,
|
366 |
+
alpha_tm1=alpha_prod_t_prev,
|
367 |
+
eps_xt=noise_pred,
|
368 |
+
)
|
369 |
+
|
370 |
+
if self.cfg.gradient_averaging_type == Gradient_Averaging_Type.ON_END and nosie_pred_avg is not None:
|
371 |
+
|
372 |
+
nosie_pred_avg = self.noise_regularization(nosie_pred_avg, noise_pred_optimal)
|
373 |
+
if self.cfg.scheduler_type == Scheduler_Type.EULER:
|
374 |
+
approximated_z_tp1 = self.scheduler.inv_step(nosie_pred_avg, t, z_t, **extra_step_kwargs, return_dict=False)[0].detach()
|
375 |
+
else:
|
376 |
+
alpha_prod_t = self.scheduler.alphas_cumprod[t]
|
377 |
+
alpha_prod_t_prev = (
|
378 |
+
self.scheduler.alphas_cumprod[prev_timestep]
|
379 |
+
if prev_timestep is not None
|
380 |
+
else self.scheduler.final_alpha_cumprod
|
381 |
+
)
|
382 |
+
approximated_z_tp1 = _backward_ddim(
|
383 |
+
x_tm1=z_t,
|
384 |
+
alpha_t=alpha_prod_t,
|
385 |
+
alpha_tm1=alpha_prod_t_prev,
|
386 |
+
eps_xt=nosie_pred_avg,
|
387 |
+
)
|
388 |
+
|
389 |
+
if self.cfg.update_epsilon_type != Epsilon_Update_Type.NONE:
|
390 |
+
latent_model_input = torch.cat([approximated_z_tp1] * 2) if self.do_classifier_free_guidance else approximated_z_tp1
|
391 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
392 |
+
|
393 |
+
with torch.no_grad():
|
394 |
+
# predict the noise residual
|
395 |
+
noise_pred = self.unet(
|
396 |
+
latent_model_input,
|
397 |
+
t,
|
398 |
+
encoder_hidden_states=prompt_embeds,
|
399 |
+
timestep_cond=None,
|
400 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
401 |
+
added_cond_kwargs=added_cond_kwargs,
|
402 |
+
return_dict=False,
|
403 |
+
)[0]
|
404 |
+
|
405 |
+
# perform guidance
|
406 |
+
if self.do_classifier_free_guidance:
|
407 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
408 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
409 |
+
|
410 |
+
self.scheduler.step_and_update_noise(noise_pred, t, approximated_z_tp1, z_t, return_dict=False, update_epsilon_type=self.cfg.update_epsilon_type)
|
411 |
+
|
412 |
+
return approximated_z_tp1
|
413 |
+
|
414 |
+
def detach_before_opt(self, z_tp1, t, prompt_embeds, added_cond_kwargs):
|
415 |
+
z_tp1 = z_tp1.detach()
|
416 |
+
t = t.detach()
|
417 |
+
prompt_embeds = prompt_embeds.detach()
|
418 |
+
return z_tp1, t, prompt_embeds, added_cond_kwargs
|
419 |
+
|
420 |
+
def opt_z_tp1_single_step(
|
421 |
+
self,
|
422 |
+
z_tp1,
|
423 |
+
z_t,
|
424 |
+
t,
|
425 |
+
prompt_embeds,
|
426 |
+
added_cond_kwargs,
|
427 |
+
lr=0.001,
|
428 |
+
opt_loss_kl_lambda=10.0,
|
429 |
+
):
|
430 |
+
l1_loss = torch.nn.L1Loss(reduction='sum')
|
431 |
+
mse = torch.nn.MSELoss(reduction='sum')
|
432 |
+
extra_step_kwargs = {}
|
433 |
+
|
434 |
+
self.unet.requires_grad_(False)
|
435 |
+
z_tp1, t, prompt_embeds, added_cond_kwargs = self.detach_before_opt(z_tp1, t, prompt_embeds, added_cond_kwargs)
|
436 |
+
|
437 |
+
z_tp1 = torch.nn.Parameter(z_tp1, requires_grad=True)
|
438 |
+
optimizer = torch.optim.SGD([z_tp1], lr=lr, momentum=0.9)
|
439 |
+
|
440 |
+
optimizer.zero_grad()
|
441 |
+
self.unet.zero_grad()
|
442 |
+
latent_model_input = torch.cat([z_tp1] * 2) if self.do_classifier_free_guidance else z_tp1
|
443 |
+
latent_model_input = self.scheduler_inference.scale_model_input(latent_model_input, t)
|
444 |
+
|
445 |
+
noise_pred = self.unet(
|
446 |
+
latent_model_input,
|
447 |
+
t,
|
448 |
+
encoder_hidden_states=prompt_embeds,
|
449 |
+
timestep_cond=None,
|
450 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
451 |
+
added_cond_kwargs=added_cond_kwargs,
|
452 |
+
return_dict=False,
|
453 |
+
)[0]
|
454 |
+
|
455 |
+
# perform guidance
|
456 |
+
if self.do_classifier_free_guidance:
|
457 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
458 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
459 |
+
|
460 |
+
# # compute the previous noisy sample x_t -> x_t-1
|
461 |
+
z_t_hat = self.scheduler_inference.step(noise_pred, t, z_tp1, **extra_step_kwargs, return_dict=False)[0]
|
462 |
+
|
463 |
+
direct_loss = 0.5 * mse(z_t_hat, z_t.detach()) + 0.5 * l1_loss(z_t_hat, z_t.detach())
|
464 |
+
kl_loss = torch.tensor([0]).to(z_t.device)
|
465 |
+
loss = 1.0 * direct_loss + opt_loss_kl_lambda * kl_loss
|
466 |
+
|
467 |
+
loss.backward()
|
468 |
+
optimizer.step()
|
469 |
+
print(f't: {t}\t total_loss: {format(loss.item(), ".3f")}\t\t direct_loss: {format(direct_loss.item(), ".3f")}\t\t kl_loss: {format(kl_loss.item(), ".3f")}')
|
470 |
+
|
471 |
+
return z_tp1.detach()
|
472 |
+
|
473 |
+
def optimize_z_tp1(
|
474 |
+
self,
|
475 |
+
z_tp1,
|
476 |
+
z_t,
|
477 |
+
t,
|
478 |
+
prompt_embeds,
|
479 |
+
added_cond_kwargs,
|
480 |
+
nom_opt_iters=1,
|
481 |
+
lr=0.001,
|
482 |
+
opt_loss_kl_lambda=10.0,
|
483 |
+
):
|
484 |
+
l1_loss = torch.nn.L1Loss(reduction='sum')
|
485 |
+
mse = torch.nn.MSELoss(reduction='sum')
|
486 |
+
extra_step_kwargs = {}
|
487 |
+
|
488 |
+
self.unet.requires_grad_(False)
|
489 |
+
z_tp1, t, prompt_embeds, added_cond_kwargs = self.detach_before_opt(z_tp1, t, prompt_embeds, added_cond_kwargs)
|
490 |
+
|
491 |
+
z_tp1 = torch.nn.Parameter(z_tp1, requires_grad=True)
|
492 |
+
optimizer = torch.optim.SGD([z_tp1], lr=lr, momentum=0.9)
|
493 |
+
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor = 0.5, verbose=True, patience=5, cooldown=3)
|
494 |
+
max_loss = 99999999999999
|
495 |
+
|
496 |
+
z_tp1_forward = self.scheduler.add_noise(self.z_0, self.noise, t.view((1))).detach()
|
497 |
+
z_tp1_best = None
|
498 |
+
for i in range(nom_opt_iters):
|
499 |
+
optimizer.zero_grad()
|
500 |
+
self.unet.zero_grad()
|
501 |
+
latent_model_input = torch.cat([z_tp1] * 2) if self.do_classifier_free_guidance else z_tp1
|
502 |
+
latent_model_input = self.scheduler_inference.scale_model_input(latent_model_input, t)
|
503 |
+
|
504 |
+
noise_pred = self.unet(
|
505 |
+
latent_model_input,
|
506 |
+
t,
|
507 |
+
encoder_hidden_states=prompt_embeds,
|
508 |
+
timestep_cond=None,
|
509 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
510 |
+
added_cond_kwargs=added_cond_kwargs,
|
511 |
+
return_dict=False,
|
512 |
+
)[0]
|
513 |
+
|
514 |
+
# perform guidance
|
515 |
+
if self.do_classifier_free_guidance:
|
516 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
517 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
518 |
+
|
519 |
+
# # compute the previous noisy sample x_t -> x_t-1
|
520 |
+
z_t_hat = self.scheduler_inference.step(noise_pred, t, z_tp1, **extra_step_kwargs, return_dict=False)[0]
|
521 |
+
|
522 |
+
direct_loss = 0.5 * mse(z_t_hat, z_t.detach()) + 0.5 * l1_loss(z_t_hat, z_t.detach())
|
523 |
+
kl_loss = self.patchify_latents_kl_divergence(z_tp1, z_tp1_forward)
|
524 |
+
loss = 1.0 * direct_loss + opt_loss_kl_lambda * kl_loss
|
525 |
+
|
526 |
+
loss.backward()
|
527 |
+
best = False
|
528 |
+
if loss < max_loss:
|
529 |
+
max_loss = loss
|
530 |
+
z_tp1_best = torch.clone(z_tp1)
|
531 |
+
best = True
|
532 |
+
lr_scheduler.step(loss)
|
533 |
+
if optimizer.param_groups[0]['lr'] < 9e-06:
|
534 |
+
break
|
535 |
+
optimizer.step()
|
536 |
+
print(f't: {t}\t\t iter: {i}\t total_loss: {format(loss.item(), ".3f")}\t\t direct_loss: {format(direct_loss.item(), ".3f")}\t\t kl_loss: {format(kl_loss.item(), ".3f")}\t\t best: {best}')
|
537 |
+
|
538 |
+
if z_tp1_best is not None:
|
539 |
+
z_tp1 = z_tp1_best
|
540 |
+
|
541 |
+
self.prev_z4 = torch.clone(z_tp1)
|
542 |
+
|
543 |
+
return z_tp1.detach()
|
544 |
+
|
545 |
+
def opt_inv(self,
|
546 |
+
z_t,
|
547 |
+
t,
|
548 |
+
prompt_embeds,
|
549 |
+
added_cond_kwargs,
|
550 |
+
prev_timestep,
|
551 |
+
nom_opt_iters=1,
|
552 |
+
lr=0.001,
|
553 |
+
opt_none_inference_steps=False,
|
554 |
+
opt_loss_kl_lambda=10.0,
|
555 |
+
num_aprox_steps=100):
|
556 |
+
|
557 |
+
z_tp1 = self.inversion_step(z_t, t, prompt_embeds, added_cond_kwargs, num_aprox_steps=num_aprox_steps)
|
558 |
+
|
559 |
+
if t in self.scheduler_inference.timesteps:
|
560 |
+
z_tp1 = self.optimize_z_tp1(z_tp1, z_t, t, prompt_embeds, added_cond_kwargs, nom_opt_iters=nom_opt_iters, lr=lr, opt_loss_kl_lambda=opt_loss_kl_lambda)
|
561 |
+
|
562 |
+
return z_tp1
|
563 |
+
|
564 |
+
def latent2image(self, latents):
|
565 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
566 |
+
|
567 |
+
if needs_upcasting:
|
568 |
+
self.upcast_vae()
|
569 |
+
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
570 |
+
|
571 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
572 |
+
|
573 |
+
# cast back to fp16 if needed
|
574 |
+
# if needs_upcasting:
|
575 |
+
# self.vae.to(dtype=torch.float16)
|
576 |
+
|
577 |
+
return image
|
578 |
+
|
579 |
+
def patchify_latents_kl_divergence(self, x0, x1):
|
580 |
+
# devide x0 and x1 into patches (4x64x64) -> (4x4x4)
|
581 |
+
PATCH_SIZE = 4
|
582 |
+
NUM_CHANNELS = 4
|
583 |
+
|
584 |
+
def patchify_tensor(input_tensor):
|
585 |
+
patches = input_tensor.unfold(1, PATCH_SIZE, PATCH_SIZE).unfold(2, PATCH_SIZE, PATCH_SIZE).unfold(3, PATCH_SIZE, PATCH_SIZE)
|
586 |
+
patches = patches.contiguous().view(-1, NUM_CHANNELS, PATCH_SIZE, PATCH_SIZE)
|
587 |
+
return patches
|
588 |
+
|
589 |
+
x0 = patchify_tensor(x0)
|
590 |
+
x1 = patchify_tensor(x1)
|
591 |
+
|
592 |
+
kl = self.latents_kl_divergence(x0, x1).sum()
|
593 |
+
# for i in range(x0.shape[0]):
|
594 |
+
# kl += self.latents_kl_divergence(x0[i], x1[i])
|
595 |
+
return kl
|
596 |
+
|
597 |
+
|
598 |
+
def latents_kl_divergence(self, x0, x1):
|
599 |
+
EPSILON = 1e-6
|
600 |
+
|
601 |
+
#{\displaystyle D_{\text{KL}}\left({\mathcal {N}}_{0}\parallel {\mathcal {N}}_{1}\right)={\frac {1}{2}}\left(\operatorname {tr} \left(\Sigma _{1}^{-1}\Sigma _{0}\right)-k+\left(\mu _{1}-\mu _{0}\right)^{\mathsf {T}}\Sigma _{1}^{-1}\left(\mu _{1}-\mu _{0}\right)+\ln \left({\frac {\det \Sigma _{1}}{\det \Sigma _{0}}}\right)\right).}
|
602 |
+
x0 = x0.view(x0.shape[0], x0.shape[1], -1)
|
603 |
+
x1 = x1.view(x1.shape[0], x1.shape[1], -1)
|
604 |
+
mu0 = x0.mean(dim=-1)
|
605 |
+
mu1 = x1.mean(dim=-1)
|
606 |
+
var0 = x0.var(dim=-1)
|
607 |
+
var1 = x1.var(dim=-1)
|
608 |
+
kl = torch.log((var1 + EPSILON) / (var0 + EPSILON)) + (var0 + (mu0 - mu1)**2) / (var1 + EPSILON) - 1
|
609 |
+
kl = torch.abs(kl).sum(dim=-1)
|
610 |
+
# kl = torch.linalg.norm(mu0 - mu1) + torch.linalg.norm(var0 - var1)
|
611 |
+
# kl *= 1000
|
612 |
+
# sigma0 = torch.cov(x0)
|
613 |
+
# sigma1 = torch.cov(x1)
|
614 |
+
# inv_sigma1 = torch.inverse(sigma1.to(dtype=torch.float64)).to(dtype=x0.dtype)
|
615 |
+
# k = x0.shape[1]
|
616 |
+
# kl = 0.5 * (torch.trace(inv_sigma1 @ sigma0) - k + (mu1 - mu0).T @ inv_sigma1 @ (mu1 - mu0) + torch.log(torch.det(sigma1) / torch.det(sigma0)))
|
617 |
+
return kl
|
618 |
+
|
619 |
+
|
620 |
+
class SpecifyGradient(torch.autograd.Function):
|
621 |
+
@staticmethod
|
622 |
+
@custom_fwd
|
623 |
+
def forward(ctx, input_tensor, gt_grad):
|
624 |
+
ctx.save_for_backward(gt_grad)
|
625 |
+
|
626 |
+
# dummy loss value
|
627 |
+
return torch.zeros([1], device=input_tensor.device, dtype=input_tensor.dtype)
|
628 |
+
|
629 |
+
@staticmethod
|
630 |
+
@custom_bwd
|
631 |
+
def backward(ctx, grad):
|
632 |
+
gt_grad, = ctx.saved_tensors
|
633 |
+
batch_size = len(gt_grad)
|
634 |
+
return gt_grad / batch_size, None
|
src/sdxl_inversion_pipeline.py
ADDED
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Plug&Play Feature Injection
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
5 |
+
from random import randrange
|
6 |
+
import PIL
|
7 |
+
import numpy as np
|
8 |
+
from tqdm import tqdm
|
9 |
+
from torch.cuda.amp import custom_bwd, custom_fwd
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
|
13 |
+
from diffusers import (
|
14 |
+
StableDiffusionXLPipeline,
|
15 |
+
StableDiffusionXLImg2ImgPipeline,
|
16 |
+
DDIMScheduler,
|
17 |
+
)
|
18 |
+
from diffusers.utils.torch_utils import randn_tensor
|
19 |
+
|
20 |
+
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import (
|
21 |
+
rescale_noise_cfg,
|
22 |
+
StableDiffusionXLPipelineOutput,
|
23 |
+
retrieve_timesteps,
|
24 |
+
PipelineImageInput
|
25 |
+
)
|
26 |
+
|
27 |
+
from src.eunms import Scheduler_Type, Gradient_Averaging_Type, Epsilon_Update_Type
|
28 |
+
from src.inversion_utils import noise_regularization
|
29 |
+
|
30 |
+
def _backward_ddim(x_tm1, alpha_t, alpha_tm1, eps_xt):
|
31 |
+
"""
|
32 |
+
let a = alpha_t, b = alpha_{t - 1}
|
33 |
+
We have a > b,
|
34 |
+
x_{t} - x_{t - 1} = sqrt(a) ((sqrt(1/b) - sqrt(1/a)) * x_{t-1} + (sqrt(1/a - 1) - sqrt(1/b - 1)) * eps_{t-1})
|
35 |
+
From https://arxiv.org/pdf/2105.05233.pdf, section F.
|
36 |
+
"""
|
37 |
+
|
38 |
+
a, b = alpha_t, alpha_tm1
|
39 |
+
sa = a**0.5
|
40 |
+
sb = b**0.5
|
41 |
+
|
42 |
+
return sa * ((1 / sb) * x_tm1 + ((1 / a - 1) ** 0.5 - (1 / b - 1) ** 0.5) * eps_xt)
|
43 |
+
|
44 |
+
|
45 |
+
class SDXLDDIMPipeline(StableDiffusionXLImg2ImgPipeline):
|
46 |
+
# @torch.no_grad()
|
47 |
+
def __call__(
|
48 |
+
self,
|
49 |
+
prompt: Union[str, List[str]] = None,
|
50 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
51 |
+
image: PipelineImageInput = None,
|
52 |
+
strength: float = 0.3,
|
53 |
+
num_inversion_steps: int = 50,
|
54 |
+
timesteps: List[int] = None,
|
55 |
+
denoising_start: Optional[float] = None,
|
56 |
+
denoising_end: Optional[float] = None,
|
57 |
+
guidance_scale: float = 1.0,
|
58 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
59 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
60 |
+
num_images_per_prompt: Optional[int] = 1,
|
61 |
+
eta: float = 0.0,
|
62 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
63 |
+
latents: Optional[torch.FloatTensor] = None,
|
64 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
65 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
66 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
67 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
68 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
69 |
+
output_type: Optional[str] = "pil",
|
70 |
+
return_dict: bool = True,
|
71 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
72 |
+
guidance_rescale: float = 0.0,
|
73 |
+
original_size: Tuple[int, int] = None,
|
74 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
75 |
+
target_size: Tuple[int, int] = None,
|
76 |
+
negative_original_size: Optional[Tuple[int, int]] = None,
|
77 |
+
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
78 |
+
negative_target_size: Optional[Tuple[int, int]] = None,
|
79 |
+
aesthetic_score: float = 6.0,
|
80 |
+
negative_aesthetic_score: float = 2.5,
|
81 |
+
clip_skip: Optional[int] = None,
|
82 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
83 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
84 |
+
opt_lr: float = 0.001,
|
85 |
+
opt_iters: int = 1,
|
86 |
+
opt_none_inference_steps: bool = False,
|
87 |
+
opt_loss_kl_lambda: float = 10.0,
|
88 |
+
num_inference_steps: int = 50,
|
89 |
+
num_aprox_steps: int = 100,
|
90 |
+
**kwargs,
|
91 |
+
):
|
92 |
+
callback = kwargs.pop("callback", None)
|
93 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
94 |
+
|
95 |
+
if callback is not None:
|
96 |
+
deprecate(
|
97 |
+
"callback",
|
98 |
+
"1.0.0",
|
99 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
100 |
+
)
|
101 |
+
if callback_steps is not None:
|
102 |
+
deprecate(
|
103 |
+
"callback_steps",
|
104 |
+
"1.0.0",
|
105 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
106 |
+
)
|
107 |
+
|
108 |
+
# 1. Check inputs. Raise error if not correct
|
109 |
+
self.check_inputs(
|
110 |
+
prompt,
|
111 |
+
prompt_2,
|
112 |
+
strength,
|
113 |
+
num_inversion_steps,
|
114 |
+
callback_steps,
|
115 |
+
negative_prompt,
|
116 |
+
negative_prompt_2,
|
117 |
+
prompt_embeds,
|
118 |
+
negative_prompt_embeds,
|
119 |
+
callback_on_step_end_tensor_inputs,
|
120 |
+
)
|
121 |
+
|
122 |
+
denoising_start_fr = 1.0 - denoising_start
|
123 |
+
denoising_start = 0.0 if self.cfg.noise_friendly_inversion else denoising_start
|
124 |
+
|
125 |
+
self._guidance_scale = guidance_scale
|
126 |
+
self._guidance_rescale = guidance_rescale
|
127 |
+
self._clip_skip = clip_skip
|
128 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
129 |
+
self._denoising_end = denoising_end
|
130 |
+
self._denoising_start = denoising_start
|
131 |
+
|
132 |
+
# 2. Define call parameters
|
133 |
+
if prompt is not None and isinstance(prompt, str):
|
134 |
+
batch_size = 1
|
135 |
+
elif prompt is not None and isinstance(prompt, list):
|
136 |
+
batch_size = len(prompt)
|
137 |
+
else:
|
138 |
+
batch_size = prompt_embeds.shape[0]
|
139 |
+
|
140 |
+
device = self._execution_device
|
141 |
+
|
142 |
+
# 3. Encode input prompt
|
143 |
+
text_encoder_lora_scale = (
|
144 |
+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
145 |
+
)
|
146 |
+
(
|
147 |
+
prompt_embeds,
|
148 |
+
negative_prompt_embeds,
|
149 |
+
pooled_prompt_embeds,
|
150 |
+
negative_pooled_prompt_embeds,
|
151 |
+
) = self.encode_prompt(
|
152 |
+
prompt=prompt,
|
153 |
+
prompt_2=prompt_2,
|
154 |
+
device=device,
|
155 |
+
num_images_per_prompt=num_images_per_prompt,
|
156 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
157 |
+
negative_prompt=negative_prompt,
|
158 |
+
negative_prompt_2=negative_prompt_2,
|
159 |
+
prompt_embeds=prompt_embeds,
|
160 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
161 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
162 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
163 |
+
lora_scale=text_encoder_lora_scale,
|
164 |
+
clip_skip=self.clip_skip,
|
165 |
+
)
|
166 |
+
|
167 |
+
# 4. Preprocess image
|
168 |
+
image = self.image_processor.preprocess(image)
|
169 |
+
|
170 |
+
# 5. Prepare timesteps
|
171 |
+
def denoising_value_valid(dnv):
|
172 |
+
return isinstance(self.denoising_end, float) and 0 < dnv < 1
|
173 |
+
|
174 |
+
timesteps, num_inversion_steps = retrieve_timesteps(self.scheduler, num_inversion_steps, device, timesteps)
|
175 |
+
timesteps_num_inference_steps, num_inference_steps = retrieve_timesteps(self.scheduler_inference, num_inference_steps, device, None)
|
176 |
+
|
177 |
+
timesteps, num_inversion_steps = self.get_timesteps(
|
178 |
+
num_inversion_steps,
|
179 |
+
strength,
|
180 |
+
device,
|
181 |
+
denoising_start=self.denoising_start if denoising_value_valid else None,
|
182 |
+
)
|
183 |
+
# latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
184 |
+
|
185 |
+
# add_noise = True if self.denoising_start is None else False
|
186 |
+
# 6. Prepare latent variables
|
187 |
+
with torch.no_grad():
|
188 |
+
latents = self.prepare_latents(
|
189 |
+
image,
|
190 |
+
None,
|
191 |
+
batch_size,
|
192 |
+
num_images_per_prompt,
|
193 |
+
prompt_embeds.dtype,
|
194 |
+
device,
|
195 |
+
generator,
|
196 |
+
False,
|
197 |
+
)
|
198 |
+
# 7. Prepare extra step kwargs.
|
199 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
200 |
+
|
201 |
+
height, width = latents.shape[-2:]
|
202 |
+
height = height * self.vae_scale_factor
|
203 |
+
width = width * self.vae_scale_factor
|
204 |
+
|
205 |
+
original_size = original_size or (height, width)
|
206 |
+
target_size = target_size or (height, width)
|
207 |
+
|
208 |
+
# 8. Prepare added time ids & embeddings
|
209 |
+
if negative_original_size is None:
|
210 |
+
negative_original_size = original_size
|
211 |
+
if negative_target_size is None:
|
212 |
+
negative_target_size = target_size
|
213 |
+
|
214 |
+
add_text_embeds = pooled_prompt_embeds
|
215 |
+
if self.text_encoder_2 is None:
|
216 |
+
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
217 |
+
else:
|
218 |
+
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
219 |
+
|
220 |
+
add_time_ids, add_neg_time_ids = self._get_add_time_ids(
|
221 |
+
original_size,
|
222 |
+
crops_coords_top_left,
|
223 |
+
target_size,
|
224 |
+
aesthetic_score,
|
225 |
+
negative_aesthetic_score,
|
226 |
+
negative_original_size,
|
227 |
+
negative_crops_coords_top_left,
|
228 |
+
negative_target_size,
|
229 |
+
dtype=prompt_embeds.dtype,
|
230 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
231 |
+
)
|
232 |
+
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
233 |
+
|
234 |
+
if self.do_classifier_free_guidance:
|
235 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
236 |
+
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
237 |
+
add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
238 |
+
add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
|
239 |
+
|
240 |
+
prompt_embeds = prompt_embeds.to(device)
|
241 |
+
add_text_embeds = add_text_embeds.to(device)
|
242 |
+
add_time_ids = add_time_ids.to(device)
|
243 |
+
|
244 |
+
if ip_adapter_image is not None:
|
245 |
+
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
|
246 |
+
if self.do_classifier_free_guidance:
|
247 |
+
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
248 |
+
image_embeds = image_embeds.to(device)
|
249 |
+
|
250 |
+
# 9. Denoising loop
|
251 |
+
num_warmup_steps = max(len(timesteps) - num_inversion_steps * self.scheduler.order, 0)
|
252 |
+
prev_timestep = None
|
253 |
+
|
254 |
+
self._num_timesteps = len(timesteps)
|
255 |
+
self.prev_z = torch.clone(latents)
|
256 |
+
self.prev_z4 = torch.clone(latents)
|
257 |
+
self.z_0 = torch.clone(latents)
|
258 |
+
g_cpu = torch.Generator().manual_seed(7865)
|
259 |
+
self.noise = randn_tensor(self.z_0.shape, generator=g_cpu, device=self.z_0.device, dtype=self.z_0.dtype)
|
260 |
+
|
261 |
+
# Friendly inversion params
|
262 |
+
timesteps_for = timesteps if self.cfg.noise_friendly_inversion else reversed(timesteps)
|
263 |
+
noise = randn_tensor(latents.shape, generator=g_cpu, device=latents.device, dtype=latents.dtype)
|
264 |
+
latents = self.scheduler.add_noise(self.z_0, noise, timesteps_for[0].view((1))).detach() if self.cfg.noise_friendly_inversion else latents
|
265 |
+
z_T = latents.clone()
|
266 |
+
|
267 |
+
all_latents = [latents.clone()]
|
268 |
+
with self.progress_bar(total=num_inversion_steps) as progress_bar:
|
269 |
+
for i, t in enumerate(timesteps_for):
|
270 |
+
|
271 |
+
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
272 |
+
if ip_adapter_image is not None:
|
273 |
+
added_cond_kwargs["image_embeds"] = image_embeds
|
274 |
+
|
275 |
+
z_tp1 = self.inversion_step(latents,
|
276 |
+
t,
|
277 |
+
prompt_embeds,
|
278 |
+
added_cond_kwargs,
|
279 |
+
prev_timestep=prev_timestep,
|
280 |
+
num_aprox_steps=num_aprox_steps)
|
281 |
+
|
282 |
+
prev_timestep = t
|
283 |
+
latents = z_tp1
|
284 |
+
|
285 |
+
all_latents.append(latents.clone())
|
286 |
+
|
287 |
+
if self.cfg.noise_friendly_inversion and t.item() > 1000 * denoising_start_fr:
|
288 |
+
z_T = latents.clone()
|
289 |
+
|
290 |
+
if callback_on_step_end is not None:
|
291 |
+
callback_kwargs = {}
|
292 |
+
for k in callback_on_step_end_tensor_inputs:
|
293 |
+
callback_kwargs[k] = locals()[k]
|
294 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
295 |
+
|
296 |
+
latents = callback_outputs.pop("latents", latents)
|
297 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
298 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
299 |
+
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
|
300 |
+
negative_pooled_prompt_embeds = callback_outputs.pop(
|
301 |
+
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
302 |
+
)
|
303 |
+
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
|
304 |
+
add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
|
305 |
+
|
306 |
+
# call the callback, if provided
|
307 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
308 |
+
progress_bar.update()
|
309 |
+
if callback is not None and i % callback_steps == 0:
|
310 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
311 |
+
callback(step_idx, t, latents)
|
312 |
+
|
313 |
+
if self.cfg.noise_friendly_inversion:
|
314 |
+
latents = z_T
|
315 |
+
|
316 |
+
image = latents
|
317 |
+
|
318 |
+
# Offload all models
|
319 |
+
self.maybe_free_model_hooks()
|
320 |
+
|
321 |
+
return StableDiffusionXLPipelineOutput(images=image), all_latents
|
322 |
+
|
323 |
+
# @torch.no_grad()
|
324 |
+
def inversion_step(
|
325 |
+
self,
|
326 |
+
z_t: torch.tensor,
|
327 |
+
t: torch.tensor,
|
328 |
+
prompt_embeds,
|
329 |
+
added_cond_kwargs,
|
330 |
+
prev_timestep: Optional[torch.tensor] = None,
|
331 |
+
num_aprox_steps: int = 100
|
332 |
+
) -> torch.tensor:
|
333 |
+
extra_step_kwargs = {}
|
334 |
+
|
335 |
+
avg_range = self.cfg.gradient_averaging_first_step_range if t.item() < 250 else self.cfg.gradient_averaging_step_range
|
336 |
+
num_aprox_steps = min(self.cfg.max_num_aprox_steps_first_step, num_aprox_steps) if t.item() < 250 else num_aprox_steps
|
337 |
+
|
338 |
+
nosie_pred_avg = None
|
339 |
+
z_tp1_forward = self.scheduler.add_noise(self.z_0, self.noise, t.view((1))).detach()
|
340 |
+
noise_pred_optimal = None
|
341 |
+
|
342 |
+
approximated_z_tp1 = z_t.clone()
|
343 |
+
for i in range(num_aprox_steps + 1):
|
344 |
+
|
345 |
+
with torch.no_grad():
|
346 |
+
if self.cfg.num_reg_steps > 0 and i == 0:
|
347 |
+
approximated_z_tp1 = torch.cat([z_tp1_forward, approximated_z_tp1])
|
348 |
+
prompt_embeds_in = torch.cat([prompt_embeds, prompt_embeds])
|
349 |
+
added_cond_kwargs_in = {}
|
350 |
+
added_cond_kwargs_in['text_embeds'] = torch.cat([added_cond_kwargs['text_embeds'], added_cond_kwargs['text_embeds']])
|
351 |
+
added_cond_kwargs_in['time_ids'] = torch.cat([added_cond_kwargs['time_ids'], added_cond_kwargs['time_ids']])
|
352 |
+
else:
|
353 |
+
prompt_embeds_in = prompt_embeds
|
354 |
+
added_cond_kwargs_in = added_cond_kwargs
|
355 |
+
|
356 |
+
noise_pred = self.unet_pass(approximated_z_tp1, t, prompt_embeds_in, added_cond_kwargs_in)
|
357 |
+
|
358 |
+
if self.cfg.num_reg_steps > 0 and i == 0:
|
359 |
+
noise_pred_optimal, noise_pred = noise_pred.chunk(2)
|
360 |
+
noise_pred_optimal = noise_pred_optimal.detach()
|
361 |
+
|
362 |
+
# perform guidance
|
363 |
+
if self.do_classifier_free_guidance:
|
364 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
365 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
366 |
+
|
367 |
+
# Calculate average noise
|
368 |
+
if i >= avg_range[0] and i < avg_range[1]:
|
369 |
+
j = i - avg_range[0]
|
370 |
+
if nosie_pred_avg is None:
|
371 |
+
nosie_pred_avg = noise_pred.clone()
|
372 |
+
else:
|
373 |
+
nosie_pred_avg = j * nosie_pred_avg / (j + 1) + noise_pred / (j + 1)
|
374 |
+
|
375 |
+
if i >= avg_range[0] or (self.cfg.gradient_averaging_type == Gradient_Averaging_Type.NONE and i > 0):
|
376 |
+
noise_pred = noise_regularization(noise_pred, noise_pred_optimal, lambda_kl=self.cfg.lambda_kl, lambda_ac=self.cfg.lambda_ac, num_reg_steps=self.cfg.num_reg_steps, num_ac_rolls=self.cfg.num_ac_rolls)
|
377 |
+
|
378 |
+
approximated_z_tp1 = self.backward_step(noise_pred, t, z_t, prev_timestep)
|
379 |
+
|
380 |
+
if self.cfg.gradient_averaging_type == Gradient_Averaging_Type.ON_END and nosie_pred_avg is not None:
|
381 |
+
|
382 |
+
nosie_pred_avg = noise_regularization(nosie_pred_avg, noise_pred_optimal, lambda_kl=self.cfg.lambda_kl, lambda_ac=self.cfg.lambda_ac, num_reg_steps=self.cfg.num_reg_steps, num_ac_rolls=self.cfg.num_ac_rolls)
|
383 |
+
approximated_z_tp1 = self.backward_step(nosie_pred_avg, t, z_t, prev_timestep)
|
384 |
+
|
385 |
+
if self.cfg.update_epsilon_type != Epsilon_Update_Type.NONE:
|
386 |
+
noise_pred = self.unet_pass(approximated_z_tp1, t, prompt_embeds, added_cond_kwargs)
|
387 |
+
|
388 |
+
# perform guidance
|
389 |
+
if self.do_classifier_free_guidance:
|
390 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
391 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
392 |
+
|
393 |
+
self.scheduler.step_and_update_noise(noise_pred, t, approximated_z_tp1, z_t, return_dict=False, update_epsilon_type=self.cfg.update_epsilon_type)
|
394 |
+
|
395 |
+
return approximated_z_tp1
|
396 |
+
|
397 |
+
@torch.no_grad()
|
398 |
+
def unet_pass(self, z_t, t, prompt_embeds, added_cond_kwargs):
|
399 |
+
latent_model_input = torch.cat([z_t] * 2) if self.do_classifier_free_guidance else z_t
|
400 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
401 |
+
return self.unet(
|
402 |
+
latent_model_input,
|
403 |
+
t,
|
404 |
+
encoder_hidden_states=prompt_embeds,
|
405 |
+
timestep_cond=None,
|
406 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
407 |
+
added_cond_kwargs=added_cond_kwargs,
|
408 |
+
return_dict=False,
|
409 |
+
)[0]
|
410 |
+
|
411 |
+
@torch.no_grad()
|
412 |
+
def backward_step(self, nosie_pred, t, z_t, prev_timestep):
|
413 |
+
extra_step_kwargs = {}
|
414 |
+
if self.cfg.scheduler_type == Scheduler_Type.EULER or self.cfg.scheduler_type == Scheduler_Type.LCM:
|
415 |
+
return self.scheduler.inv_step(nosie_pred, t, z_t, **extra_step_kwargs, return_dict=False)[0].detach()
|
416 |
+
else:
|
417 |
+
alpha_prod_t = self.scheduler.alphas_cumprod[t]
|
418 |
+
alpha_prod_t_prev = (
|
419 |
+
self.scheduler.alphas_cumprod[prev_timestep]
|
420 |
+
if prev_timestep is not None
|
421 |
+
else self.scheduler.final_alpha_cumprod
|
422 |
+
)
|
423 |
+
return _backward_ddim(
|
424 |
+
x_tm1=z_t,
|
425 |
+
alpha_t=alpha_prod_t,
|
426 |
+
alpha_tm1=alpha_prod_t_prev,
|
427 |
+
eps_xt=nosie_pred,
|
428 |
+
)
|
429 |
+
|
430 |
+
|
style.css
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
h1 {
|
2 |
+
text-align: center;
|
3 |
+
}
|
4 |
+
|