File size: 2,790 Bytes
aaba5b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import math
from typing import List, Optional, Tuple, Union
import torch
from diffusers import DiffusionPipeline, ImagePipelineOutput, UNet2DModel
from diffusers.utils import randn_tensor
class ConsistencyPipeline(DiffusionPipeline):
unet: UNet2DModel
def __init__(
self,
unet: UNet2DModel,
) -> None:
super().__init__()
self.register_modules(unet=unet)
@torch.no_grad()
def __call__(
self,
batch_size : int = 1,
num_class: Optional[int] = None,
label_index: Optional[int] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
eps: float = 0.002,
T: float = 80.0,
data_std: float = 0.5,
num_inference_steps: int = 1,
output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs,
) -> Union[Tuple, ImagePipelineOutput]:
model = self.unet
device = model.device
image_labels = None
if label_index is not None:
assert label_index + 1 <= num_class, 'label_index must <= num_class!'
image_labels = torch.LongTensor([label_index]).repeat(batch_size).to(device)
else:
if num_class is not None:
image_labels = torch.randint(low=0, high=num_class, size=[1])
image_labels = image_labels.repeat(batch_size).to(device)
img_size = self.unet.config.sample_size
shape = (batch_size, 3, img_size, img_size)
time: float = T
sample = randn_tensor(shape, generator=generator, device=device) * time
for step in self.progress_bar(range(num_inference_steps)):
if step > 0:
time = self.search_previous_time(time)
sigma = math.sqrt(time ** 2 - eps ** 2 + 1e-6)
sample = sample + sigma * randn_tensor(
sample.shape, device=sample.device, generator=generator
)
out = model(sample, torch.tensor([time], device=sample.device), image_labels).sample
skip_coef = data_std ** 2 / ((time - eps) ** 2 + data_std ** 2)
out_coef = data_std * (time - eps) / (time ** 2 + data_std ** 2) ** (0.5)
sample = (sample * skip_coef + out * out_coef).clamp(-1.0, 1.0)
sample = (sample / 2 + 0.5).clamp(0, 1)
image = sample.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
# TODO: Implement greedy search on FID
def search_previous_time(
self, time, eps: float = 0.002, T: float = 80.0
):
return (2 * time + eps) / 3
|