junhsss commited on
Commit
8784a1d
·
1 Parent(s): 6d81ae3
Files changed (1) hide show
  1. pipeline.py +76 -0
pipeline.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from diffusers import DiffusionPipeline, ImagePipelineOutput, UNet2DModel
6
+ from diffusers.utils import randn_tensor
7
+
8
+
9
+ class ConsistencyPipeline(DiffusionPipeline):
10
+ unet: UNet2DModel
11
+
12
+ def __init__(
13
+ self,
14
+ unet: UNet2DModel,
15
+ ) -> None:
16
+ super().__init__()
17
+ self.register_modules(unet=unet)
18
+
19
+ @torch.no_grad()
20
+ def __call__(
21
+ self,
22
+ steps: int = 1,
23
+ generator: Optional[
24
+ Union[torch.Generator, List[torch.Generator]]
25
+ ] = None,
26
+ time_min: float = 0.002,
27
+ time_max: float = 80.0,
28
+ data_std: float = 0.5,
29
+ output_type: Optional[str] = "pil",
30
+ return_dict: bool = True,
31
+ **kwargs,
32
+ ) -> Union[Tuple, ImagePipelineOutput]:
33
+ img_size = self.unet.config.sample_size
34
+ shape = (1, 3, img_size, img_size)
35
+
36
+ model = self.unet
37
+
38
+ time: float = time_max
39
+
40
+ sample = randn_tensor(shape, generator=generator) * time
41
+
42
+ for step in self.progress_bar(range(steps)):
43
+ if step > 0:
44
+ time = self.search_previous_time(time)
45
+ sigma = math.sqrt(time**2 - time_min**2 + 1e-6)
46
+ sample = sample + sigma * randn_tensor(
47
+ sample.shape, device=sample.device, generator=generator
48
+ )
49
+
50
+ out = model(
51
+ sample, torch.tensor([time], device=sample.device)
52
+ ).sample
53
+
54
+ skip_coef = data_std**2 / (
55
+ (time - time_min) ** 2 + data_std**2
56
+ )
57
+ out_coef = data_std * time / (time**2 + data_std**2) ** (0.5)
58
+
59
+ sample = (sample * skip_coef + out * out_coef).clamp(-1.0, 1.0)
60
+
61
+ sample = (sample / 2 + 0.5).clamp(0, 1)
62
+ image = sample.cpu().permute(0, 2, 3, 1).numpy()
63
+
64
+ if output_type == "pil":
65
+ image = self.numpy_to_pil(image)
66
+
67
+ if not return_dict:
68
+ return (image,)
69
+
70
+ return ImagePipelineOutput(images=image)
71
+
72
+ # TODO: Implement greedy search on FID
73
+ def search_previous_time(
74
+ self, time, time_min: float = 0.002, time_max: float = 80.0
75
+ ):
76
+ return (2 * time + time_min) / 3