sincostanx commited on
Commit
2b4e6ab
1 Parent(s): 8421ea7

add prototype

Browse files
Files changed (4) hide show
  1. app.py +217 -4
  2. momentum_scheduler.py +385 -0
  3. pipeline.py +236 -0
  4. requirements.txt +96 -0
app.py CHANGED
@@ -1,7 +1,220 @@
1
  import gradio as gr
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
5
 
6
- interface = gr.Interface(fn=greet, inputs="text", outputs="image")
7
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ from pipeline import CustomPipeline, setup_scheduler
4
+ from PIL import Image
5
+ # from easydict import EasyDict as edict
6
 
7
+ original_pipe = None
8
+ original_config = None
9
+ device = None
10
 
11
+
12
+ # def run_dpm_demo(id, prompt, beta, num_inference_steps, guidance_scale, seed, enable_token_merging):
13
+ def run_dpm_demo(prompt, beta, num_inference_steps, guidance_scale, seed):
14
+ global original_pipe, original_config
15
+ pipe = CustomPipeline(**original_pipe.components)
16
+
17
+ seed = int(seed)
18
+ num_inference_steps = int(num_inference_steps)
19
+
20
+ scheduler = "DPM-Solver++"
21
+ params = {
22
+ "prompt": prompt,
23
+ "num_inference_steps": num_inference_steps,
24
+ "guidance_scale": guidance_scale,
25
+ "method": "dpm"
26
+ }
27
+
28
+ # without momentum (equivalent to DPM-Solver++)
29
+ pipe = setup_scheduler(pipe, scheduler, beta=1.0, original_config=original_config)
30
+ params["generator"] = torch.Generator(device=device).manual_seed(seed)
31
+ ori_image = pipe(**params).images[0]
32
+
33
+ # with momentum
34
+ pipe = setup_scheduler(pipe, scheduler, beta=beta, original_config=original_config)
35
+ params["generator"] = torch.Generator(device=device).manual_seed(seed)
36
+ image = pipe(**params).images[0]
37
+
38
+ ori_image.save("temp1.png")
39
+ image.save("temp2.png")
40
+
41
+ return [ori_image, image]
42
+
43
+ # def run_plms_demo(id, prompt, order, beta, momentum_type, num_inference_steps, guidance_scale, seed, enable_token_merging):
44
+ def run_plms_demo(prompt, order, beta, momentum_type, num_inference_steps, guidance_scale, seed):
45
+ global original_pipe, original_config
46
+ pipe = CustomPipeline(**original_pipe.components)
47
+
48
+ seed = int(seed)
49
+ num_inference_steps = int(num_inference_steps)
50
+
51
+ scheduler = "PLMS"
52
+ method = "hb" if momentum_type == "Polyak's heavy ball" else "nt"
53
+ params = {
54
+ "prompt": prompt,
55
+ "num_inference_steps": num_inference_steps,
56
+ "guidance_scale": guidance_scale,
57
+ "method": method
58
+ }
59
+
60
+ # without momentum (equivalent to PLMS)
61
+ pipe = setup_scheduler(pipe, scheduler, momentum_type=momentum_type, order=order, beta=1.0, original_config=original_config)
62
+ params["generator"] = torch.Generator(device=device).manual_seed(seed)
63
+ ori_image = pipe(**params).images[0]
64
+
65
+ # with momentum
66
+ pipe = setup_scheduler(pipe, scheduler, momentum_type=momentum_type, order=order, beta=beta, original_config=original_config)
67
+ params["generator"] = torch.Generator(device=device).manual_seed(seed)
68
+ image = pipe(**params).images[0]
69
+
70
+ return [ori_image, image]
71
+
72
+ # def run_ghvb_demo(id, prompt, order, beta, num_inference_steps, guidance_scale, seed, enable_token_merging):
73
+ def run_ghvb_demo(prompt, order, beta, num_inference_steps, guidance_scale, seed):
74
+ global original_pipe, original_config
75
+ pipe = CustomPipeline(**original_pipe.components)
76
+
77
+ seed = int(seed)
78
+ num_inference_steps = int(num_inference_steps)
79
+
80
+ scheduler = "GHVB"
81
+ params = {
82
+ "prompt": prompt,
83
+ "num_inference_steps": num_inference_steps,
84
+ "guidance_scale": guidance_scale,
85
+ "method": "ghvb"
86
+ }
87
+
88
+ # without momentum (equivalent to PLMS)
89
+ pipe = setup_scheduler(pipe, scheduler, order=order, beta=1.0, original_config=original_config)
90
+ params["generator"] = torch.Generator(device=device).manual_seed(seed)
91
+ ori_image = pipe(**params).images[0]
92
+
93
+ # with momentum
94
+ pipe = setup_scheduler(pipe, scheduler, order=order, beta=beta, original_config=original_config)
95
+ params["generator"] = torch.Generator(device=device).manual_seed(seed)
96
+ image = pipe(**params).images[0]
97
+
98
+ return [ori_image, image]
99
+
100
+ if __name__ == "__main__":
101
+
102
+ demo = gr.Blocks()
103
+
104
+ inputs = {}
105
+ outputs = {}
106
+ buttons = {}
107
+
108
+ list_models = [
109
+
110
+ ]
111
+
112
+ with gr.Blocks() as demo:
113
+ gr.Markdown(
114
+ """
115
+ # Momentum-Diffusion Demo
116
+
117
+ A novel sampling method for diffusion models based on momentum to reduce artifacts
118
+
119
+ """
120
+ )
121
+ id = gr.Dropdown(list_models, label="Model ID", value="Linaqruf/anything-v3.0", allow_custom_value=True)
122
+ enable_token_merging = gr.Checkbox(label="Enable Token Merging", value=False)
123
+ # output = gr.Textbox()
124
+ buttons["select_model"] = gr.Button("Select")
125
+
126
+ with gr.Tab("GHVB", visible=False) as tab3:
127
+ prompt3 = gr.Textbox(label="Prompt", value="a cozy cafe", visible=False)
128
+
129
+ with gr.Row(visible=False) as row31:
130
+ order = gr.Slider(minimum=1, maximum=4, value=4, step=1, label="order")
131
+ beta = gr.Slider(minimum=0, maximum=1, value=0.4, step=0.05, label="beta")
132
+ num_inference_steps = gr.Number(label="Number of steps", value=12)
133
+ guidance_scale = gr.Number(label="Guidance scale (cfg)", value=10)
134
+ seed = gr.Number(label="Seed", value=42)
135
+
136
+ with gr.Row(visible=False) as row32:
137
+ out1 = gr.Image(label="PLMS", interactive=False)
138
+ out2 = gr.Image(label="GHVB", interactive=False)
139
+
140
+ inputs["GHVB"] = [prompt3, order, beta, num_inference_steps, guidance_scale, seed]
141
+ outputs["GHVB"] = [out1, out2]
142
+ buttons["GHVB"] = gr.Button("Sample", visible=False)
143
+
144
+ with gr.Tab("PLMS", visible=False) as tab2:
145
+ prompt2 = gr.Textbox(label="Prompt", value="1girl", visible=False)
146
+
147
+ with gr.Row(visible=False) as row21:
148
+ order = gr.Slider(minimum=1, maximum=4, value=4, step=1, label="order")
149
+ beta = gr.Slider(minimum=0, maximum=1, value=0.7, step=0.05, label="beta")
150
+ momentum_type = gr.Dropdown(["Polyak's heavy ball", "Nesterov"], label="Momentum Type", value="Polyak's heavy ball")
151
+ num_inference_steps = gr.Number(label="Number of steps", value=10)
152
+ guidance_scale = gr.Number(label="Guidance scale (cfg)", value=10)
153
+ seed = gr.Number(label="Seed", value=42)
154
+
155
+ with gr.Row(visible=False) as row22:
156
+ out1 = gr.Image(label="Without momentum", interactive=False)
157
+ out2 = gr.Image(label="With momentum", interactive=False)
158
+
159
+ inputs["PLMS"] = [prompt2, order, beta, momentum_type, num_inference_steps, guidance_scale, seed]
160
+ outputs["PLMS"] = [out1, out2]
161
+ buttons["PLMS"] = gr.Button("Sample", visible=False)
162
+
163
+ with gr.Tab("DPM-Solver++", visible=False) as tab1:
164
+ prompt1 = gr.Textbox(label="Prompt", value="1girl", visible=False)
165
+
166
+ with gr.Row(visible=False) as row11:
167
+ beta = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.05, label="beta")
168
+ num_inference_steps = gr.Number(label="Number of steps", value=15)
169
+ guidance_scale = gr.Number(label="Guidance scale (cfg)", value=20)
170
+ seed = gr.Number(label="Seed", value=0)
171
+
172
+ with gr.Row(visible=False) as row12:
173
+ out1 = gr.Image(label="Without momentum", interactive=False)
174
+ out2 = gr.Image(label="With momentum", interactive=False)
175
+
176
+ inputs["DPM-Solver++"] = [prompt1, beta, num_inference_steps, guidance_scale, seed]
177
+ outputs["DPM-Solver++"] = [out1, out2]
178
+ buttons["DPM-Solver++"] = gr.Button("Sample", visible=False)
179
+
180
+ def prepare_model(id, enable_token_merging):
181
+ global original_pipe, original_config, device
182
+
183
+ if original_pipe is not None:
184
+ del original_pipe
185
+
186
+ original_pipe = CustomPipeline.from_pretrained(id)
187
+ device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
188
+ original_pipe = original_pipe.to(device)
189
+
190
+ if enable_token_merging:
191
+ import tomesd
192
+ tomesd.apply_patch(original_pipe, ratio=0.5)
193
+ print("Enabled Token merging.")
194
+
195
+ original_config = original_pipe.scheduler.config
196
+ print(type(original_pipe))
197
+ print(original_config)
198
+
199
+ return {
200
+ row11: gr.update(visible=True),
201
+ row12: gr.update(visible=True),
202
+ row21: gr.update(visible=True),
203
+ row22: gr.update(visible=True),
204
+ row31: gr.update(visible=True),
205
+ row32: gr.update(visible=True),
206
+ prompt1: gr.update(visible=True),
207
+ prompt2: gr.update(visible=True),
208
+ prompt3: gr.update(visible=True),
209
+ buttons["DPM-Solver++"]: gr.update(visible=True),
210
+ buttons["PLMS"]: gr.update(visible=True),
211
+ buttons["GHVB"]: gr.update(visible=True),
212
+ }
213
+
214
+ all_outputs = [row11, row12, row21, row22, row31, row32, prompt1, prompt2, prompt3, buttons["DPM-Solver++"], buttons["PLMS"], buttons["GHVB"]]
215
+ buttons["select_model"].click(prepare_model, inputs=[id, enable_token_merging], outputs=all_outputs)
216
+ buttons["DPM-Solver++"].click(run_dpm_demo, inputs=inputs["DPM-Solver++"], outputs=outputs["DPM-Solver++"])
217
+ buttons["PLMS"].click(run_plms_demo, inputs=inputs["PLMS"], outputs=outputs["PLMS"])
218
+ buttons["GHVB"].click(run_ghvb_demo, inputs=inputs["GHVB"], outputs=outputs["GHVB"])
219
+
220
+ demo.launch(share=True)
momentum_scheduler.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import DPMSolverMultistepScheduler, UniPCMultistepScheduler
3
+ from typing import List
4
+
5
+ def AdamBmixer(order, ets, b=1):
6
+
7
+ cur_order = min(order, len(ets))
8
+ if cur_order == 1:
9
+ prime = b * ets[-1]
10
+ elif cur_order == 2:
11
+ prime = ((2+b) * ets[-1] - (2-b)*ets[-2]) / 2
12
+ elif cur_order == 3:
13
+ prime = ((18+5*b) * ets[-1] - (24-8*b) * ets[-2] + (6-1*b) * ets[-3]) / 12
14
+ elif cur_order == 4:
15
+ prime = ((46+9*b) * ets[-1] - (78-19*b) * ets[-2] + (42-5*b) * ets[-3] - (10-b) * ets[-4]) / 24
16
+ elif cur_order == 5:
17
+ prime = ((1650+251*b) * ets[-1] - (3420-646*b) * ets[-2]
18
+ + (2880-264*b) * ets[-3] - (1380-106*b) * ets[-4]
19
+ + (270-19*b)* ets[-5]) / 720
20
+ else:
21
+ raise NotImplementedError
22
+
23
+ prime = prime/b
24
+ return prime
25
+
26
+ class PLMSWithHBScheduler():
27
+ """
28
+ PLMS with Polyak's Heavy Ball Momentum (HB) for diffusion ODEs.
29
+ We implement it as a wrapper for schedulers in diffusers (https://github.com/huggingface/diffusers)
30
+
31
+ When order is an integer, this method is equivalent to PLMS without momentum.
32
+ """
33
+ def __init__(self, scheduler, order):
34
+ self.scheduler = scheduler
35
+ self.ets = []
36
+ self.update_order(order)
37
+ self.mixer = AdamBmixer
38
+
39
+ def update_order(self, order):
40
+ self.order = order // 1 + 1 if order%1 > 0 else order // 1
41
+ self.beta = order % 1 if order%1 > 0 else 1
42
+ self.vel = None
43
+
44
+ def clear(self):
45
+ self.ets = []
46
+ self.vel = None
47
+
48
+ def update_ets(self, val):
49
+ self.ets.append(val)
50
+ if len(self.ets) > self.order:
51
+ self.ets.pop(0)
52
+
53
+ def _step_with_momentum(self, grads):
54
+ self.update_ets(grads)
55
+ prime = self.mixer(self.order, self.ets, 1.0)
56
+ self.vel = (1 - self.beta) * self.vel + self.beta * prime
57
+ return self.vel
58
+
59
+ def step(
60
+ self,
61
+ grads: torch.FloatTensor,
62
+ timestep: int,
63
+ latents: torch.FloatTensor,
64
+ output_mode: str = "scale",
65
+ ):
66
+ if self.vel is None: self.vel = grads
67
+
68
+ if hasattr(self.scheduler, 'sigmas'):
69
+ step_index = (self.scheduler.timesteps == timestep).nonzero().item()
70
+ sigma = self.scheduler.sigmas[step_index]
71
+ sigma_next = self.scheduler.sigmas[step_index + 1]
72
+ del_g = sigma_next - sigma
73
+
74
+ update_val = self._step_with_momentum(grads)
75
+ return latents + del_g * update_val
76
+
77
+ elif isinstance(self.scheduler, DPMSolverMultistepScheduler):
78
+ step_index = (self.scheduler.timesteps == timestep).nonzero().item()
79
+ current_timestep = self.scheduler.timesteps[step_index]
80
+ prev_timestep = 0 if step_index == len(self.scheduler.timesteps) - 1 else self.scheduler.timesteps[step_index + 1]
81
+
82
+ alpha_prod_t = self.scheduler.alphas_cumprod[current_timestep]
83
+ alpha_bar_prev = self.scheduler.alphas_cumprod[prev_timestep]
84
+
85
+ s0 = torch.sqrt(alpha_prod_t)
86
+ s_1 = torch.sqrt(alpha_bar_prev)
87
+ g0 = torch.sqrt(1-alpha_prod_t)/s0
88
+ g_1 = torch.sqrt(1-alpha_bar_prev)/s_1
89
+ del_g = g_1 - g0
90
+
91
+ update_val = self._step_with_momentum(grads)
92
+ if output_mode in ["scale"]:
93
+ return (latents/s0 + del_g * update_val) * s_1
94
+ elif output_mode in ["back"]:
95
+ return latents + del_g * update_val * s_1
96
+ elif output_mode in ["front"]:
97
+ return latents + del_g * update_val * s0
98
+ else:
99
+ return latents + del_g * update_val
100
+ else:
101
+ raise NotImplementedError
102
+
103
+ class GHVBScheduler(PLMSWithHBScheduler):
104
+ """
105
+ Generalizing Polyak's Heavy Bal (GHVB) for diffusion ODEs.
106
+ We implement it as a wrapper for schedulers in diffusers (https://github.com/huggingface/diffusers)
107
+
108
+ When order is an integer, this method is equivalent to PLMS without momentum.
109
+ """
110
+ def _step_with_momentum(self, grads):
111
+ self.vel = (1 - self.beta) * self.vel + self.beta * grads
112
+ self.update_ets(self.vel)
113
+ prime = self.mixer(self.order, self.ets, self.beta)
114
+ return prime
115
+
116
+ class PLMSWithNTScheduler(PLMSWithHBScheduler):
117
+ """
118
+ PLMS with Nesterov Momentum (NT) for diffusion ODEs.
119
+ We implement it as a wrapper for schedulers in diffusers (https://github.com/huggingface/diffusers)
120
+
121
+ When order is an integer, this method is equivalent to PLMS without momentum.
122
+ """
123
+ def _step_with_momentum(self, grads):
124
+ self.update_ets(grads)
125
+ prime = self.mixer(self.order, self.ets, 1.0) # update v^{(2)}
126
+ self.vel = (1 - self.beta) * self.vel + self.beta * prime # update v^{(1)}
127
+ update_val = (1 - self.beta) * self.vel + self.beta * prime # update x
128
+ return update_val
129
+
130
+ class MomentumDPMSolverMultistepScheduler(DPMSolverMultistepScheduler):
131
+ """
132
+ DPM-Solver++2M with HB momentum.
133
+ Currently support only algorithm_type = "dpmsolver++" and solver_type = "midpoint"
134
+
135
+ When beta = 1.0, this method is equivalent to DPM-Solver++2M without momentum.
136
+ """
137
+ def initialize_momentum(self, beta):
138
+ self.vel = None
139
+ self.beta = beta
140
+
141
+ def multistep_dpm_solver_second_order_update(
142
+ self,
143
+ model_output_list: List[torch.FloatTensor],
144
+ timestep_list: List[int],
145
+ prev_timestep: int,
146
+ sample: torch.FloatTensor,
147
+ ) -> torch.FloatTensor:
148
+
149
+ t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
150
+ m0, m1 = model_output_list[-1], model_output_list[-2]
151
+ lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]
152
+ alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
153
+ sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
154
+ h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
155
+ r0 = h_0 / h
156
+ D0, D1 = m0, (1.0 / r0) * (m0 - m1)
157
+ if self.config.algorithm_type == "dpmsolver++":
158
+ # See https://arxiv.org/abs/2211.01095 for detailed derivations
159
+ if self.config.solver_type == "midpoint":
160
+ diff = (D0 + 0.5 * D1)
161
+
162
+ if self.vel is None:
163
+ self.vel = diff
164
+ else:
165
+ self.vel = (1-self.beta)*self.vel + self.beta * diff
166
+
167
+ x_t = (
168
+ (sigma_t / sigma_s0) * sample
169
+ - (alpha_t * (torch.exp(-h) - 1.0)) * self.vel
170
+ )
171
+ elif self.config.solver_type == "heun":
172
+ raise NotImplementedError(
173
+ "{self.config.algorithm_type} with {self.config.solver_type} is currently not supported."
174
+ )
175
+ elif self.config.algorithm_type == "dpmsolver":
176
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
177
+ if self.config.solver_type == "midpoint":
178
+ raise NotImplementedError(
179
+ "{self.config.algorithm_type} with {self.config.solver_type} is currently not supported."
180
+ )
181
+ elif self.config.solver_type == "heun":
182
+ raise NotImplementedError(
183
+ "{self.config.algorithm_type} with {self.config.solver_type} is currently not supported."
184
+ )
185
+ return x_t
186
+
187
+ class MomentumUniPCMultistepScheduler(UniPCMultistepScheduler):
188
+ """
189
+ UniPC with HB momentum.
190
+ Currently support only self.predict_x0 = True
191
+
192
+ When beta = 1.0, this method is equivalent to UniPC without momentum.
193
+ """
194
+ def initialize_momentum(self, beta):
195
+ self.vel_p = None
196
+ self.vel_c = None
197
+ self.beta = beta
198
+
199
+ def multistep_uni_p_bh_update(
200
+ self,
201
+ model_output: torch.FloatTensor,
202
+ prev_timestep: int,
203
+ sample: torch.FloatTensor,
204
+ order: int,
205
+ ) -> torch.FloatTensor:
206
+
207
+ timestep_list = self.timestep_list
208
+ model_output_list = self.model_outputs
209
+
210
+ s0, t = self.timestep_list[-1], prev_timestep
211
+ m0 = model_output_list[-1]
212
+ x = sample
213
+
214
+ if self.solver_p:
215
+ x_t = self.solver_p.step(model_output, s0, x).prev_sample
216
+ return x_t
217
+
218
+ lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
219
+ alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
220
+ sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
221
+
222
+ h = lambda_t - lambda_s0
223
+ device = sample.device
224
+
225
+ rks = []
226
+ D1s = []
227
+ for i in range(1, order):
228
+ si = timestep_list[-(i + 1)]
229
+ mi = model_output_list[-(i + 1)]
230
+ lambda_si = self.lambda_t[si]
231
+ rk = (lambda_si - lambda_s0) / h
232
+ rks.append(rk)
233
+ D1s.append((mi - m0) / rk)
234
+
235
+ rks.append(1.0)
236
+ rks = torch.tensor(rks, device=device)
237
+
238
+ R = []
239
+ b = []
240
+
241
+ hh = -h if self.predict_x0 else h
242
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
243
+ h_phi_k = h_phi_1 / hh - 1
244
+
245
+ factorial_i = 1
246
+
247
+ if self.config.solver_type == "bh1":
248
+ B_h = hh
249
+ elif self.config.solver_type == "bh2":
250
+ B_h = torch.expm1(hh)
251
+ else:
252
+ raise NotImplementedError()
253
+
254
+ for i in range(1, order + 1):
255
+ R.append(torch.pow(rks, i - 1))
256
+ b.append(h_phi_k * factorial_i / B_h)
257
+ factorial_i *= i + 1
258
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
259
+
260
+ R = torch.stack(R)
261
+ b = torch.tensor(b, device=device)
262
+
263
+ if len(D1s) > 0:
264
+ D1s = torch.stack(D1s, dim=1) # (B, K)
265
+ # for order 2, we use a simplified version
266
+ if order == 2:
267
+ rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
268
+ else:
269
+ rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
270
+ else:
271
+ D1s = None
272
+
273
+ if self.predict_x0:
274
+ if D1s is not None:
275
+ pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s)
276
+ else:
277
+ pred_res = 0
278
+
279
+ val = ( h_phi_1 * m0 + B_h * pred_res ) /sigma_t /h_phi_1
280
+ if self.vel_p is None:
281
+ self.vel_p = val
282
+ else:
283
+ self.vel_p = (1-self.beta)*self.vel_p + self.beta * val
284
+ self.vel_p = val
285
+
286
+ x_t = sigma_t * (x/ sigma_s0 - alpha_t * self.vel_p * h_phi_1)
287
+ else:
288
+ raise NotImplementedError
289
+
290
+ x_t = x_t.to(x.dtype)
291
+ return x_t
292
+
293
+ def multistep_uni_c_bh_update(
294
+ self,
295
+ this_model_output: torch.FloatTensor,
296
+ this_timestep: int,
297
+ last_sample: torch.FloatTensor,
298
+ this_sample: torch.FloatTensor,
299
+ order: int,
300
+ ) -> torch.FloatTensor:
301
+
302
+ timestep_list = self.timestep_list
303
+ model_output_list = self.model_outputs
304
+
305
+ s0, t = timestep_list[-1], this_timestep
306
+ m0 = model_output_list[-1]
307
+ x = last_sample
308
+ x_t = this_sample
309
+ model_t = this_model_output
310
+
311
+ lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
312
+ alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
313
+ sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
314
+
315
+ h = lambda_t - lambda_s0
316
+ device = this_sample.device
317
+
318
+ rks = []
319
+ D1s = []
320
+ for i in range(1, order):
321
+ si = timestep_list[-(i + 1)]
322
+ mi = model_output_list[-(i + 1)]
323
+ lambda_si = self.lambda_t[si]
324
+ rk = (lambda_si - lambda_s0) / h
325
+ rks.append(rk)
326
+ D1s.append((mi - m0) / rk)
327
+
328
+ rks.append(1.0)
329
+ rks = torch.tensor(rks, device=device)
330
+
331
+ R = []
332
+ b = []
333
+
334
+ hh = -h if self.predict_x0 else h
335
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
336
+ h_phi_k = h_phi_1 / hh - 1
337
+
338
+ factorial_i = 1
339
+
340
+ if self.config.solver_type == "bh1":
341
+ B_h = hh
342
+ elif self.config.solver_type == "bh2":
343
+ B_h = torch.expm1(hh)
344
+ else:
345
+ raise NotImplementedError()
346
+
347
+ for i in range(1, order + 1):
348
+ R.append(torch.pow(rks, i - 1))
349
+ b.append(h_phi_k * factorial_i / B_h)
350
+ factorial_i *= i + 1
351
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
352
+
353
+ R = torch.stack(R)
354
+ b = torch.tensor(b, device=device)
355
+
356
+ if len(D1s) > 0:
357
+ D1s = torch.stack(D1s, dim=1)
358
+ else:
359
+ D1s = None
360
+
361
+ # for order 1, we use a simplified version
362
+ if order == 1:
363
+ rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
364
+ else:
365
+ rhos_c = torch.linalg.solve(R, b)
366
+
367
+ if self.predict_x0:
368
+ if D1s is not None:
369
+ corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s)
370
+ else:
371
+ corr_res = 0
372
+ D1_t = model_t - m0
373
+
374
+ val = (h_phi_1 * m0 + B_h * (corr_res + rhos_c[-1] * D1_t))/sigma_t/h_phi_1
375
+ if self.vel_c is None:
376
+ self.vel_c = val
377
+ else:
378
+ self.vel_c = (1-self.beta)*self.vel_c + self.beta * val
379
+
380
+ x_t = sigma_t * (x/ sigma_s0 - alpha_t * self.vel_c * h_phi_1)
381
+ else:
382
+ raise NotImplementedError
383
+
384
+ x_t = x_t.to(x.dtype)
385
+ return x_t
pipeline.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import numpy as np
4
+ from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler, UniPCMultistepScheduler
5
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
6
+ from typing import Union, Optional, List, Callable, Dict, Any, Tuple
7
+ from momentum_scheduler import (
8
+ GHVBScheduler,
9
+ PLMSWithHBScheduler,
10
+ PLMSWithNTScheduler,
11
+ MomentumDPMSolverMultistepScheduler,
12
+ MomentumUniPCMultistepScheduler,
13
+ )
14
+
15
+ available_solvers = {
16
+ "GHVB": GHVBScheduler,
17
+ "PLMS_HB": PLMSWithHBScheduler,
18
+ "PLMS_NT": PLMSWithNTScheduler,
19
+ "DPM-Solver++": MomentumDPMSolverMultistepScheduler,
20
+ "UniPC": MomentumUniPCMultistepScheduler,
21
+ }
22
+
23
+ def get_momentum_number(order, beta):
24
+ out = order if beta == 1.0 else order - (1 - beta)
25
+ return out
26
+
27
+ def setup_scheduler(pipe, scheduler, momentum_type="Polyak's heavy ball", order=4.0, beta=1.0, original_config=None):
28
+ assert original_config is not None
29
+
30
+ if scheduler in ["DPM-Solver++", "UniPC"]:
31
+ if momentum_type in ["Nesterov"]:
32
+ raise NotImplementedError(f"{scheduler} w/ Nesterov is not implemented.")
33
+
34
+ pipe.scheduler = available_solvers[scheduler].from_config(original_config)
35
+ pipe.scheduler.initialize_momentum(beta=beta)
36
+
37
+ elif scheduler in ["PLMS"]:
38
+ momentum_number = get_momentum_number(order, beta)
39
+ method = "PLMS_HB" if momentum_type == "Polyak's heavy ball" else "PLMS_NT"
40
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(original_config)
41
+ pipe.init_scheduler(method=method, order=momentum_number)
42
+ pipe.clear_scheduler()
43
+
44
+ elif scheduler in ["GHVB"]:
45
+ momentum_number = get_momentum_number(order, beta)
46
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(original_config)
47
+ pipe.init_scheduler(method="GHVB", order=momentum_number)
48
+ pipe.clear_scheduler()
49
+
50
+ return pipe
51
+
52
+ class CustomPipeline(StableDiffusionPipeline):
53
+ def clear_scheduler(self):
54
+ self.scheduler_uncond.clear()
55
+ self.scheduler_text.clear()
56
+
57
+ def init_scheduler(self, method, order):
58
+ # equivalent to not applied numerical operator splitting since orders are the same
59
+ self.scheduler_uncond = available_solvers[method](self.scheduler, order)
60
+ self.scheduler_text = available_solvers[method](self.scheduler, order)
61
+
62
+ def get_noise(self, latents, prompt_embeds, guidance_scale, t, do_classifier_free_guidance):
63
+ # expand the latents if we are doing classifier free guidance
64
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
65
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
66
+
67
+ # predict the noise residual
68
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
69
+
70
+ if do_classifier_free_guidance:
71
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
72
+ grads_a = guidance_scale * (noise_pred_text - noise_pred_uncond)
73
+
74
+ return noise_pred_uncond, grads_a
75
+
76
+ def denoising_step(
77
+ self,
78
+ latents,
79
+ prompt_embeds,
80
+ guidance_scale,
81
+ t,
82
+ do_classifier_free_guidance,
83
+ method,
84
+ extra_step_kwargs,
85
+ ):
86
+ noise_pred_uncond, grads_a = self.get_noise(
87
+ latents, prompt_embeds, guidance_scale, t, do_classifier_free_guidance
88
+ )
89
+ if method in ["dpm", "unipc"]:
90
+ latents = self.scheduler.step(noise_pred_uncond + grads_a, t, latents, **extra_step_kwargs).prev_sample
91
+
92
+ elif method in ["hb", "ghvb", "nt"]:
93
+ latents = self.scheduler_uncond.step(noise_pred_uncond, t, latents, output_mode="scale")
94
+ latents = self.scheduler_text.step(grads_a, t, latents, output_mode='back')
95
+ else:
96
+ raise NotImplementedError
97
+
98
+ return latents
99
+
100
+ @torch.no_grad()
101
+ def __call__(
102
+ self,
103
+ prompt: Union[str, List[str]] = None,
104
+ height: Optional[int] = None,
105
+ width: Optional[int] = None,
106
+ num_inference_steps: int = 50,
107
+ guidance_scale: float = 7.5,
108
+ negative_prompt: Optional[Union[str, List[str]]] = None,
109
+ num_images_per_prompt: Optional[int] = 1,
110
+ eta: float = 0.0,
111
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
112
+ latents: Optional[torch.FloatTensor] = None,
113
+ prompt_embeds: Optional[torch.FloatTensor] = None,
114
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
115
+ output_type: Optional[str] = "pil",
116
+ return_dict: bool = True,
117
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
118
+ callback_steps: int = 1,
119
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
120
+ method="ghvb",
121
+ ):
122
+ # 0. Default height and width to unet
123
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
124
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
125
+
126
+ # 1. Check inputs. Raise error if not correct
127
+ self.check_inputs(
128
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
129
+ )
130
+
131
+ # 2. Define call parameters
132
+ if prompt is not None and isinstance(prompt, str):
133
+ batch_size = 1
134
+ elif prompt is not None and isinstance(prompt, list):
135
+ batch_size = len(prompt)
136
+ else:
137
+ batch_size = prompt_embeds.shape[0]
138
+
139
+ device = self._execution_device
140
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
141
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
142
+ # corresponds to doing no classifier free guidance.
143
+ do_classifier_free_guidance = guidance_scale > 1.0
144
+
145
+ # 3. Encode input prompt
146
+ prompt_embeds = self._encode_prompt(
147
+ prompt,
148
+ device,
149
+ num_images_per_prompt,
150
+ do_classifier_free_guidance,
151
+ negative_prompt,
152
+ prompt_embeds=prompt_embeds,
153
+ negative_prompt_embeds=negative_prompt_embeds,
154
+ )
155
+
156
+ # 4. Prepare timesteps
157
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
158
+ timesteps = self.scheduler.timesteps
159
+ # print(timesteps)
160
+
161
+ # 5. Prepare latent variables
162
+ num_channels_latents = self.unet.config.in_channels
163
+ latents = self.prepare_latents(
164
+ batch_size * num_images_per_prompt,
165
+ num_channels_latents,
166
+ height,
167
+ width,
168
+ prompt_embeds.dtype,
169
+ device,
170
+ generator,
171
+ latents,
172
+ )
173
+
174
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
175
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
176
+
177
+ # 7. Denoising loop
178
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
179
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
180
+ for i, t in enumerate(timesteps):
181
+ latents = self.denoising_step(
182
+ latents,
183
+ prompt_embeds,
184
+ guidance_scale,
185
+ t,
186
+ do_classifier_free_guidance,
187
+ method,
188
+ extra_step_kwargs,
189
+ )
190
+
191
+ # call the callback, if provided
192
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
193
+ progress_bar.update()
194
+ if callback is not None and i % callback_steps == 0:
195
+ callback(i, t, latents)
196
+
197
+ if output_type == "latent":
198
+ image = latents
199
+ has_nsfw_concept = None
200
+ elif output_type == "pil":
201
+ # 8. Post-processing
202
+ image = self.decode_latents(latents)
203
+
204
+ # 9. Run safety checker
205
+ # image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
206
+ has_nsfw_concept = False
207
+
208
+ # 10. Convert to PIL
209
+ image = self.numpy_to_pil(image)
210
+ else:
211
+ # 8. Post-processing
212
+ image = self.decode_latents(latents)
213
+
214
+ # 9. Run safety checker
215
+ # image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
216
+ has_nsfw_concept = False
217
+
218
+ # Offload last model to CPU
219
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
220
+ self.final_offload_hook.offload()
221
+
222
+ if not return_dict:
223
+ return (image, has_nsfw_concept)
224
+
225
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
226
+
227
+ def generate(self, params):
228
+ params["output_type"] = "latent"
229
+ ori_latents = self.__call__(**params)["images"]
230
+
231
+ with torch.no_grad():
232
+ latents = torch.clone(ori_latents)
233
+ image = self.decode_latents(latents)
234
+ image = self.numpy_to_pil(image)[0]
235
+
236
+ return image, ori_latents
requirements.txt ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.21.0
2
+ aiofiles==23.1.0
3
+ aiohttp==3.8.4
4
+ aiosignal==1.3.1
5
+ altair==5.0.1
6
+ annotated-types==0.5.0
7
+ anyio==3.7.1
8
+ async-timeout==4.0.2
9
+ attrs==23.1.0
10
+ certifi==2023.5.7
11
+ charset-normalizer==3.2.0
12
+ click==8.1.5
13
+ cmake==3.26.4
14
+ contourpy==1.1.0
15
+ cycler==0.11.0
16
+ diffusers==0.15.0
17
+ exceptiongroup==1.1.2
18
+ fastapi==0.100.0
19
+ ffmpy==0.3.0
20
+ filelock==3.12.2
21
+ fonttools==4.41.0
22
+ frozenlist==1.4.0
23
+ fsspec==2023.6.0
24
+ gradio==3.36.1
25
+ gradio_client==0.2.9
26
+ h11==0.14.0
27
+ httpcore==0.17.3
28
+ httpx==0.24.1
29
+ huggingface-hub==0.16.4
30
+ idna==3.4
31
+ importlib-metadata==6.8.0
32
+ importlib-resources==6.0.0
33
+ Jinja2==3.1.2
34
+ jsonschema==4.18.3
35
+ jsonschema-specifications==2023.6.1
36
+ kiwisolver==1.4.4
37
+ linkify-it-py==2.0.2
38
+ lit==16.0.6
39
+ markdown-it-py==2.2.0
40
+ MarkupSafe==2.1.3
41
+ matplotlib==3.7.2
42
+ mdit-py-plugins==0.3.3
43
+ mdurl==0.1.2
44
+ mpmath==1.3.0
45
+ multidict==6.0.4
46
+ networkx==3.1
47
+ numpy==1.25.1
48
+ nvidia-cublas-cu11==11.10.3.66
49
+ nvidia-cuda-cupti-cu11==11.7.101
50
+ nvidia-cuda-nvrtc-cu11==11.7.99
51
+ nvidia-cuda-runtime-cu11==11.7.99
52
+ nvidia-cudnn-cu11==8.5.0.96
53
+ nvidia-cufft-cu11==10.9.0.58
54
+ nvidia-curand-cu11==10.2.10.91
55
+ nvidia-cusolver-cu11==11.4.0.1
56
+ nvidia-cusparse-cu11==11.7.4.91
57
+ nvidia-nccl-cu11==2.14.3
58
+ nvidia-nvtx-cu11==11.7.91
59
+ orjson==3.9.2
60
+ packaging==23.1
61
+ pandas==2.0.3
62
+ Pillow==10.0.0
63
+ psutil==5.9.5
64
+ pydantic==2.0.2
65
+ pydantic_core==2.1.2
66
+ pydub==0.25.1
67
+ Pygments==2.15.1
68
+ pyparsing==3.0.9
69
+ python-dateutil==2.8.2
70
+ python-multipart==0.0.6
71
+ pytz==2023.3
72
+ PyYAML==6.0
73
+ referencing==0.29.1
74
+ regex==2023.6.3
75
+ requests==2.31.0
76
+ rpds-py==0.8.10
77
+ semantic-version==2.10.0
78
+ six==1.16.0
79
+ sniffio==1.3.0
80
+ starlette==0.27.0
81
+ sympy==1.12
82
+ tokenizers==0.13.3
83
+ tomesd==0.1.3
84
+ toolz==0.12.0
85
+ torch==2.0.1
86
+ tqdm==4.65.0
87
+ transformers==4.28.1
88
+ triton==2.0.0
89
+ typing_extensions==4.7.1
90
+ tzdata==2023.3
91
+ uc-micro-py==1.0.2
92
+ urllib3==2.0.3
93
+ uvicorn==0.22.0
94
+ websockets==11.0.3
95
+ yarl==1.9.2
96
+ zipp==3.16.1