File size: 27,974 Bytes
df4ab84 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 |
import torch
import os
import PIL
from typing import List, Optional, Union
from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput
from PIL import Image
from diffusers.utils import logging
VECTOR_DATA_FOLDER = "vector_data"
VECTOR_DATA_DICT = "vector_data"
logger = logging.get_logger(__name__)
def get_ddpm_inversion_scheduler(
scheduler,
step_function,
config,
timesteps,
save_timesteps,
latents,
x_ts,
x_ts_c_hat,
save_intermediate_results,
pipe,
x_0,
v1s_images,
v2s_images,
deltas_images,
v1_x0s,
v2_x0s,
deltas_x0s,
folder_name,
image_name,
time_measure_n,
):
def step(
model_output: torch.FloatTensor,
timestep: int,
sample: torch.FloatTensor,
eta: float = 0.0,
use_clipped_model_output: bool = False,
generator=None,
variance_noise: Optional[torch.FloatTensor] = None,
return_dict: bool = True,
):
# if scheduler.is_save:
# start = timer()
res_inv = step_save_latents(
scheduler,
model_output[:1, :, :, :],
timestep,
sample[:1, :, :, :],
eta,
use_clipped_model_output,
generator,
variance_noise,
return_dict,
)
# end = timer()
# print(f"Run Time Inv: {end - start}")
res_inf = step_use_latents(
scheduler,
model_output[1:, :, :, :],
timestep,
sample[1:, :, :, :],
eta,
use_clipped_model_output,
generator,
variance_noise,
return_dict,
)
# res = res_inv
res = (torch.cat((res_inv[0], res_inf[0]), dim=0),)
return res
# return res
scheduler.step_function = step_function
scheduler.is_save = True
scheduler._timesteps = timesteps
scheduler._save_timesteps = save_timesteps if save_timesteps else timesteps
scheduler._config = config
scheduler.latents = latents
scheduler.x_ts = x_ts
scheduler.x_ts_c_hat = x_ts_c_hat
scheduler.step = step
scheduler.save_intermediate_results = save_intermediate_results
scheduler.pipe = pipe
scheduler.v1s_images = v1s_images
scheduler.v2s_images = v2s_images
scheduler.deltas_images = deltas_images
scheduler.v1_x0s = v1_x0s
scheduler.v2_x0s = v2_x0s
scheduler.deltas_x0s = deltas_x0s
scheduler.clean_step_run = False
scheduler.x_0s = create_xts(
config.noise_shift_delta,
config.noise_timesteps,
config.clean_step_timestep,
None,
pipe.scheduler,
timesteps,
x_0,
no_add_noise=True,
)
scheduler.folder_name = folder_name
scheduler.image_name = image_name
scheduler.p_to_p = False
scheduler.p_to_p_replace = False
scheduler.time_measure_n = time_measure_n
return scheduler
def step_save_latents(
self,
model_output: torch.FloatTensor,
timestep: int,
sample: torch.FloatTensor,
eta: float = 0.0,
use_clipped_model_output: bool = False,
generator=None,
variance_noise: Optional[torch.FloatTensor] = None,
return_dict: bool = True,
):
# print(self._save_timesteps)
# timestep_index = map_timpstep_to_index[timestep]
# timestep_index = ((self._save_timesteps == timestep).nonzero(as_tuple=True)[0]).item()
timestep_index = self._save_timesteps.index(timestep) if not self.clean_step_run else -1
next_timestep_index = timestep_index + 1 if not self.clean_step_run else -1
u_hat_t = self.step_function(
model_output=model_output,
timestep=timestep,
sample=sample,
eta=eta,
use_clipped_model_output=use_clipped_model_output,
generator=generator,
variance_noise=variance_noise,
return_dict=False,
scheduler=self,
)
x_t_minus_1 = self.x_ts[next_timestep_index]
self.x_ts_c_hat.append(u_hat_t)
z_t = x_t_minus_1 - u_hat_t
self.latents.append(z_t)
z_t, _ = normalize(z_t, timestep_index, self._config.max_norm_zs)
x_t_minus_1_predicted = u_hat_t + z_t
if not return_dict:
return (x_t_minus_1_predicted,)
return DDIMSchedulerOutput(prev_sample=x_t_minus_1, pred_original_sample=None)
def step_use_latents(
self,
model_output: torch.FloatTensor,
timestep: int,
sample: torch.FloatTensor,
eta: float = 0.0,
use_clipped_model_output: bool = False,
generator=None,
variance_noise: Optional[torch.FloatTensor] = None,
return_dict: bool = True,
):
# timestep_index = ((self._save_timesteps == timestep).nonzero(as_tuple=True)[0]).item()
timestep_index = self._timesteps.index(timestep) if not self.clean_step_run else -1
next_timestep_index = (
timestep_index + 1 if not self.clean_step_run else -1
)
z_t = self.latents[next_timestep_index] # + 1 because latents[0] is X_T
_, normalize_coefficient = normalize(
z_t[0] if self._config.breakdown == "x_t_hat_c_with_zeros" else z_t,
timestep_index,
self._config.max_norm_zs,
)
if normalize_coefficient == 0:
eta = 0
# eta = normalize_coefficient
x_t_hat_c_hat = self.step_function(
model_output=model_output,
timestep=timestep,
sample=sample,
eta=eta,
use_clipped_model_output=use_clipped_model_output,
generator=generator,
variance_noise=variance_noise,
return_dict=False,
scheduler=self,
)
w1 = self._config.ws1[timestep_index]
w2 = self._config.ws2[timestep_index]
x_t_minus_1_exact = self.x_ts[next_timestep_index]
x_t_minus_1_exact = x_t_minus_1_exact.expand_as(x_t_hat_c_hat)
x_t_c_hat: torch.Tensor = self.x_ts_c_hat[next_timestep_index]
if self._config.breakdown == "x_t_c_hat":
raise NotImplementedError("breakdown x_t_c_hat not implemented yet")
# x_t_c_hat = x_t_c_hat.expand_as(x_t_hat_c_hat)
x_t_c = x_t_c_hat[0].expand_as(x_t_hat_c_hat)
# if self._config.breakdown == "x_t_c_hat":
# v1 = x_t_hat_c_hat - x_t_c_hat
# v2 = x_t_c_hat - x_t_c
if (
self._config.breakdown == "x_t_hat_c"
or self._config.breakdown == "x_t_hat_c_with_zeros"
):
zero_index_reconstruction = 1 if not self.time_measure_n else 0
edit_prompts_num = (
(model_output.size(0) - zero_index_reconstruction) // 3
if self._config.breakdown == "x_t_hat_c_with_zeros" and not self.p_to_p
else (model_output.size(0) - zero_index_reconstruction) // 2
)
x_t_hat_c_indices = (zero_index_reconstruction, edit_prompts_num + zero_index_reconstruction)
edit_images_indices = (
edit_prompts_num + zero_index_reconstruction,
(
model_output.size(0)
if self._config.breakdown == "x_t_hat_c"
else zero_index_reconstruction + 2 * edit_prompts_num
),
)
x_t_hat_c = torch.zeros_like(x_t_hat_c_hat)
x_t_hat_c[edit_images_indices[0] : edit_images_indices[1]] = x_t_hat_c_hat[
x_t_hat_c_indices[0] : x_t_hat_c_indices[1]
]
v1 = x_t_hat_c_hat - x_t_hat_c
v2 = x_t_hat_c - normalize_coefficient * x_t_c
if self._config.breakdown == "x_t_hat_c_with_zeros" and not self.p_to_p:
path = os.path.join(
self.folder_name,
VECTOR_DATA_FOLDER,
self.image_name,
)
if not hasattr(self, VECTOR_DATA_DICT):
os.makedirs(path, exist_ok=True)
self.vector_data = dict()
x_t_0 = x_t_c_hat[1]
empty_prompt_indices = (1 + 2 * edit_prompts_num, 1 + 3 * edit_prompts_num)
x_t_hat_0 = x_t_hat_c_hat[empty_prompt_indices[0] : empty_prompt_indices[1]]
self.vector_data[timestep.item()] = dict()
self.vector_data[timestep.item()]["x_t_hat_c"] = x_t_hat_c[
edit_images_indices[0] : edit_images_indices[1]
]
self.vector_data[timestep.item()]["x_t_hat_0"] = x_t_hat_0
self.vector_data[timestep.item()]["x_t_c"] = x_t_c[0].expand_as(x_t_hat_0)
self.vector_data[timestep.item()]["x_t_0"] = x_t_0.expand_as(x_t_hat_0)
self.vector_data[timestep.item()]["x_t_hat_c_hat"] = x_t_hat_c_hat[
edit_images_indices[0] : edit_images_indices[1]
]
self.vector_data[timestep.item()]["x_t_minus_1_noisy"] = x_t_minus_1_exact[
0
].expand_as(x_t_hat_0)
self.vector_data[timestep.item()]["x_t_minus_1_clean"] = self.x_0s[
next_timestep_index
].expand_as(x_t_hat_0)
else: # no breakdown
v1 = x_t_hat_c_hat - normalize_coefficient * x_t_c
v2 = 0
if self.save_intermediate_results and not self.p_to_p:
delta = v1 + v2
v1_plus_x0 = self.x_0s[next_timestep_index] + v1
v2_plus_x0 = self.x_0s[next_timestep_index] + v2
delta_plus_x0 = self.x_0s[next_timestep_index] + delta
v1_images = decode_latents(v1, self.pipe)
self.v1s_images.append(v1_images)
v2_images = (
decode_latents(v2, self.pipe)
if self._config.breakdown != "no_breakdown"
else [PIL.Image.new("RGB", (1, 1))]
)
self.v2s_images.append(v2_images)
delta_images = decode_latents(delta, self.pipe)
self.deltas_images.append(delta_images)
v1_plus_x0_images = decode_latents(v1_plus_x0, self.pipe)
self.v1_x0s.append(v1_plus_x0_images)
v2_plus_x0_images = (
decode_latents(v2_plus_x0, self.pipe)
if self._config.breakdown != "no_breakdown"
else [PIL.Image.new("RGB", (1, 1))]
)
self.v2_x0s.append(v2_plus_x0_images)
delta_plus_x0_images = decode_latents(delta_plus_x0, self.pipe)
self.deltas_x0s.append(delta_plus_x0_images)
# print(f"v1 norm: {torch.norm(v1, dim=0).mean()}")
# if self._config.breakdown != "no_breakdown":
# print(f"v2 norm: {torch.norm(v2, dim=0).mean()}")
# print(f"v sum norm: {torch.norm(v1 + v2, dim=0).mean()}")
x_t_minus_1 = normalize_coefficient * x_t_minus_1_exact + w1 * v1 + w2 * v2
if (
self._config.breakdown == "x_t_hat_c"
or self._config.breakdown == "x_t_hat_c_with_zeros"
):
x_t_minus_1[x_t_hat_c_indices[0] : x_t_hat_c_indices[1]] = x_t_minus_1[
edit_images_indices[0] : edit_images_indices[1]
] # update x_t_hat_c to be x_t_hat_c_hat
if self._config.breakdown == "x_t_hat_c_with_zeros" and not self.p_to_p:
x_t_minus_1[empty_prompt_indices[0] : empty_prompt_indices[1]] = (
x_t_minus_1[edit_images_indices[0] : edit_images_indices[1]]
)
self.vector_data[timestep.item()]["x_t_minus_1_edited"] = x_t_minus_1[
edit_images_indices[0] : edit_images_indices[1]
]
if timestep == self._timesteps[-1]:
torch.save(
self.vector_data,
os.path.join(
path,
f"{VECTOR_DATA_DICT}.pt",
),
)
# p_to_p_force_perfect_reconstruction
if not self.time_measure_n:
x_t_minus_1[0] = x_t_minus_1_exact[0]
if not return_dict:
return (x_t_minus_1,)
return DDIMSchedulerOutput(
prev_sample=x_t_minus_1,
pred_original_sample=None,
)
def create_xts(
noise_shift_delta,
noise_timesteps,
clean_step_timestep,
generator,
scheduler,
timesteps,
x_0,
no_add_noise=False,
):
if noise_timesteps is None:
noising_delta = noise_shift_delta * (timesteps[0] - timesteps[1])
noise_timesteps = [timestep - int(noising_delta) for timestep in timesteps]
first_x_0_idx = len(noise_timesteps)
for i in range(len(noise_timesteps)):
if noise_timesteps[i] <= 0:
first_x_0_idx = i
break
noise_timesteps = noise_timesteps[:first_x_0_idx]
x_0_expanded = x_0.expand(len(noise_timesteps), -1, -1, -1)
noise = (
torch.randn(x_0_expanded.size(), generator=generator, device="cpu").to(
x_0.device
)
if not no_add_noise
else torch.zeros_like(x_0_expanded)
)
x_ts = scheduler.add_noise(
x_0_expanded,
noise,
torch.IntTensor(noise_timesteps),
)
x_ts = [t.unsqueeze(dim=0) for t in list(x_ts)]
x_ts += [x_0] * (len(timesteps) - first_x_0_idx)
x_ts += [x_0]
if clean_step_timestep > 0:
x_ts += [x_0]
return x_ts
def normalize(
z_t,
i,
max_norm_zs,
):
max_norm = max_norm_zs[i]
if max_norm < 0:
return z_t, 1
norm = torch.norm(z_t)
if norm < max_norm:
return z_t, 1
coeff = max_norm / norm
z_t = z_t * coeff
return z_t, coeff
def decode_latents(latent, pipe):
latent_img = pipe.vae.decode(
latent / pipe.vae.config.scaling_factor, return_dict=False
)[0]
return pipe.image_processor.postprocess(latent_img, output_type="pil")
def deterministic_ddim_step(
model_output: torch.FloatTensor,
timestep: int,
sample: torch.FloatTensor,
eta: float = 0.0,
use_clipped_model_output: bool = False,
generator=None,
variance_noise: Optional[torch.FloatTensor] = None,
return_dict: bool = True,
scheduler=None,
):
if scheduler.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
prev_timestep = (
timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
)
# 2. compute alphas, betas
alpha_prod_t = scheduler.alphas_cumprod[timestep]
alpha_prod_t_prev = (
scheduler.alphas_cumprod[prev_timestep]
if prev_timestep >= 0
else scheduler.final_alpha_cumprod
)
beta_prod_t = 1 - alpha_prod_t
if scheduler.config.prediction_type == "epsilon":
pred_original_sample = (
sample - beta_prod_t ** (0.5) * model_output
) / alpha_prod_t ** (0.5)
pred_epsilon = model_output
elif scheduler.config.prediction_type == "sample":
pred_original_sample = model_output
pred_epsilon = (
sample - alpha_prod_t ** (0.5) * pred_original_sample
) / beta_prod_t ** (0.5)
elif scheduler.config.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (
beta_prod_t**0.5
) * model_output
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else:
raise ValueError(
f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction`"
)
# 4. Clip or threshold "predicted x_0"
if scheduler.config.thresholding:
pred_original_sample = scheduler._threshold_sample(pred_original_sample)
elif scheduler.config.clip_sample:
pred_original_sample = pred_original_sample.clamp(
-scheduler.config.clip_sample_range,
scheduler.config.clip_sample_range,
)
# 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
variance = scheduler._get_variance(timestep, prev_timestep)
std_dev_t = eta * variance ** (0.5)
if use_clipped_model_output:
# the pred_epsilon is always re-derived from the clipped x_0 in Glide
pred_epsilon = (
sample - alpha_prod_t ** (0.5) * pred_original_sample
) / beta_prod_t ** (0.5)
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (
0.5
) * pred_epsilon
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
prev_sample = (
alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
)
return prev_sample
def deterministic_euler_step(
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
eta,
use_clipped_model_output,
generator,
variance_noise,
return_dict,
scheduler,
):
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`):
Whether or not to return a
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
Returns:
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
If return_dict is `True`,
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
otherwise a tuple is returned where the first element is the sample tensor.
"""
if (
isinstance(timestep, int)
or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)
):
raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."
),
)
if scheduler.step_index is None:
scheduler._init_step_index(timestep)
sigma = scheduler.sigmas[scheduler.step_index]
# Upcast to avoid precision issues when computing prev_sample
sample = sample.to(torch.float32)
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
if scheduler.config.prediction_type == "epsilon":
pred_original_sample = sample - sigma * model_output
elif scheduler.config.prediction_type == "v_prediction":
# * c_out + input * c_skip
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (
sample / (sigma**2 + 1)
)
elif scheduler.config.prediction_type == "sample":
raise NotImplementedError("prediction_type not implemented yet: sample")
else:
raise ValueError(
f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
)
sigma_from = scheduler.sigmas[scheduler.step_index]
sigma_to = scheduler.sigmas[scheduler.step_index + 1]
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
# 2. Convert to an ODE derivative
derivative = (sample - pred_original_sample) / sigma
dt = sigma_down - sigma
prev_sample = sample + derivative * dt
# Cast sample back to model compatible dtype
prev_sample = prev_sample.to(model_output.dtype)
# upon completion increase step index by one
scheduler._step_index += 1
return prev_sample
def deterministic_non_ancestral_euler_step(
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
eta: float = 0.0,
use_clipped_model_output: bool = False,
s_churn: float = 0.0,
s_tmin: float = 0.0,
s_tmax: float = float("inf"),
s_noise: float = 1.0,
generator: Optional[torch.Generator] = None,
variance_noise: Optional[torch.FloatTensor] = None,
return_dict: bool = True,
scheduler=None,
):
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
s_churn (`float`):
s_tmin (`float`):
s_tmax (`float`):
s_noise (`float`, defaults to 1.0):
Scaling factor for noise added to the sample.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
tuple.
Returns:
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
returned, otherwise a tuple is returned where the first element is the sample tensor.
"""
if (
isinstance(timestep, int)
or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)
):
raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."
),
)
if not scheduler.is_scale_input_called:
logger.warning(
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
"See `StableDiffusionPipeline` for a usage example."
)
if scheduler.step_index is None:
scheduler._init_step_index(timestep)
# Upcast to avoid precision issues when computing prev_sample
sample = sample.to(torch.float32)
sigma = scheduler.sigmas[scheduler.step_index]
gamma = (
min(s_churn / (len(scheduler.sigmas) - 1), 2**0.5 - 1)
if s_tmin <= sigma <= s_tmax
else 0.0
)
sigma_hat = sigma * (gamma + 1)
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
# NOTE: "original_sample" should not be an expected prediction_type but is left in for
# backwards compatibility
if (
scheduler.config.prediction_type == "original_sample"
or scheduler.config.prediction_type == "sample"
):
pred_original_sample = model_output
elif scheduler.config.prediction_type == "epsilon":
pred_original_sample = sample - sigma_hat * model_output
elif scheduler.config.prediction_type == "v_prediction":
# denoised = model_output * c_out + input * c_skip
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (
sample / (sigma**2 + 1)
)
else:
raise ValueError(
f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
)
# 2. Convert to an ODE derivative
derivative = (sample - pred_original_sample) / sigma_hat
dt = scheduler.sigmas[scheduler.step_index + 1] - sigma_hat
prev_sample = sample + derivative * dt
# Cast sample back to model compatible dtype
prev_sample = prev_sample.to(model_output.dtype)
# upon completion increase step index by one
scheduler._step_index += 1
return prev_sample
def deterministic_ddpm_step(
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
eta,
use_clipped_model_output,
generator,
variance_noise,
return_dict,
scheduler,
):
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
t = timestep
prev_t = scheduler.previous_timestep(t)
if model_output.shape[1] == sample.shape[1] * 2 and scheduler.variance_type in [
"learned",
"learned_range",
]:
model_output, predicted_variance = torch.split(
model_output, sample.shape[1], dim=1
)
else:
predicted_variance = None
# 1. compute alphas, betas
alpha_prod_t = scheduler.alphas_cumprod[t]
alpha_prod_t_prev = (
scheduler.alphas_cumprod[prev_t] if prev_t >= 0 else scheduler.one
)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
current_alpha_t = alpha_prod_t / alpha_prod_t_prev
current_beta_t = 1 - current_alpha_t
# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
if scheduler.config.prediction_type == "epsilon":
pred_original_sample = (
sample - beta_prod_t ** (0.5) * model_output
) / alpha_prod_t ** (0.5)
elif scheduler.config.prediction_type == "sample":
pred_original_sample = model_output
elif scheduler.config.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (
beta_prod_t**0.5
) * model_output
else:
raise ValueError(
f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, `sample` or"
" `v_prediction` for the DDPMScheduler."
)
# 3. Clip or threshold "predicted x_0"
if scheduler.config.thresholding:
pred_original_sample = scheduler._threshold_sample(pred_original_sample)
elif scheduler.config.clip_sample:
pred_original_sample = pred_original_sample.clamp(
-scheduler.config.clip_sample_range, scheduler.config.clip_sample_range
)
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_sample_coeff = (
alpha_prod_t_prev ** (0.5) * current_beta_t
) / beta_prod_t
current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
# 5. Compute predicted previous sample µ_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_prev_sample = (
pred_original_sample_coeff * pred_original_sample
+ current_sample_coeff * sample
)
return pred_prev_sample
|