BILLY12138 commited on
Commit
1fb9388
1 Parent(s): b7362f5

Create tdd_scheduler.py

Browse files
Files changed (1) hide show
  1. tdd_scheduler.py +515 -0
tdd_scheduler.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import TCDScheduler, DPMSolverSinglestepScheduler
2
+ from diffusers.schedulers.scheduling_tcd import *
3
+ from diffusers.schedulers.scheduling_dpmsolver_singlestep import *
4
+
5
+ class TDDScheduler(DPMSolverSinglestepScheduler):
6
+ @register_to_config
7
+ def __init__(
8
+ self,
9
+ num_train_timesteps: int = 1000,
10
+ beta_start: float = 0.0001,
11
+ beta_end: float = 0.02,
12
+ beta_schedule: str = "linear",
13
+ trained_betas: Optional[np.ndarray] = None,
14
+ solver_order: int = 1,
15
+ prediction_type: str = "epsilon",
16
+ thresholding: bool = False,
17
+ dynamic_thresholding_ratio: float = 0.995,
18
+ sample_max_value: float = 1.0,
19
+ algorithm_type: str = "dpmsolver++",
20
+ solver_type: str = "midpoint",
21
+ lower_order_final: bool = False,
22
+ use_karras_sigmas: Optional[bool] = False,
23
+ final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
24
+ lambda_min_clipped: float = -float("inf"),
25
+ variance_type: Optional[str] = None,
26
+ tdd_train_step: int = 250,
27
+ special_jump: bool = False,
28
+ t_l: int = -1
29
+ ):
30
+ self.t_l = t_l
31
+ self.special_jump = special_jump
32
+ self.tdd_train_step = tdd_train_step
33
+ if algorithm_type == "dpmsolver":
34
+ deprecation_message = "algorithm_type `dpmsolver` is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
35
+ deprecate("algorithm_types=dpmsolver", "1.0.0", deprecation_message)
36
+
37
+ if trained_betas is not None:
38
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
39
+ elif beta_schedule == "linear":
40
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
41
+ elif beta_schedule == "scaled_linear":
42
+ # this schedule is very specific to the latent diffusion model.
43
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
44
+ elif beta_schedule == "squaredcos_cap_v2":
45
+ # Glide cosine schedule
46
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
47
+ else:
48
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
49
+
50
+ self.alphas = 1.0 - self.betas
51
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
52
+ # Currently we only support VP-type noise schedule
53
+ self.alpha_t = torch.sqrt(self.alphas_cumprod)
54
+ self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
55
+ self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
56
+ self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
57
+
58
+ # standard deviation of the initial noise distribution
59
+ self.init_noise_sigma = 1.0
60
+
61
+ # settings for DPM-Solver
62
+ if algorithm_type not in ["dpmsolver", "dpmsolver++"]:
63
+ if algorithm_type == "deis":
64
+ self.register_to_config(algorithm_type="dpmsolver++")
65
+ else:
66
+ raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
67
+ if solver_type not in ["midpoint", "heun"]:
68
+ if solver_type in ["logrho", "bh1", "bh2"]:
69
+ self.register_to_config(solver_type="midpoint")
70
+ else:
71
+ raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
72
+
73
+ if algorithm_type != "dpmsolver++" and final_sigmas_type == "zero":
74
+ raise ValueError(
75
+ f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please chooose `sigma_min` instead."
76
+ )
77
+
78
+ # setable values
79
+ self.num_inference_steps = None
80
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
81
+ self.timesteps = torch.from_numpy(timesteps)
82
+ self.model_outputs = [None] * solver_order
83
+ self.sample = None
84
+ self.order_list = self.get_order_list(num_train_timesteps)
85
+ self._step_index = None
86
+ self._begin_index = None
87
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
88
+
89
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
90
+ self.num_inference_steps = num_inference_steps
91
+ # Clipping the minimum of all lambda(t) for numerical stability.
92
+ # This is critical for cosine (squaredcos_cap_v2) noise schedule.
93
+ #original_steps = self.config.original_inference_steps
94
+ if True:
95
+ original_steps=self.tdd_train_step
96
+ k = 1000 / original_steps
97
+ tcd_origin_timesteps = np.asarray(list(range(1, int(original_steps) + 1))) * k - 1
98
+ else:
99
+ tcd_origin_timesteps = np.asarray(list(range(0, int(self.config.num_train_timesteps))))
100
+ # TCD Inference Steps Schedule
101
+ tcd_origin_timesteps = tcd_origin_timesteps[::-1].copy()
102
+ # Select (approximately) evenly spaced indices from tcd_origin_timesteps.
103
+ inference_indices = np.linspace(0, len(tcd_origin_timesteps), num=num_inference_steps, endpoint=False)
104
+ inference_indices = np.floor(inference_indices).astype(np.int64)
105
+ timesteps = tcd_origin_timesteps[inference_indices]
106
+ if self.special_jump:
107
+ if self.tdd_train_step == 50:
108
+ #timesteps = np.array([999., 879., 759., 499., 259.])
109
+ print(timesteps)
110
+ elif self.tdd_train_step == 250:
111
+ if num_inference_steps == 5:
112
+ timesteps = np.array([999., 875., 751., 499., 251.])
113
+ elif num_inference_steps == 6:
114
+ timesteps = np.array([999., 875., 751., 627., 499., 251.])
115
+ elif num_inference_steps == 7:
116
+ timesteps = np.array([999., 875., 751., 627., 499., 375., 251.])
117
+
118
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
119
+ if self.config.use_karras_sigmas:
120
+ log_sigmas = np.log(sigmas)
121
+ sigmas = np.flip(sigmas).copy()
122
+ sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
123
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
124
+ else:
125
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
126
+
127
+ if self.config.final_sigmas_type == "sigma_min":
128
+ sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
129
+ elif self.config.final_sigmas_type == "zero":
130
+ sigma_last = 0
131
+ else:
132
+ raise ValueError(
133
+ f" `final_sigmas_type` must be one of `sigma_min` or `zero`, but got {self.config.final_sigmas_type}"
134
+ )
135
+ sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
136
+
137
+ self.sigmas = torch.from_numpy(sigmas).to(device=device)
138
+
139
+ self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
140
+ self.model_outputs = [None] * self.config.solver_order
141
+ self.sample = None
142
+
143
+ if not self.config.lower_order_final and num_inference_steps % self.config.solver_order != 0:
144
+ logger.warning(
145
+ "Changing scheduler {self.config} to have `lower_order_final` set to True to handle uneven amount of inference steps. Please make sure to always use an even number of `num_inference steps when using `lower_order_final=False`."
146
+ )
147
+ self.register_to_config(lower_order_final=True)
148
+
149
+ if not self.config.lower_order_final and self.config.final_sigmas_type == "zero":
150
+ logger.warning(
151
+ " `last_sigmas_type='zero'` is not supported for `lower_order_final=False`. Changing scheduler {self.config} to have `lower_order_final` set to True."
152
+ )
153
+ self.register_to_config(lower_order_final=True)
154
+
155
+ self.order_list = self.get_order_list(num_inference_steps)
156
+
157
+ # add an index counter for schedulers that allow duplicated timesteps
158
+ self._step_index = None
159
+ self._begin_index = None
160
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
161
+
162
+ def set_timesteps_s(self, eta: float = 0.0):
163
+ # Clipping the minimum of all lambda(t) for numerical stability.
164
+ # This is critical for cosine (squaredcos_cap_v2) noise schedule.
165
+ num_inference_steps = self.num_inference_steps
166
+ device = self.timesteps.device
167
+ if True:
168
+ original_steps=self.tdd_train_step
169
+ k = 1000 / original_steps
170
+ tcd_origin_timesteps = np.asarray(list(range(1, int(original_steps) + 1))) * k - 1
171
+ else:
172
+ tcd_origin_timesteps = np.asarray(list(range(0, int(self.config.num_train_timesteps))))
173
+ # TCD Inference Steps Schedule
174
+ tcd_origin_timesteps = tcd_origin_timesteps[::-1].copy()
175
+ # Select (approximately) evenly spaced indices from tcd_origin_timesteps.
176
+ inference_indices = np.linspace(0, len(tcd_origin_timesteps), num=num_inference_steps, endpoint=False)
177
+ inference_indices = np.floor(inference_indices).astype(np.int64)
178
+ timesteps = tcd_origin_timesteps[inference_indices]
179
+ if self.special_jump:
180
+ if self.tdd_train_step == 50:
181
+ timesteps = np.array([999., 879., 759., 499., 259.])
182
+ elif self.tdd_train_step == 250:
183
+ if num_inference_steps == 5:
184
+ timesteps = np.array([999., 875., 751., 499., 251.])
185
+ elif num_inference_steps == 6:
186
+ timesteps = np.array([999., 875., 751., 627., 499., 251.])
187
+ elif num_inference_steps == 7:
188
+ timesteps = np.array([999., 875., 751., 627., 499., 375., 251.])
189
+
190
+ timesteps_s = np.floor((1 - eta) * timesteps).astype(np.int64)
191
+
192
+ sigmas_s = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
193
+ if self.config.use_karras_sigmas:
194
+ print("have not write")
195
+ pass
196
+ else:
197
+ sigmas_s = np.interp(timesteps_s, np.arange(0, len(sigmas_s)), sigmas_s)
198
+
199
+ if self.config.final_sigmas_type == "sigma_min":
200
+ sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
201
+ elif self.config.final_sigmas_type == "zero":
202
+ sigma_last = 0
203
+ else:
204
+ raise ValueError(
205
+ f" `final_sigmas_type` must be one of `sigma_min` or `zero`, but got {self.config.final_sigmas_type}"
206
+ )
207
+
208
+ sigmas_s = np.concatenate([sigmas_s, [sigma_last]]).astype(np.float32)
209
+ self.sigmas_s = torch.from_numpy(sigmas_s).to(device=device)
210
+ self.timesteps_s = torch.from_numpy(timesteps_s).to(device=device, dtype=torch.int64)
211
+
212
+ def step(
213
+ self,
214
+ model_output: torch.FloatTensor,
215
+ timestep: int,
216
+ sample: torch.FloatTensor,
217
+ eta: float,
218
+ generator: Optional[torch.Generator] = None,
219
+ return_dict: bool = True,
220
+ ) -> Union[SchedulerOutput, Tuple]:
221
+ if self.num_inference_steps is None:
222
+ raise ValueError(
223
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
224
+ )
225
+
226
+ if self.step_index is None:
227
+ self._init_step_index(timestep)
228
+
229
+ if self.step_index == 0:
230
+ self.set_timesteps_s(eta)
231
+
232
+ model_output = self.convert_model_output(model_output, sample=sample)
233
+ for i in range(self.config.solver_order - 1):
234
+ self.model_outputs[i] = self.model_outputs[i + 1]
235
+ self.model_outputs[-1] = model_output
236
+
237
+ order = self.order_list[self.step_index]
238
+
239
+ # For img2img denoising might start with order>1 which is not possible
240
+ # In this case make sure that the first two steps are both order=1
241
+ while self.model_outputs[-order] is None:
242
+ order -= 1
243
+
244
+ # For single-step solvers, we use the initial value at each time with order = 1.
245
+ if order == 1:
246
+ self.sample = sample
247
+
248
+ prev_sample = self.singlestep_dpm_solver_update(self.model_outputs, sample=self.sample, order=order)
249
+
250
+ if eta > 0:
251
+ if self.step_index != self.num_inference_steps - 1:
252
+
253
+ alpha_prod_s = self.alphas_cumprod[self.timesteps_s[self.step_index + 1]]
254
+ alpha_prod_t_prev = self.alphas_cumprod[self.timesteps[self.step_index + 1]]
255
+
256
+ noise = randn_tensor(
257
+ model_output.shape, generator=generator, device=model_output.device, dtype=prev_sample.dtype
258
+ )
259
+ prev_sample = (alpha_prod_t_prev / alpha_prod_s).sqrt() * prev_sample + (
260
+ 1 - alpha_prod_t_prev / alpha_prod_s
261
+ ).sqrt() * noise
262
+
263
+ # upon completion increase step index by one
264
+ self._step_index += 1
265
+
266
+ if not return_dict:
267
+ return (prev_sample,)
268
+
269
+ return SchedulerOutput(prev_sample=prev_sample)
270
+
271
+ def dpm_solver_first_order_update(
272
+ self,
273
+ model_output: torch.FloatTensor,
274
+ *args,
275
+ sample: torch.FloatTensor = None,
276
+ **kwargs,
277
+ ) -> torch.FloatTensor:
278
+ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
279
+ prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
280
+ if sample is None:
281
+ if len(args) > 2:
282
+ sample = args[2]
283
+ else:
284
+ raise ValueError(" missing `sample` as a required keyward argument")
285
+ if timestep is not None:
286
+ deprecate(
287
+ "timesteps",
288
+ "1.0.0",
289
+ "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
290
+ )
291
+
292
+ if prev_timestep is not None:
293
+ deprecate(
294
+ "prev_timestep",
295
+ "1.0.0",
296
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
297
+ )
298
+ sigma_t, sigma_s = self.sigmas_s[self.step_index + 1], self.sigmas[self.step_index]
299
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
300
+ alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
301
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
302
+ lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
303
+ h = lambda_t - lambda_s
304
+ if self.config.algorithm_type == "dpmsolver++":
305
+ x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
306
+ elif self.config.algorithm_type == "dpmsolver":
307
+ x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
308
+ return x_t
309
+
310
+ def singlestep_dpm_solver_second_order_update(
311
+ self,
312
+ model_output_list: List[torch.FloatTensor],
313
+ *args,
314
+ sample: torch.FloatTensor = None,
315
+ **kwargs,
316
+ ) -> torch.FloatTensor:
317
+ timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
318
+ prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
319
+ if sample is None:
320
+ if len(args) > 2:
321
+ sample = args[2]
322
+ else:
323
+ raise ValueError(" missing `sample` as a required keyward argument")
324
+ if timestep_list is not None:
325
+ deprecate(
326
+ "timestep_list",
327
+ "1.0.0",
328
+ "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
329
+ )
330
+
331
+ if prev_timestep is not None:
332
+ deprecate(
333
+ "prev_timestep",
334
+ "1.0.0",
335
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
336
+ )
337
+ sigma_t, sigma_s0, sigma_s1 = (
338
+ self.sigmas_s[self.step_index + 1],
339
+ self.sigmas[self.step_index],
340
+ self.sigmas[self.step_index - 1],
341
+ )
342
+
343
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
344
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
345
+ alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
346
+
347
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
348
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
349
+ lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
350
+
351
+ m0, m1 = model_output_list[-1], model_output_list[-2]
352
+
353
+ h, h_0 = lambda_t - lambda_s1, lambda_s0 - lambda_s1
354
+ r0 = h_0 / h
355
+ D0, D1 = m1, (1.0 / r0) * (m0 - m1)
356
+ if self.config.algorithm_type == "dpmsolver++":
357
+ # See https://arxiv.org/abs/2211.01095 for detailed derivations
358
+ if self.config.solver_type == "midpoint":
359
+ x_t = (
360
+ (sigma_t / sigma_s1) * sample
361
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
362
+ - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
363
+ )
364
+ elif self.config.solver_type == "heun":
365
+ x_t = (
366
+ (sigma_t / sigma_s1) * sample
367
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
368
+ + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
369
+ )
370
+ elif self.config.algorithm_type == "dpmsolver":
371
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
372
+ if self.config.solver_type == "midpoint":
373
+ x_t = (
374
+ (alpha_t / alpha_s1) * sample
375
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
376
+ - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1
377
+ )
378
+ elif self.config.solver_type == "heun":
379
+ x_t = (
380
+ (alpha_t / alpha_s1) * sample
381
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
382
+ - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
383
+ )
384
+ return x_t
385
+
386
+ def singlestep_dpm_solver_update(
387
+ self,
388
+ model_output_list: List[torch.FloatTensor],
389
+ *args,
390
+ sample: torch.FloatTensor = None,
391
+ order: int = None,
392
+ **kwargs,
393
+ ) -> torch.FloatTensor:
394
+ timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
395
+ prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
396
+ if sample is None:
397
+ if len(args) > 2:
398
+ sample = args[2]
399
+ else:
400
+ raise ValueError(" missing`sample` as a required keyward argument")
401
+ if order is None:
402
+ if len(args) > 3:
403
+ order = args[3]
404
+ else:
405
+ raise ValueError(" missing `order` as a required keyward argument")
406
+ if timestep_list is not None:
407
+ deprecate(
408
+ "timestep_list",
409
+ "1.0.0",
410
+ "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
411
+ )
412
+
413
+ if prev_timestep is not None:
414
+ deprecate(
415
+ "prev_timestep",
416
+ "1.0.0",
417
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
418
+ )
419
+
420
+ if order == 1:
421
+ return self.dpm_solver_first_order_update(model_output_list[-1], sample=sample)
422
+ elif order == 2:
423
+ return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample)
424
+ else:
425
+ raise ValueError(f"Order must be 1, 2, got {order}")
426
+
427
+ def convert_model_output(
428
+ self,
429
+ model_output: torch.FloatTensor,
430
+ *args,
431
+ sample: torch.FloatTensor = None,
432
+ **kwargs,
433
+ ) -> torch.FloatTensor:
434
+ """
435
+ Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
436
+ designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
437
+ integral of the data prediction model.
438
+
439
+ <Tip>
440
+
441
+ The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
442
+ prediction and data prediction models.
443
+
444
+ </Tip>
445
+
446
+ Args:
447
+ model_output (`torch.FloatTensor`):
448
+ The direct output from the learned diffusion model.
449
+ sample (`torch.FloatTensor`):
450
+ A current instance of a sample created by the diffusion process.
451
+
452
+ Returns:
453
+ `torch.FloatTensor`:
454
+ The converted model output.
455
+ """
456
+ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
457
+ if sample is None:
458
+ if len(args) > 1:
459
+ sample = args[1]
460
+ else:
461
+ raise ValueError("missing `sample` as a required keyward argument")
462
+ if timestep is not None:
463
+ deprecate(
464
+ "timesteps",
465
+ "1.0.0",
466
+ "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
467
+ )
468
+ # DPM-Solver++ needs to solve an integral of the data prediction model.
469
+ if self.config.algorithm_type == "dpmsolver++":
470
+ if self.config.prediction_type == "epsilon":
471
+ # DPM-Solver and DPM-Solver++ only need the "mean" output.
472
+ if self.config.variance_type in ["learned_range"]:
473
+ model_output = model_output[:, :3]
474
+ sigma = self.sigmas[self.step_index]
475
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
476
+ x0_pred = (sample - sigma_t * model_output) / alpha_t
477
+ elif self.config.prediction_type == "sample":
478
+ x0_pred = model_output
479
+ elif self.config.prediction_type == "v_prediction":
480
+ sigma = self.sigmas[self.step_index]
481
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
482
+ x0_pred = alpha_t * sample - sigma_t * model_output
483
+ else:
484
+ raise ValueError(
485
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
486
+ " `v_prediction` for the DPMSolverSinglestepScheduler."
487
+ )
488
+
489
+ if self.step_index <= self.t_l:
490
+ if self.config.thresholding:
491
+ x0_pred = self._threshold_sample(x0_pred)
492
+
493
+ return x0_pred
494
+ # DPM-Solver needs to solve an integral of the noise prediction model.
495
+ elif self.config.algorithm_type == "dpmsolver":
496
+ if self.config.prediction_type == "epsilon":
497
+ # DPM-Solver and DPM-Solver++ only need the "mean" output.
498
+ if self.config.variance_type in ["learned_range"]:
499
+ model_output = model_output[:, :3]
500
+ return model_output
501
+ elif self.config.prediction_type == "sample":
502
+ sigma = self.sigmas[self.step_index]
503
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
504
+ epsilon = (sample - alpha_t * model_output) / sigma_t
505
+ return epsilon
506
+ elif self.config.prediction_type == "v_prediction":
507
+ sigma = self.sigmas[self.step_index]
508
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
509
+ epsilon = alpha_t * model_output + sigma_t * sample
510
+ return epsilon
511
+ else:
512
+ raise ValueError(
513
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
514
+ " `v_prediction` for the DPMSolverSinglestepScheduler."
515
+ )