Diffusers
Safetensors
File size: 7,550 Bytes
39455bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
from typing import Callable, Dict, List, Optional, Self, Tuple, Union

import torch
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from einops import rearrange

from .scheduling_ncsn import (
    AnnealedLangevinDynamicOutput,
    AnnealedLangevinDynamicScheduler,
)
from .unet_2d_ncsn import UNet2DModelForNCSN


def normalize_images(image: torch.Tensor) -> torch.Tensor:
    """Normalize the image to be between 0 and 1 using min-max normalization manner.

    Args:
        image (torch.Tensor): The batch of images to normalize.

    Returns:
        torch.Tensor: The normalized image.
    """
    assert image.ndim == 4, image.ndim
    batch_size = image.shape[0]

    def _normalize(img: torch.Tensor) -> torch.Tensor:
        return (img - img.min()) / (img.max() - img.min())

    for i in range(batch_size):
        image[i] = _normalize(image[i])
    return image


class NCSNPipeline(DiffusionPipeline):
    r"""
    Pipeline for unconditional image generation using Noise Conditional Score Network (NCSN).

    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
    implemented for all pipelines (downloading, saving, running on a particular device, etc.).

    Parameters:
        unet ([`UNet2DModelForNCSN`]):
            A `UNet2DModelForNCSN` to estimate the score of the image.
        scheduler ([`AnnealedLangevinDynamicScheduler`]):
            A `AnnealedLangevinDynamicScheduler` to be used in combination with `unet` to estimate the score of the image.
    """

    unet: UNet2DModelForNCSN
    scheduler: AnnealedLangevinDynamicScheduler

    _callback_tensor_inputs: List[str] = ["samples"]

    def __init__(
        self, unet: UNet2DModelForNCSN, scheduler: AnnealedLangevinDynamicScheduler
    ) -> None:
        super().__init__()
        self.register_modules(unet=unet, scheduler=scheduler)

    def decode_samples(self, samples: torch.Tensor) -> torch.Tensor:
        # Normalize the generated image
        samples = normalize_images(samples)
        # Rearrange the generated image to the correct format
        samples = rearrange(samples, "b c w h -> b w h c")
        return samples

    @torch.no_grad()
    def __call__(
        self,
        batch_size: int = 1,
        num_inference_steps: int = 10,
        generator: Optional[torch.Generator] = None,
        output_type: str = "pil",
        return_dict: bool = True,
        callback_on_step_end: Optional[
            Union[
                Callable[[Self, int, int, Dict], Dict],
                PipelineCallback,
                MultiPipelineCallbacks,
            ]
        ] = None,
        callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
        **kwargs,
    ) -> Union[ImagePipelineOutput, Tuple]:
        r"""
        The call function to the pipeline for generation.

        Args:
            batch_size (`int`, *optional*, defaults to 1):
                The number of images to generate.
            num_inference_steps (`int`, *optional*, defaults to 10):
                The number of inference steps.
            generator (`torch.Generator`, `optional`):
                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
                generation deterministic.
            output_type (`str`, `optional`, defaults to `"pil"`):
                The output format of the generated image. Choose between `PIL.Image` or `np.array`.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple.
            callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
                A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
                each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
                DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
                list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
            callback_on_step_end_tensor_inputs (`List`, *optional*):
                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
                `._callback_tensor_inputs` attribute of your pipeline class.

        Returns:
            [`~pipelines.ImagePipelineOutput`] or `tuple`:
                If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
                returned where the first element is a list with the generated images.
        """
        callback_on_step_end_tensor_inputs = (
            callback_on_step_end_tensor_inputs or self._callback_tensor_inputs
        )
        if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
            callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs

        samples_shape = (
            batch_size,
            self.unet.config.in_channels,  # type: ignore
            self.unet.config.sample_size,  # type: ignore
            self.unet.config.sample_size,  # type: ignore
        )

        # Generate a random sample
        # NOTE: The behavior of random number generation is different between CPU and GPU,
        # so first generate random numbers on CPU and then move them to GPU (if available).
        samples = torch.rand(samples_shape, generator=generator)
        samples = samples.to(self.device)

        # Set the number of inference steps for the scheduler
        self.scheduler.set_timesteps(num_inference_steps)

        # Perform the reverse diffusion process
        for t in self.progress_bar(self.scheduler.timesteps):
            # Perform `num_annnealed_steps` annealing steps
            for i in range(self.scheduler.num_annealed_steps):
                # Predict the score using the model
                model_output = self.unet(samples, t).sample  # type: ignore

                # Perform the annealed langevin dynamics
                output = self.scheduler.step(
                    model_output=model_output,
                    timestep=t,
                    samples=samples,
                    generator=generator,
                    return_dict=return_dict,
                )
                samples = (
                    output.prev_sample
                    if isinstance(output, AnnealedLangevinDynamicOutput)
                    else output[0]
                )

                # Perform the callback on step end if provided
                if callback_on_step_end is not None:
                    callback_kwargs = {}
                    for k in callback_on_step_end_tensor_inputs:
                        callback_kwargs[k] = locals()[k]

                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
                    samples = callback_outputs.pop("samples", samples)

        samples = self.decode_samples(samples)

        if output_type == "pil":
            samples = self.numpy_to_pil(samples.cpu().numpy())

        if return_dict:
            return ImagePipelineOutput(images=samples)  # type: ignore
        else:
            return (samples,)