Spaces:
Runtime error
Runtime error
sincostanx
commited on
Commit
•
2b4e6ab
1
Parent(s):
8421ea7
add prototype
Browse files- app.py +217 -4
- momentum_scheduler.py +385 -0
- pipeline.py +236 -0
- requirements.txt +96 -0
app.py
CHANGED
@@ -1,7 +1,220 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
|
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from pipeline import CustomPipeline, setup_scheduler
|
4 |
+
from PIL import Image
|
5 |
+
# from easydict import EasyDict as edict
|
6 |
|
7 |
+
original_pipe = None
|
8 |
+
original_config = None
|
9 |
+
device = None
|
10 |
|
11 |
+
|
12 |
+
# def run_dpm_demo(id, prompt, beta, num_inference_steps, guidance_scale, seed, enable_token_merging):
|
13 |
+
def run_dpm_demo(prompt, beta, num_inference_steps, guidance_scale, seed):
|
14 |
+
global original_pipe, original_config
|
15 |
+
pipe = CustomPipeline(**original_pipe.components)
|
16 |
+
|
17 |
+
seed = int(seed)
|
18 |
+
num_inference_steps = int(num_inference_steps)
|
19 |
+
|
20 |
+
scheduler = "DPM-Solver++"
|
21 |
+
params = {
|
22 |
+
"prompt": prompt,
|
23 |
+
"num_inference_steps": num_inference_steps,
|
24 |
+
"guidance_scale": guidance_scale,
|
25 |
+
"method": "dpm"
|
26 |
+
}
|
27 |
+
|
28 |
+
# without momentum (equivalent to DPM-Solver++)
|
29 |
+
pipe = setup_scheduler(pipe, scheduler, beta=1.0, original_config=original_config)
|
30 |
+
params["generator"] = torch.Generator(device=device).manual_seed(seed)
|
31 |
+
ori_image = pipe(**params).images[0]
|
32 |
+
|
33 |
+
# with momentum
|
34 |
+
pipe = setup_scheduler(pipe, scheduler, beta=beta, original_config=original_config)
|
35 |
+
params["generator"] = torch.Generator(device=device).manual_seed(seed)
|
36 |
+
image = pipe(**params).images[0]
|
37 |
+
|
38 |
+
ori_image.save("temp1.png")
|
39 |
+
image.save("temp2.png")
|
40 |
+
|
41 |
+
return [ori_image, image]
|
42 |
+
|
43 |
+
# def run_plms_demo(id, prompt, order, beta, momentum_type, num_inference_steps, guidance_scale, seed, enable_token_merging):
|
44 |
+
def run_plms_demo(prompt, order, beta, momentum_type, num_inference_steps, guidance_scale, seed):
|
45 |
+
global original_pipe, original_config
|
46 |
+
pipe = CustomPipeline(**original_pipe.components)
|
47 |
+
|
48 |
+
seed = int(seed)
|
49 |
+
num_inference_steps = int(num_inference_steps)
|
50 |
+
|
51 |
+
scheduler = "PLMS"
|
52 |
+
method = "hb" if momentum_type == "Polyak's heavy ball" else "nt"
|
53 |
+
params = {
|
54 |
+
"prompt": prompt,
|
55 |
+
"num_inference_steps": num_inference_steps,
|
56 |
+
"guidance_scale": guidance_scale,
|
57 |
+
"method": method
|
58 |
+
}
|
59 |
+
|
60 |
+
# without momentum (equivalent to PLMS)
|
61 |
+
pipe = setup_scheduler(pipe, scheduler, momentum_type=momentum_type, order=order, beta=1.0, original_config=original_config)
|
62 |
+
params["generator"] = torch.Generator(device=device).manual_seed(seed)
|
63 |
+
ori_image = pipe(**params).images[0]
|
64 |
+
|
65 |
+
# with momentum
|
66 |
+
pipe = setup_scheduler(pipe, scheduler, momentum_type=momentum_type, order=order, beta=beta, original_config=original_config)
|
67 |
+
params["generator"] = torch.Generator(device=device).manual_seed(seed)
|
68 |
+
image = pipe(**params).images[0]
|
69 |
+
|
70 |
+
return [ori_image, image]
|
71 |
+
|
72 |
+
# def run_ghvb_demo(id, prompt, order, beta, num_inference_steps, guidance_scale, seed, enable_token_merging):
|
73 |
+
def run_ghvb_demo(prompt, order, beta, num_inference_steps, guidance_scale, seed):
|
74 |
+
global original_pipe, original_config
|
75 |
+
pipe = CustomPipeline(**original_pipe.components)
|
76 |
+
|
77 |
+
seed = int(seed)
|
78 |
+
num_inference_steps = int(num_inference_steps)
|
79 |
+
|
80 |
+
scheduler = "GHVB"
|
81 |
+
params = {
|
82 |
+
"prompt": prompt,
|
83 |
+
"num_inference_steps": num_inference_steps,
|
84 |
+
"guidance_scale": guidance_scale,
|
85 |
+
"method": "ghvb"
|
86 |
+
}
|
87 |
+
|
88 |
+
# without momentum (equivalent to PLMS)
|
89 |
+
pipe = setup_scheduler(pipe, scheduler, order=order, beta=1.0, original_config=original_config)
|
90 |
+
params["generator"] = torch.Generator(device=device).manual_seed(seed)
|
91 |
+
ori_image = pipe(**params).images[0]
|
92 |
+
|
93 |
+
# with momentum
|
94 |
+
pipe = setup_scheduler(pipe, scheduler, order=order, beta=beta, original_config=original_config)
|
95 |
+
params["generator"] = torch.Generator(device=device).manual_seed(seed)
|
96 |
+
image = pipe(**params).images[0]
|
97 |
+
|
98 |
+
return [ori_image, image]
|
99 |
+
|
100 |
+
if __name__ == "__main__":
|
101 |
+
|
102 |
+
demo = gr.Blocks()
|
103 |
+
|
104 |
+
inputs = {}
|
105 |
+
outputs = {}
|
106 |
+
buttons = {}
|
107 |
+
|
108 |
+
list_models = [
|
109 |
+
|
110 |
+
]
|
111 |
+
|
112 |
+
with gr.Blocks() as demo:
|
113 |
+
gr.Markdown(
|
114 |
+
"""
|
115 |
+
# Momentum-Diffusion Demo
|
116 |
+
|
117 |
+
A novel sampling method for diffusion models based on momentum to reduce artifacts
|
118 |
+
|
119 |
+
"""
|
120 |
+
)
|
121 |
+
id = gr.Dropdown(list_models, label="Model ID", value="Linaqruf/anything-v3.0", allow_custom_value=True)
|
122 |
+
enable_token_merging = gr.Checkbox(label="Enable Token Merging", value=False)
|
123 |
+
# output = gr.Textbox()
|
124 |
+
buttons["select_model"] = gr.Button("Select")
|
125 |
+
|
126 |
+
with gr.Tab("GHVB", visible=False) as tab3:
|
127 |
+
prompt3 = gr.Textbox(label="Prompt", value="a cozy cafe", visible=False)
|
128 |
+
|
129 |
+
with gr.Row(visible=False) as row31:
|
130 |
+
order = gr.Slider(minimum=1, maximum=4, value=4, step=1, label="order")
|
131 |
+
beta = gr.Slider(minimum=0, maximum=1, value=0.4, step=0.05, label="beta")
|
132 |
+
num_inference_steps = gr.Number(label="Number of steps", value=12)
|
133 |
+
guidance_scale = gr.Number(label="Guidance scale (cfg)", value=10)
|
134 |
+
seed = gr.Number(label="Seed", value=42)
|
135 |
+
|
136 |
+
with gr.Row(visible=False) as row32:
|
137 |
+
out1 = gr.Image(label="PLMS", interactive=False)
|
138 |
+
out2 = gr.Image(label="GHVB", interactive=False)
|
139 |
+
|
140 |
+
inputs["GHVB"] = [prompt3, order, beta, num_inference_steps, guidance_scale, seed]
|
141 |
+
outputs["GHVB"] = [out1, out2]
|
142 |
+
buttons["GHVB"] = gr.Button("Sample", visible=False)
|
143 |
+
|
144 |
+
with gr.Tab("PLMS", visible=False) as tab2:
|
145 |
+
prompt2 = gr.Textbox(label="Prompt", value="1girl", visible=False)
|
146 |
+
|
147 |
+
with gr.Row(visible=False) as row21:
|
148 |
+
order = gr.Slider(minimum=1, maximum=4, value=4, step=1, label="order")
|
149 |
+
beta = gr.Slider(minimum=0, maximum=1, value=0.7, step=0.05, label="beta")
|
150 |
+
momentum_type = gr.Dropdown(["Polyak's heavy ball", "Nesterov"], label="Momentum Type", value="Polyak's heavy ball")
|
151 |
+
num_inference_steps = gr.Number(label="Number of steps", value=10)
|
152 |
+
guidance_scale = gr.Number(label="Guidance scale (cfg)", value=10)
|
153 |
+
seed = gr.Number(label="Seed", value=42)
|
154 |
+
|
155 |
+
with gr.Row(visible=False) as row22:
|
156 |
+
out1 = gr.Image(label="Without momentum", interactive=False)
|
157 |
+
out2 = gr.Image(label="With momentum", interactive=False)
|
158 |
+
|
159 |
+
inputs["PLMS"] = [prompt2, order, beta, momentum_type, num_inference_steps, guidance_scale, seed]
|
160 |
+
outputs["PLMS"] = [out1, out2]
|
161 |
+
buttons["PLMS"] = gr.Button("Sample", visible=False)
|
162 |
+
|
163 |
+
with gr.Tab("DPM-Solver++", visible=False) as tab1:
|
164 |
+
prompt1 = gr.Textbox(label="Prompt", value="1girl", visible=False)
|
165 |
+
|
166 |
+
with gr.Row(visible=False) as row11:
|
167 |
+
beta = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.05, label="beta")
|
168 |
+
num_inference_steps = gr.Number(label="Number of steps", value=15)
|
169 |
+
guidance_scale = gr.Number(label="Guidance scale (cfg)", value=20)
|
170 |
+
seed = gr.Number(label="Seed", value=0)
|
171 |
+
|
172 |
+
with gr.Row(visible=False) as row12:
|
173 |
+
out1 = gr.Image(label="Without momentum", interactive=False)
|
174 |
+
out2 = gr.Image(label="With momentum", interactive=False)
|
175 |
+
|
176 |
+
inputs["DPM-Solver++"] = [prompt1, beta, num_inference_steps, guidance_scale, seed]
|
177 |
+
outputs["DPM-Solver++"] = [out1, out2]
|
178 |
+
buttons["DPM-Solver++"] = gr.Button("Sample", visible=False)
|
179 |
+
|
180 |
+
def prepare_model(id, enable_token_merging):
|
181 |
+
global original_pipe, original_config, device
|
182 |
+
|
183 |
+
if original_pipe is not None:
|
184 |
+
del original_pipe
|
185 |
+
|
186 |
+
original_pipe = CustomPipeline.from_pretrained(id)
|
187 |
+
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
|
188 |
+
original_pipe = original_pipe.to(device)
|
189 |
+
|
190 |
+
if enable_token_merging:
|
191 |
+
import tomesd
|
192 |
+
tomesd.apply_patch(original_pipe, ratio=0.5)
|
193 |
+
print("Enabled Token merging.")
|
194 |
+
|
195 |
+
original_config = original_pipe.scheduler.config
|
196 |
+
print(type(original_pipe))
|
197 |
+
print(original_config)
|
198 |
+
|
199 |
+
return {
|
200 |
+
row11: gr.update(visible=True),
|
201 |
+
row12: gr.update(visible=True),
|
202 |
+
row21: gr.update(visible=True),
|
203 |
+
row22: gr.update(visible=True),
|
204 |
+
row31: gr.update(visible=True),
|
205 |
+
row32: gr.update(visible=True),
|
206 |
+
prompt1: gr.update(visible=True),
|
207 |
+
prompt2: gr.update(visible=True),
|
208 |
+
prompt3: gr.update(visible=True),
|
209 |
+
buttons["DPM-Solver++"]: gr.update(visible=True),
|
210 |
+
buttons["PLMS"]: gr.update(visible=True),
|
211 |
+
buttons["GHVB"]: gr.update(visible=True),
|
212 |
+
}
|
213 |
+
|
214 |
+
all_outputs = [row11, row12, row21, row22, row31, row32, prompt1, prompt2, prompt3, buttons["DPM-Solver++"], buttons["PLMS"], buttons["GHVB"]]
|
215 |
+
buttons["select_model"].click(prepare_model, inputs=[id, enable_token_merging], outputs=all_outputs)
|
216 |
+
buttons["DPM-Solver++"].click(run_dpm_demo, inputs=inputs["DPM-Solver++"], outputs=outputs["DPM-Solver++"])
|
217 |
+
buttons["PLMS"].click(run_plms_demo, inputs=inputs["PLMS"], outputs=outputs["PLMS"])
|
218 |
+
buttons["GHVB"].click(run_ghvb_demo, inputs=inputs["GHVB"], outputs=outputs["GHVB"])
|
219 |
+
|
220 |
+
demo.launch(share=True)
|
momentum_scheduler.py
ADDED
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from diffusers import DPMSolverMultistepScheduler, UniPCMultistepScheduler
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
def AdamBmixer(order, ets, b=1):
|
6 |
+
|
7 |
+
cur_order = min(order, len(ets))
|
8 |
+
if cur_order == 1:
|
9 |
+
prime = b * ets[-1]
|
10 |
+
elif cur_order == 2:
|
11 |
+
prime = ((2+b) * ets[-1] - (2-b)*ets[-2]) / 2
|
12 |
+
elif cur_order == 3:
|
13 |
+
prime = ((18+5*b) * ets[-1] - (24-8*b) * ets[-2] + (6-1*b) * ets[-3]) / 12
|
14 |
+
elif cur_order == 4:
|
15 |
+
prime = ((46+9*b) * ets[-1] - (78-19*b) * ets[-2] + (42-5*b) * ets[-3] - (10-b) * ets[-4]) / 24
|
16 |
+
elif cur_order == 5:
|
17 |
+
prime = ((1650+251*b) * ets[-1] - (3420-646*b) * ets[-2]
|
18 |
+
+ (2880-264*b) * ets[-3] - (1380-106*b) * ets[-4]
|
19 |
+
+ (270-19*b)* ets[-5]) / 720
|
20 |
+
else:
|
21 |
+
raise NotImplementedError
|
22 |
+
|
23 |
+
prime = prime/b
|
24 |
+
return prime
|
25 |
+
|
26 |
+
class PLMSWithHBScheduler():
|
27 |
+
"""
|
28 |
+
PLMS with Polyak's Heavy Ball Momentum (HB) for diffusion ODEs.
|
29 |
+
We implement it as a wrapper for schedulers in diffusers (https://github.com/huggingface/diffusers)
|
30 |
+
|
31 |
+
When order is an integer, this method is equivalent to PLMS without momentum.
|
32 |
+
"""
|
33 |
+
def __init__(self, scheduler, order):
|
34 |
+
self.scheduler = scheduler
|
35 |
+
self.ets = []
|
36 |
+
self.update_order(order)
|
37 |
+
self.mixer = AdamBmixer
|
38 |
+
|
39 |
+
def update_order(self, order):
|
40 |
+
self.order = order // 1 + 1 if order%1 > 0 else order // 1
|
41 |
+
self.beta = order % 1 if order%1 > 0 else 1
|
42 |
+
self.vel = None
|
43 |
+
|
44 |
+
def clear(self):
|
45 |
+
self.ets = []
|
46 |
+
self.vel = None
|
47 |
+
|
48 |
+
def update_ets(self, val):
|
49 |
+
self.ets.append(val)
|
50 |
+
if len(self.ets) > self.order:
|
51 |
+
self.ets.pop(0)
|
52 |
+
|
53 |
+
def _step_with_momentum(self, grads):
|
54 |
+
self.update_ets(grads)
|
55 |
+
prime = self.mixer(self.order, self.ets, 1.0)
|
56 |
+
self.vel = (1 - self.beta) * self.vel + self.beta * prime
|
57 |
+
return self.vel
|
58 |
+
|
59 |
+
def step(
|
60 |
+
self,
|
61 |
+
grads: torch.FloatTensor,
|
62 |
+
timestep: int,
|
63 |
+
latents: torch.FloatTensor,
|
64 |
+
output_mode: str = "scale",
|
65 |
+
):
|
66 |
+
if self.vel is None: self.vel = grads
|
67 |
+
|
68 |
+
if hasattr(self.scheduler, 'sigmas'):
|
69 |
+
step_index = (self.scheduler.timesteps == timestep).nonzero().item()
|
70 |
+
sigma = self.scheduler.sigmas[step_index]
|
71 |
+
sigma_next = self.scheduler.sigmas[step_index + 1]
|
72 |
+
del_g = sigma_next - sigma
|
73 |
+
|
74 |
+
update_val = self._step_with_momentum(grads)
|
75 |
+
return latents + del_g * update_val
|
76 |
+
|
77 |
+
elif isinstance(self.scheduler, DPMSolverMultistepScheduler):
|
78 |
+
step_index = (self.scheduler.timesteps == timestep).nonzero().item()
|
79 |
+
current_timestep = self.scheduler.timesteps[step_index]
|
80 |
+
prev_timestep = 0 if step_index == len(self.scheduler.timesteps) - 1 else self.scheduler.timesteps[step_index + 1]
|
81 |
+
|
82 |
+
alpha_prod_t = self.scheduler.alphas_cumprod[current_timestep]
|
83 |
+
alpha_bar_prev = self.scheduler.alphas_cumprod[prev_timestep]
|
84 |
+
|
85 |
+
s0 = torch.sqrt(alpha_prod_t)
|
86 |
+
s_1 = torch.sqrt(alpha_bar_prev)
|
87 |
+
g0 = torch.sqrt(1-alpha_prod_t)/s0
|
88 |
+
g_1 = torch.sqrt(1-alpha_bar_prev)/s_1
|
89 |
+
del_g = g_1 - g0
|
90 |
+
|
91 |
+
update_val = self._step_with_momentum(grads)
|
92 |
+
if output_mode in ["scale"]:
|
93 |
+
return (latents/s0 + del_g * update_val) * s_1
|
94 |
+
elif output_mode in ["back"]:
|
95 |
+
return latents + del_g * update_val * s_1
|
96 |
+
elif output_mode in ["front"]:
|
97 |
+
return latents + del_g * update_val * s0
|
98 |
+
else:
|
99 |
+
return latents + del_g * update_val
|
100 |
+
else:
|
101 |
+
raise NotImplementedError
|
102 |
+
|
103 |
+
class GHVBScheduler(PLMSWithHBScheduler):
|
104 |
+
"""
|
105 |
+
Generalizing Polyak's Heavy Bal (GHVB) for diffusion ODEs.
|
106 |
+
We implement it as a wrapper for schedulers in diffusers (https://github.com/huggingface/diffusers)
|
107 |
+
|
108 |
+
When order is an integer, this method is equivalent to PLMS without momentum.
|
109 |
+
"""
|
110 |
+
def _step_with_momentum(self, grads):
|
111 |
+
self.vel = (1 - self.beta) * self.vel + self.beta * grads
|
112 |
+
self.update_ets(self.vel)
|
113 |
+
prime = self.mixer(self.order, self.ets, self.beta)
|
114 |
+
return prime
|
115 |
+
|
116 |
+
class PLMSWithNTScheduler(PLMSWithHBScheduler):
|
117 |
+
"""
|
118 |
+
PLMS with Nesterov Momentum (NT) for diffusion ODEs.
|
119 |
+
We implement it as a wrapper for schedulers in diffusers (https://github.com/huggingface/diffusers)
|
120 |
+
|
121 |
+
When order is an integer, this method is equivalent to PLMS without momentum.
|
122 |
+
"""
|
123 |
+
def _step_with_momentum(self, grads):
|
124 |
+
self.update_ets(grads)
|
125 |
+
prime = self.mixer(self.order, self.ets, 1.0) # update v^{(2)}
|
126 |
+
self.vel = (1 - self.beta) * self.vel + self.beta * prime # update v^{(1)}
|
127 |
+
update_val = (1 - self.beta) * self.vel + self.beta * prime # update x
|
128 |
+
return update_val
|
129 |
+
|
130 |
+
class MomentumDPMSolverMultistepScheduler(DPMSolverMultistepScheduler):
|
131 |
+
"""
|
132 |
+
DPM-Solver++2M with HB momentum.
|
133 |
+
Currently support only algorithm_type = "dpmsolver++" and solver_type = "midpoint"
|
134 |
+
|
135 |
+
When beta = 1.0, this method is equivalent to DPM-Solver++2M without momentum.
|
136 |
+
"""
|
137 |
+
def initialize_momentum(self, beta):
|
138 |
+
self.vel = None
|
139 |
+
self.beta = beta
|
140 |
+
|
141 |
+
def multistep_dpm_solver_second_order_update(
|
142 |
+
self,
|
143 |
+
model_output_list: List[torch.FloatTensor],
|
144 |
+
timestep_list: List[int],
|
145 |
+
prev_timestep: int,
|
146 |
+
sample: torch.FloatTensor,
|
147 |
+
) -> torch.FloatTensor:
|
148 |
+
|
149 |
+
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
|
150 |
+
m0, m1 = model_output_list[-1], model_output_list[-2]
|
151 |
+
lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]
|
152 |
+
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
|
153 |
+
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
|
154 |
+
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
|
155 |
+
r0 = h_0 / h
|
156 |
+
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
|
157 |
+
if self.config.algorithm_type == "dpmsolver++":
|
158 |
+
# See https://arxiv.org/abs/2211.01095 for detailed derivations
|
159 |
+
if self.config.solver_type == "midpoint":
|
160 |
+
diff = (D0 + 0.5 * D1)
|
161 |
+
|
162 |
+
if self.vel is None:
|
163 |
+
self.vel = diff
|
164 |
+
else:
|
165 |
+
self.vel = (1-self.beta)*self.vel + self.beta * diff
|
166 |
+
|
167 |
+
x_t = (
|
168 |
+
(sigma_t / sigma_s0) * sample
|
169 |
+
- (alpha_t * (torch.exp(-h) - 1.0)) * self.vel
|
170 |
+
)
|
171 |
+
elif self.config.solver_type == "heun":
|
172 |
+
raise NotImplementedError(
|
173 |
+
"{self.config.algorithm_type} with {self.config.solver_type} is currently not supported."
|
174 |
+
)
|
175 |
+
elif self.config.algorithm_type == "dpmsolver":
|
176 |
+
# See https://arxiv.org/abs/2206.00927 for detailed derivations
|
177 |
+
if self.config.solver_type == "midpoint":
|
178 |
+
raise NotImplementedError(
|
179 |
+
"{self.config.algorithm_type} with {self.config.solver_type} is currently not supported."
|
180 |
+
)
|
181 |
+
elif self.config.solver_type == "heun":
|
182 |
+
raise NotImplementedError(
|
183 |
+
"{self.config.algorithm_type} with {self.config.solver_type} is currently not supported."
|
184 |
+
)
|
185 |
+
return x_t
|
186 |
+
|
187 |
+
class MomentumUniPCMultistepScheduler(UniPCMultistepScheduler):
|
188 |
+
"""
|
189 |
+
UniPC with HB momentum.
|
190 |
+
Currently support only self.predict_x0 = True
|
191 |
+
|
192 |
+
When beta = 1.0, this method is equivalent to UniPC without momentum.
|
193 |
+
"""
|
194 |
+
def initialize_momentum(self, beta):
|
195 |
+
self.vel_p = None
|
196 |
+
self.vel_c = None
|
197 |
+
self.beta = beta
|
198 |
+
|
199 |
+
def multistep_uni_p_bh_update(
|
200 |
+
self,
|
201 |
+
model_output: torch.FloatTensor,
|
202 |
+
prev_timestep: int,
|
203 |
+
sample: torch.FloatTensor,
|
204 |
+
order: int,
|
205 |
+
) -> torch.FloatTensor:
|
206 |
+
|
207 |
+
timestep_list = self.timestep_list
|
208 |
+
model_output_list = self.model_outputs
|
209 |
+
|
210 |
+
s0, t = self.timestep_list[-1], prev_timestep
|
211 |
+
m0 = model_output_list[-1]
|
212 |
+
x = sample
|
213 |
+
|
214 |
+
if self.solver_p:
|
215 |
+
x_t = self.solver_p.step(model_output, s0, x).prev_sample
|
216 |
+
return x_t
|
217 |
+
|
218 |
+
lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
|
219 |
+
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
|
220 |
+
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
|
221 |
+
|
222 |
+
h = lambda_t - lambda_s0
|
223 |
+
device = sample.device
|
224 |
+
|
225 |
+
rks = []
|
226 |
+
D1s = []
|
227 |
+
for i in range(1, order):
|
228 |
+
si = timestep_list[-(i + 1)]
|
229 |
+
mi = model_output_list[-(i + 1)]
|
230 |
+
lambda_si = self.lambda_t[si]
|
231 |
+
rk = (lambda_si - lambda_s0) / h
|
232 |
+
rks.append(rk)
|
233 |
+
D1s.append((mi - m0) / rk)
|
234 |
+
|
235 |
+
rks.append(1.0)
|
236 |
+
rks = torch.tensor(rks, device=device)
|
237 |
+
|
238 |
+
R = []
|
239 |
+
b = []
|
240 |
+
|
241 |
+
hh = -h if self.predict_x0 else h
|
242 |
+
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
|
243 |
+
h_phi_k = h_phi_1 / hh - 1
|
244 |
+
|
245 |
+
factorial_i = 1
|
246 |
+
|
247 |
+
if self.config.solver_type == "bh1":
|
248 |
+
B_h = hh
|
249 |
+
elif self.config.solver_type == "bh2":
|
250 |
+
B_h = torch.expm1(hh)
|
251 |
+
else:
|
252 |
+
raise NotImplementedError()
|
253 |
+
|
254 |
+
for i in range(1, order + 1):
|
255 |
+
R.append(torch.pow(rks, i - 1))
|
256 |
+
b.append(h_phi_k * factorial_i / B_h)
|
257 |
+
factorial_i *= i + 1
|
258 |
+
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
259 |
+
|
260 |
+
R = torch.stack(R)
|
261 |
+
b = torch.tensor(b, device=device)
|
262 |
+
|
263 |
+
if len(D1s) > 0:
|
264 |
+
D1s = torch.stack(D1s, dim=1) # (B, K)
|
265 |
+
# for order 2, we use a simplified version
|
266 |
+
if order == 2:
|
267 |
+
rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
|
268 |
+
else:
|
269 |
+
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
|
270 |
+
else:
|
271 |
+
D1s = None
|
272 |
+
|
273 |
+
if self.predict_x0:
|
274 |
+
if D1s is not None:
|
275 |
+
pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s)
|
276 |
+
else:
|
277 |
+
pred_res = 0
|
278 |
+
|
279 |
+
val = ( h_phi_1 * m0 + B_h * pred_res ) /sigma_t /h_phi_1
|
280 |
+
if self.vel_p is None:
|
281 |
+
self.vel_p = val
|
282 |
+
else:
|
283 |
+
self.vel_p = (1-self.beta)*self.vel_p + self.beta * val
|
284 |
+
self.vel_p = val
|
285 |
+
|
286 |
+
x_t = sigma_t * (x/ sigma_s0 - alpha_t * self.vel_p * h_phi_1)
|
287 |
+
else:
|
288 |
+
raise NotImplementedError
|
289 |
+
|
290 |
+
x_t = x_t.to(x.dtype)
|
291 |
+
return x_t
|
292 |
+
|
293 |
+
def multistep_uni_c_bh_update(
|
294 |
+
self,
|
295 |
+
this_model_output: torch.FloatTensor,
|
296 |
+
this_timestep: int,
|
297 |
+
last_sample: torch.FloatTensor,
|
298 |
+
this_sample: torch.FloatTensor,
|
299 |
+
order: int,
|
300 |
+
) -> torch.FloatTensor:
|
301 |
+
|
302 |
+
timestep_list = self.timestep_list
|
303 |
+
model_output_list = self.model_outputs
|
304 |
+
|
305 |
+
s0, t = timestep_list[-1], this_timestep
|
306 |
+
m0 = model_output_list[-1]
|
307 |
+
x = last_sample
|
308 |
+
x_t = this_sample
|
309 |
+
model_t = this_model_output
|
310 |
+
|
311 |
+
lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
|
312 |
+
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
|
313 |
+
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
|
314 |
+
|
315 |
+
h = lambda_t - lambda_s0
|
316 |
+
device = this_sample.device
|
317 |
+
|
318 |
+
rks = []
|
319 |
+
D1s = []
|
320 |
+
for i in range(1, order):
|
321 |
+
si = timestep_list[-(i + 1)]
|
322 |
+
mi = model_output_list[-(i + 1)]
|
323 |
+
lambda_si = self.lambda_t[si]
|
324 |
+
rk = (lambda_si - lambda_s0) / h
|
325 |
+
rks.append(rk)
|
326 |
+
D1s.append((mi - m0) / rk)
|
327 |
+
|
328 |
+
rks.append(1.0)
|
329 |
+
rks = torch.tensor(rks, device=device)
|
330 |
+
|
331 |
+
R = []
|
332 |
+
b = []
|
333 |
+
|
334 |
+
hh = -h if self.predict_x0 else h
|
335 |
+
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
|
336 |
+
h_phi_k = h_phi_1 / hh - 1
|
337 |
+
|
338 |
+
factorial_i = 1
|
339 |
+
|
340 |
+
if self.config.solver_type == "bh1":
|
341 |
+
B_h = hh
|
342 |
+
elif self.config.solver_type == "bh2":
|
343 |
+
B_h = torch.expm1(hh)
|
344 |
+
else:
|
345 |
+
raise NotImplementedError()
|
346 |
+
|
347 |
+
for i in range(1, order + 1):
|
348 |
+
R.append(torch.pow(rks, i - 1))
|
349 |
+
b.append(h_phi_k * factorial_i / B_h)
|
350 |
+
factorial_i *= i + 1
|
351 |
+
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
352 |
+
|
353 |
+
R = torch.stack(R)
|
354 |
+
b = torch.tensor(b, device=device)
|
355 |
+
|
356 |
+
if len(D1s) > 0:
|
357 |
+
D1s = torch.stack(D1s, dim=1)
|
358 |
+
else:
|
359 |
+
D1s = None
|
360 |
+
|
361 |
+
# for order 1, we use a simplified version
|
362 |
+
if order == 1:
|
363 |
+
rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
|
364 |
+
else:
|
365 |
+
rhos_c = torch.linalg.solve(R, b)
|
366 |
+
|
367 |
+
if self.predict_x0:
|
368 |
+
if D1s is not None:
|
369 |
+
corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s)
|
370 |
+
else:
|
371 |
+
corr_res = 0
|
372 |
+
D1_t = model_t - m0
|
373 |
+
|
374 |
+
val = (h_phi_1 * m0 + B_h * (corr_res + rhos_c[-1] * D1_t))/sigma_t/h_phi_1
|
375 |
+
if self.vel_c is None:
|
376 |
+
self.vel_c = val
|
377 |
+
else:
|
378 |
+
self.vel_c = (1-self.beta)*self.vel_c + self.beta * val
|
379 |
+
|
380 |
+
x_t = sigma_t * (x/ sigma_s0 - alpha_t * self.vel_c * h_phi_1)
|
381 |
+
else:
|
382 |
+
raise NotImplementedError
|
383 |
+
|
384 |
+
x_t = x_t.to(x.dtype)
|
385 |
+
return x_t
|
pipeline.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler, UniPCMultistepScheduler
|
5 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
6 |
+
from typing import Union, Optional, List, Callable, Dict, Any, Tuple
|
7 |
+
from momentum_scheduler import (
|
8 |
+
GHVBScheduler,
|
9 |
+
PLMSWithHBScheduler,
|
10 |
+
PLMSWithNTScheduler,
|
11 |
+
MomentumDPMSolverMultistepScheduler,
|
12 |
+
MomentumUniPCMultistepScheduler,
|
13 |
+
)
|
14 |
+
|
15 |
+
available_solvers = {
|
16 |
+
"GHVB": GHVBScheduler,
|
17 |
+
"PLMS_HB": PLMSWithHBScheduler,
|
18 |
+
"PLMS_NT": PLMSWithNTScheduler,
|
19 |
+
"DPM-Solver++": MomentumDPMSolverMultistepScheduler,
|
20 |
+
"UniPC": MomentumUniPCMultistepScheduler,
|
21 |
+
}
|
22 |
+
|
23 |
+
def get_momentum_number(order, beta):
|
24 |
+
out = order if beta == 1.0 else order - (1 - beta)
|
25 |
+
return out
|
26 |
+
|
27 |
+
def setup_scheduler(pipe, scheduler, momentum_type="Polyak's heavy ball", order=4.0, beta=1.0, original_config=None):
|
28 |
+
assert original_config is not None
|
29 |
+
|
30 |
+
if scheduler in ["DPM-Solver++", "UniPC"]:
|
31 |
+
if momentum_type in ["Nesterov"]:
|
32 |
+
raise NotImplementedError(f"{scheduler} w/ Nesterov is not implemented.")
|
33 |
+
|
34 |
+
pipe.scheduler = available_solvers[scheduler].from_config(original_config)
|
35 |
+
pipe.scheduler.initialize_momentum(beta=beta)
|
36 |
+
|
37 |
+
elif scheduler in ["PLMS"]:
|
38 |
+
momentum_number = get_momentum_number(order, beta)
|
39 |
+
method = "PLMS_HB" if momentum_type == "Polyak's heavy ball" else "PLMS_NT"
|
40 |
+
pipe.scheduler = DPMSolverMultistepScheduler.from_config(original_config)
|
41 |
+
pipe.init_scheduler(method=method, order=momentum_number)
|
42 |
+
pipe.clear_scheduler()
|
43 |
+
|
44 |
+
elif scheduler in ["GHVB"]:
|
45 |
+
momentum_number = get_momentum_number(order, beta)
|
46 |
+
pipe.scheduler = DPMSolverMultistepScheduler.from_config(original_config)
|
47 |
+
pipe.init_scheduler(method="GHVB", order=momentum_number)
|
48 |
+
pipe.clear_scheduler()
|
49 |
+
|
50 |
+
return pipe
|
51 |
+
|
52 |
+
class CustomPipeline(StableDiffusionPipeline):
|
53 |
+
def clear_scheduler(self):
|
54 |
+
self.scheduler_uncond.clear()
|
55 |
+
self.scheduler_text.clear()
|
56 |
+
|
57 |
+
def init_scheduler(self, method, order):
|
58 |
+
# equivalent to not applied numerical operator splitting since orders are the same
|
59 |
+
self.scheduler_uncond = available_solvers[method](self.scheduler, order)
|
60 |
+
self.scheduler_text = available_solvers[method](self.scheduler, order)
|
61 |
+
|
62 |
+
def get_noise(self, latents, prompt_embeds, guidance_scale, t, do_classifier_free_guidance):
|
63 |
+
# expand the latents if we are doing classifier free guidance
|
64 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
65 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
66 |
+
|
67 |
+
# predict the noise residual
|
68 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
|
69 |
+
|
70 |
+
if do_classifier_free_guidance:
|
71 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
72 |
+
grads_a = guidance_scale * (noise_pred_text - noise_pred_uncond)
|
73 |
+
|
74 |
+
return noise_pred_uncond, grads_a
|
75 |
+
|
76 |
+
def denoising_step(
|
77 |
+
self,
|
78 |
+
latents,
|
79 |
+
prompt_embeds,
|
80 |
+
guidance_scale,
|
81 |
+
t,
|
82 |
+
do_classifier_free_guidance,
|
83 |
+
method,
|
84 |
+
extra_step_kwargs,
|
85 |
+
):
|
86 |
+
noise_pred_uncond, grads_a = self.get_noise(
|
87 |
+
latents, prompt_embeds, guidance_scale, t, do_classifier_free_guidance
|
88 |
+
)
|
89 |
+
if method in ["dpm", "unipc"]:
|
90 |
+
latents = self.scheduler.step(noise_pred_uncond + grads_a, t, latents, **extra_step_kwargs).prev_sample
|
91 |
+
|
92 |
+
elif method in ["hb", "ghvb", "nt"]:
|
93 |
+
latents = self.scheduler_uncond.step(noise_pred_uncond, t, latents, output_mode="scale")
|
94 |
+
latents = self.scheduler_text.step(grads_a, t, latents, output_mode='back')
|
95 |
+
else:
|
96 |
+
raise NotImplementedError
|
97 |
+
|
98 |
+
return latents
|
99 |
+
|
100 |
+
@torch.no_grad()
|
101 |
+
def __call__(
|
102 |
+
self,
|
103 |
+
prompt: Union[str, List[str]] = None,
|
104 |
+
height: Optional[int] = None,
|
105 |
+
width: Optional[int] = None,
|
106 |
+
num_inference_steps: int = 50,
|
107 |
+
guidance_scale: float = 7.5,
|
108 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
109 |
+
num_images_per_prompt: Optional[int] = 1,
|
110 |
+
eta: float = 0.0,
|
111 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
112 |
+
latents: Optional[torch.FloatTensor] = None,
|
113 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
114 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
115 |
+
output_type: Optional[str] = "pil",
|
116 |
+
return_dict: bool = True,
|
117 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
118 |
+
callback_steps: int = 1,
|
119 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
120 |
+
method="ghvb",
|
121 |
+
):
|
122 |
+
# 0. Default height and width to unet
|
123 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
124 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
125 |
+
|
126 |
+
# 1. Check inputs. Raise error if not correct
|
127 |
+
self.check_inputs(
|
128 |
+
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
|
129 |
+
)
|
130 |
+
|
131 |
+
# 2. Define call parameters
|
132 |
+
if prompt is not None and isinstance(prompt, str):
|
133 |
+
batch_size = 1
|
134 |
+
elif prompt is not None and isinstance(prompt, list):
|
135 |
+
batch_size = len(prompt)
|
136 |
+
else:
|
137 |
+
batch_size = prompt_embeds.shape[0]
|
138 |
+
|
139 |
+
device = self._execution_device
|
140 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
141 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
142 |
+
# corresponds to doing no classifier free guidance.
|
143 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
144 |
+
|
145 |
+
# 3. Encode input prompt
|
146 |
+
prompt_embeds = self._encode_prompt(
|
147 |
+
prompt,
|
148 |
+
device,
|
149 |
+
num_images_per_prompt,
|
150 |
+
do_classifier_free_guidance,
|
151 |
+
negative_prompt,
|
152 |
+
prompt_embeds=prompt_embeds,
|
153 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
154 |
+
)
|
155 |
+
|
156 |
+
# 4. Prepare timesteps
|
157 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
158 |
+
timesteps = self.scheduler.timesteps
|
159 |
+
# print(timesteps)
|
160 |
+
|
161 |
+
# 5. Prepare latent variables
|
162 |
+
num_channels_latents = self.unet.config.in_channels
|
163 |
+
latents = self.prepare_latents(
|
164 |
+
batch_size * num_images_per_prompt,
|
165 |
+
num_channels_latents,
|
166 |
+
height,
|
167 |
+
width,
|
168 |
+
prompt_embeds.dtype,
|
169 |
+
device,
|
170 |
+
generator,
|
171 |
+
latents,
|
172 |
+
)
|
173 |
+
|
174 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
175 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
176 |
+
|
177 |
+
# 7. Denoising loop
|
178 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
179 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
180 |
+
for i, t in enumerate(timesteps):
|
181 |
+
latents = self.denoising_step(
|
182 |
+
latents,
|
183 |
+
prompt_embeds,
|
184 |
+
guidance_scale,
|
185 |
+
t,
|
186 |
+
do_classifier_free_guidance,
|
187 |
+
method,
|
188 |
+
extra_step_kwargs,
|
189 |
+
)
|
190 |
+
|
191 |
+
# call the callback, if provided
|
192 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
193 |
+
progress_bar.update()
|
194 |
+
if callback is not None and i % callback_steps == 0:
|
195 |
+
callback(i, t, latents)
|
196 |
+
|
197 |
+
if output_type == "latent":
|
198 |
+
image = latents
|
199 |
+
has_nsfw_concept = None
|
200 |
+
elif output_type == "pil":
|
201 |
+
# 8. Post-processing
|
202 |
+
image = self.decode_latents(latents)
|
203 |
+
|
204 |
+
# 9. Run safety checker
|
205 |
+
# image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
206 |
+
has_nsfw_concept = False
|
207 |
+
|
208 |
+
# 10. Convert to PIL
|
209 |
+
image = self.numpy_to_pil(image)
|
210 |
+
else:
|
211 |
+
# 8. Post-processing
|
212 |
+
image = self.decode_latents(latents)
|
213 |
+
|
214 |
+
# 9. Run safety checker
|
215 |
+
# image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
216 |
+
has_nsfw_concept = False
|
217 |
+
|
218 |
+
# Offload last model to CPU
|
219 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
220 |
+
self.final_offload_hook.offload()
|
221 |
+
|
222 |
+
if not return_dict:
|
223 |
+
return (image, has_nsfw_concept)
|
224 |
+
|
225 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
226 |
+
|
227 |
+
def generate(self, params):
|
228 |
+
params["output_type"] = "latent"
|
229 |
+
ori_latents = self.__call__(**params)["images"]
|
230 |
+
|
231 |
+
with torch.no_grad():
|
232 |
+
latents = torch.clone(ori_latents)
|
233 |
+
image = self.decode_latents(latents)
|
234 |
+
image = self.numpy_to_pil(image)[0]
|
235 |
+
|
236 |
+
return image, ori_latents
|
requirements.txt
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==0.21.0
|
2 |
+
aiofiles==23.1.0
|
3 |
+
aiohttp==3.8.4
|
4 |
+
aiosignal==1.3.1
|
5 |
+
altair==5.0.1
|
6 |
+
annotated-types==0.5.0
|
7 |
+
anyio==3.7.1
|
8 |
+
async-timeout==4.0.2
|
9 |
+
attrs==23.1.0
|
10 |
+
certifi==2023.5.7
|
11 |
+
charset-normalizer==3.2.0
|
12 |
+
click==8.1.5
|
13 |
+
cmake==3.26.4
|
14 |
+
contourpy==1.1.0
|
15 |
+
cycler==0.11.0
|
16 |
+
diffusers==0.15.0
|
17 |
+
exceptiongroup==1.1.2
|
18 |
+
fastapi==0.100.0
|
19 |
+
ffmpy==0.3.0
|
20 |
+
filelock==3.12.2
|
21 |
+
fonttools==4.41.0
|
22 |
+
frozenlist==1.4.0
|
23 |
+
fsspec==2023.6.0
|
24 |
+
gradio==3.36.1
|
25 |
+
gradio_client==0.2.9
|
26 |
+
h11==0.14.0
|
27 |
+
httpcore==0.17.3
|
28 |
+
httpx==0.24.1
|
29 |
+
huggingface-hub==0.16.4
|
30 |
+
idna==3.4
|
31 |
+
importlib-metadata==6.8.0
|
32 |
+
importlib-resources==6.0.0
|
33 |
+
Jinja2==3.1.2
|
34 |
+
jsonschema==4.18.3
|
35 |
+
jsonschema-specifications==2023.6.1
|
36 |
+
kiwisolver==1.4.4
|
37 |
+
linkify-it-py==2.0.2
|
38 |
+
lit==16.0.6
|
39 |
+
markdown-it-py==2.2.0
|
40 |
+
MarkupSafe==2.1.3
|
41 |
+
matplotlib==3.7.2
|
42 |
+
mdit-py-plugins==0.3.3
|
43 |
+
mdurl==0.1.2
|
44 |
+
mpmath==1.3.0
|
45 |
+
multidict==6.0.4
|
46 |
+
networkx==3.1
|
47 |
+
numpy==1.25.1
|
48 |
+
nvidia-cublas-cu11==11.10.3.66
|
49 |
+
nvidia-cuda-cupti-cu11==11.7.101
|
50 |
+
nvidia-cuda-nvrtc-cu11==11.7.99
|
51 |
+
nvidia-cuda-runtime-cu11==11.7.99
|
52 |
+
nvidia-cudnn-cu11==8.5.0.96
|
53 |
+
nvidia-cufft-cu11==10.9.0.58
|
54 |
+
nvidia-curand-cu11==10.2.10.91
|
55 |
+
nvidia-cusolver-cu11==11.4.0.1
|
56 |
+
nvidia-cusparse-cu11==11.7.4.91
|
57 |
+
nvidia-nccl-cu11==2.14.3
|
58 |
+
nvidia-nvtx-cu11==11.7.91
|
59 |
+
orjson==3.9.2
|
60 |
+
packaging==23.1
|
61 |
+
pandas==2.0.3
|
62 |
+
Pillow==10.0.0
|
63 |
+
psutil==5.9.5
|
64 |
+
pydantic==2.0.2
|
65 |
+
pydantic_core==2.1.2
|
66 |
+
pydub==0.25.1
|
67 |
+
Pygments==2.15.1
|
68 |
+
pyparsing==3.0.9
|
69 |
+
python-dateutil==2.8.2
|
70 |
+
python-multipart==0.0.6
|
71 |
+
pytz==2023.3
|
72 |
+
PyYAML==6.0
|
73 |
+
referencing==0.29.1
|
74 |
+
regex==2023.6.3
|
75 |
+
requests==2.31.0
|
76 |
+
rpds-py==0.8.10
|
77 |
+
semantic-version==2.10.0
|
78 |
+
six==1.16.0
|
79 |
+
sniffio==1.3.0
|
80 |
+
starlette==0.27.0
|
81 |
+
sympy==1.12
|
82 |
+
tokenizers==0.13.3
|
83 |
+
tomesd==0.1.3
|
84 |
+
toolz==0.12.0
|
85 |
+
torch==2.0.1
|
86 |
+
tqdm==4.65.0
|
87 |
+
transformers==4.28.1
|
88 |
+
triton==2.0.0
|
89 |
+
typing_extensions==4.7.1
|
90 |
+
tzdata==2023.3
|
91 |
+
uc-micro-py==1.0.2
|
92 |
+
urllib3==2.0.3
|
93 |
+
uvicorn==0.22.0
|
94 |
+
websockets==11.0.3
|
95 |
+
yarl==1.9.2
|
96 |
+
zipp==3.16.1
|