YouLiXiya commited on
Commit
aaba5b3
·
1 Parent(s): 5ce2c9c

Upload pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +83 -0
pipeline.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from diffusers import DiffusionPipeline, ImagePipelineOutput, UNet2DModel
6
+ from diffusers.utils import randn_tensor
7
+
8
+
9
+ class ConsistencyPipeline(DiffusionPipeline):
10
+ unet: UNet2DModel
11
+
12
+ def __init__(
13
+ self,
14
+ unet: UNet2DModel,
15
+ ) -> None:
16
+ super().__init__()
17
+ self.register_modules(unet=unet)
18
+
19
+ @torch.no_grad()
20
+ def __call__(
21
+ self,
22
+ batch_size : int = 1,
23
+ num_class: Optional[int] = None,
24
+ label_index: Optional[int] = None,
25
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
26
+ eps: float = 0.002,
27
+ T: float = 80.0,
28
+ data_std: float = 0.5,
29
+ num_inference_steps: int = 1,
30
+ output_type: Optional[str] = "pil",
31
+ return_dict: bool = True,
32
+ **kwargs,
33
+ ) -> Union[Tuple, ImagePipelineOutput]:
34
+ model = self.unet
35
+ device = model.device
36
+ image_labels = None
37
+ if label_index is not None:
38
+ assert label_index + 1 <= num_class, 'label_index must <= num_class!'
39
+ image_labels = torch.LongTensor([label_index]).repeat(batch_size).to(device)
40
+ else:
41
+ if num_class is not None:
42
+ image_labels = torch.randint(low=0, high=num_class, size=[1])
43
+ image_labels = image_labels.repeat(batch_size).to(device)
44
+ img_size = self.unet.config.sample_size
45
+ shape = (batch_size, 3, img_size, img_size)
46
+
47
+ time: float = T
48
+
49
+ sample = randn_tensor(shape, generator=generator, device=device) * time
50
+
51
+ for step in self.progress_bar(range(num_inference_steps)):
52
+ if step > 0:
53
+ time = self.search_previous_time(time)
54
+ sigma = math.sqrt(time ** 2 - eps ** 2 + 1e-6)
55
+ sample = sample + sigma * randn_tensor(
56
+ sample.shape, device=sample.device, generator=generator
57
+ )
58
+
59
+ out = model(sample, torch.tensor([time], device=sample.device), image_labels).sample
60
+
61
+ skip_coef = data_std ** 2 / ((time - eps) ** 2 + data_std ** 2)
62
+ out_coef = data_std * (time - eps) / (time ** 2 + data_std ** 2) ** (0.5)
63
+
64
+ sample = (sample * skip_coef + out * out_coef).clamp(-1.0, 1.0)
65
+
66
+ sample = (sample / 2 + 0.5).clamp(0, 1)
67
+ image = sample.cpu().permute(0, 2, 3, 1).numpy()
68
+
69
+ if output_type == "pil":
70
+ image = self.numpy_to_pil(image)
71
+
72
+ if not return_dict:
73
+ return (image,)
74
+
75
+ return ImagePipelineOutput(images=image)
76
+
77
+ # TODO: Implement greedy search on FID
78
+
79
+ def search_previous_time(
80
+ self, time, eps: float = 0.002, T: float = 80.0
81
+ ):
82
+ return (2 * time + eps) / 3
83
+