File size: 7,550 Bytes
39455bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
from typing import Callable, Dict, List, Optional, Self, Tuple, Union
import torch
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from einops import rearrange
from .scheduling_ncsn import (
AnnealedLangevinDynamicOutput,
AnnealedLangevinDynamicScheduler,
)
from .unet_2d_ncsn import UNet2DModelForNCSN
def normalize_images(image: torch.Tensor) -> torch.Tensor:
"""Normalize the image to be between 0 and 1 using min-max normalization manner.
Args:
image (torch.Tensor): The batch of images to normalize.
Returns:
torch.Tensor: The normalized image.
"""
assert image.ndim == 4, image.ndim
batch_size = image.shape[0]
def _normalize(img: torch.Tensor) -> torch.Tensor:
return (img - img.min()) / (img.max() - img.min())
for i in range(batch_size):
image[i] = _normalize(image[i])
return image
class NCSNPipeline(DiffusionPipeline):
r"""
Pipeline for unconditional image generation using Noise Conditional Score Network (NCSN).
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Parameters:
unet ([`UNet2DModelForNCSN`]):
A `UNet2DModelForNCSN` to estimate the score of the image.
scheduler ([`AnnealedLangevinDynamicScheduler`]):
A `AnnealedLangevinDynamicScheduler` to be used in combination with `unet` to estimate the score of the image.
"""
unet: UNet2DModelForNCSN
scheduler: AnnealedLangevinDynamicScheduler
_callback_tensor_inputs: List[str] = ["samples"]
def __init__(
self, unet: UNet2DModelForNCSN, scheduler: AnnealedLangevinDynamicScheduler
) -> None:
super().__init__()
self.register_modules(unet=unet, scheduler=scheduler)
def decode_samples(self, samples: torch.Tensor) -> torch.Tensor:
# Normalize the generated image
samples = normalize_images(samples)
# Rearrange the generated image to the correct format
samples = rearrange(samples, "b c w h -> b w h c")
return samples
@torch.no_grad()
def __call__(
self,
batch_size: int = 1,
num_inference_steps: int = 10,
generator: Optional[torch.Generator] = None,
output_type: str = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[
Union[
Callable[[Self, int, int, Dict], Dict],
PipelineCallback,
MultiPipelineCallbacks,
]
] = None,
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
**kwargs,
) -> Union[ImagePipelineOutput, Tuple]:
r"""
The call function to the pipeline for generation.
Args:
batch_size (`int`, *optional*, defaults to 1):
The number of images to generate.
num_inference_steps (`int`, *optional*, defaults to 10):
The number of inference steps.
generator (`torch.Generator`, `optional`):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
output_type (`str`, `optional`, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple.
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
returned where the first element is a list with the generated images.
"""
callback_on_step_end_tensor_inputs = (
callback_on_step_end_tensor_inputs or self._callback_tensor_inputs
)
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
samples_shape = (
batch_size,
self.unet.config.in_channels, # type: ignore
self.unet.config.sample_size, # type: ignore
self.unet.config.sample_size, # type: ignore
)
# Generate a random sample
# NOTE: The behavior of random number generation is different between CPU and GPU,
# so first generate random numbers on CPU and then move them to GPU (if available).
samples = torch.rand(samples_shape, generator=generator)
samples = samples.to(self.device)
# Set the number of inference steps for the scheduler
self.scheduler.set_timesteps(num_inference_steps)
# Perform the reverse diffusion process
for t in self.progress_bar(self.scheduler.timesteps):
# Perform `num_annnealed_steps` annealing steps
for i in range(self.scheduler.num_annealed_steps):
# Predict the score using the model
model_output = self.unet(samples, t).sample # type: ignore
# Perform the annealed langevin dynamics
output = self.scheduler.step(
model_output=model_output,
timestep=t,
samples=samples,
generator=generator,
return_dict=return_dict,
)
samples = (
output.prev_sample
if isinstance(output, AnnealedLangevinDynamicOutput)
else output[0]
)
# Perform the callback on step end if provided
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
samples = callback_outputs.pop("samples", samples)
samples = self.decode_samples(samples)
if output_type == "pil":
samples = self.numpy_to_pil(samples.cpu().numpy())
if return_dict:
return ImagePipelineOutput(images=samples) # type: ignore
else:
return (samples,)
|