junhsss commited on
Commit
8dc6ca5
·
1 Parent(s): 8784a1d
Files changed (1) hide show
  1. pipeline.py +9 -17
pipeline.py CHANGED
@@ -20,9 +20,7 @@ class ConsistencyPipeline(DiffusionPipeline):
20
  def __call__(
21
  self,
22
  steps: int = 1,
23
- generator: Optional[
24
- Union[torch.Generator, List[torch.Generator]]
25
- ] = None,
26
  time_min: float = 0.002,
27
  time_max: float = 80.0,
28
  data_std: float = 0.5,
@@ -37,23 +35,17 @@ class ConsistencyPipeline(DiffusionPipeline):
37
 
38
  time: float = time_max
39
 
40
- sample = randn_tensor(shape, generator=generator) * time
41
 
42
  for step in self.progress_bar(range(steps)):
43
  if step > 0:
44
  time = self.search_previous_time(time)
45
  sigma = math.sqrt(time**2 - time_min**2 + 1e-6)
46
- sample = sample + sigma * randn_tensor(
47
- sample.shape, device=sample.device, generator=generator
48
- )
49
 
50
- out = model(
51
- sample, torch.tensor([time], device=sample.device)
52
- ).sample
53
 
54
- skip_coef = data_std**2 / (
55
- (time - time_min) ** 2 + data_std**2
56
- )
57
  out_coef = data_std * time / (time**2 + data_std**2) ** (0.5)
58
 
59
  sample = (sample * skip_coef + out * out_coef).clamp(-1.0, 1.0)
@@ -69,8 +61,8 @@ class ConsistencyPipeline(DiffusionPipeline):
69
 
70
  return ImagePipelineOutput(images=image)
71
 
72
- # TODO: Implement greedy search on FID
73
- def search_previous_time(
74
- self, time, time_min: float = 0.002, time_max: float = 80.0
75
- ):
76
  return (2 * time + time_min) / 3
 
 
 
 
20
  def __call__(
21
  self,
22
  steps: int = 1,
23
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
 
 
24
  time_min: float = 0.002,
25
  time_max: float = 80.0,
26
  data_std: float = 0.5,
 
35
 
36
  time: float = time_max
37
 
38
+ sample = randn_tensor(shape, generator=generator, device=self.device) * time
39
 
40
  for step in self.progress_bar(range(steps)):
41
  if step > 0:
42
  time = self.search_previous_time(time)
43
  sigma = math.sqrt(time**2 - time_min**2 + 1e-6)
44
+ sample = sample + sigma * randn_tensor(sample.shape, device=self.device, generator=generator)
 
 
45
 
46
+ out = model(sample, torch.tensor([time], device=self.device)).sample
 
 
47
 
48
+ skip_coef = data_std**2 / ((time - time_min) ** 2 + data_std**2)
 
 
49
  out_coef = data_std * time / (time**2 + data_std**2) ** (0.5)
50
 
51
  sample = (sample * skip_coef + out * out_coef).clamp(-1.0, 1.0)
 
61
 
62
  return ImagePipelineOutput(images=image)
63
 
64
+ def search_previous_time(self, time, time_min: float = 0.002, time_max: float = 80.0):
 
 
 
65
  return (2 * time + time_min) / 3
66
+
67
+ def cuda(self):
68
+ self.to("cuda")