pipeline / pipeline.py
junhsss's picture
v0.3.2
8dc6ca5
import math
from typing import List, Optional, Tuple, Union
import torch
from diffusers import DiffusionPipeline, ImagePipelineOutput, UNet2DModel
from diffusers.utils import randn_tensor
class ConsistencyPipeline(DiffusionPipeline):
unet: UNet2DModel
def __init__(
self,
unet: UNet2DModel,
) -> None:
super().__init__()
self.register_modules(unet=unet)
@torch.no_grad()
def __call__(
self,
steps: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
time_min: float = 0.002,
time_max: float = 80.0,
data_std: float = 0.5,
output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs,
) -> Union[Tuple, ImagePipelineOutput]:
img_size = self.unet.config.sample_size
shape = (1, 3, img_size, img_size)
model = self.unet
time: float = time_max
sample = randn_tensor(shape, generator=generator, device=self.device) * time
for step in self.progress_bar(range(steps)):
if step > 0:
time = self.search_previous_time(time)
sigma = math.sqrt(time**2 - time_min**2 + 1e-6)
sample = sample + sigma * randn_tensor(sample.shape, device=self.device, generator=generator)
out = model(sample, torch.tensor([time], device=self.device)).sample
skip_coef = data_std**2 / ((time - time_min) ** 2 + data_std**2)
out_coef = data_std * time / (time**2 + data_std**2) ** (0.5)
sample = (sample * skip_coef + out * out_coef).clamp(-1.0, 1.0)
sample = (sample / 2 + 0.5).clamp(0, 1)
image = sample.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
def search_previous_time(self, time, time_min: float = 0.002, time_max: float = 80.0):
return (2 * time + time_min) / 3
def cuda(self):
self.to("cuda")