wangshuai6 commited on
Commit
db75344
·
1 Parent(s): 67d7248
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -80,8 +80,9 @@ class Pipeline:
80
  self.diffusion_sampler.guidance_interval_min = guidance_interval_min
81
  self.diffusion_sampler.guidance_interval_max = guidance_interval_max
82
  self.diffusion_sampler.timeshift = timeshift
83
- generator = torch.Generator(device="cuda").manual_seed(seed)
84
- xT = torch.randn((num_images, 4, self.resolution//8, self.resolution//8), device="cuda", dtype=torch.float32, generator=generator)
 
85
  with torch.no_grad():
86
  condition, uncondition = conditioner([self.classlabels2ids[y],]*num_images)
87
  # Sample images:
@@ -123,7 +124,7 @@ if __name__ == "__main__":
123
  w_scheduler=LinearScheduler(),
124
  guidance_fn=simple_guidance_fn,
125
  num_steps=50,
126
- guidance=3.0,
127
  state_refresh_rate=1,
128
  guidance_interval_min=0.3,
129
  guidance_interval_max=1.0,
 
80
  self.diffusion_sampler.guidance_interval_min = guidance_interval_min
81
  self.diffusion_sampler.guidance_interval_max = guidance_interval_max
82
  self.diffusion_sampler.timeshift = timeshift
83
+ generator = torch.Generator(device="cpu").manual_seed(seed)
84
+ xT = torch.randn((num_images, 4, self.resolution//8, self.resolution//8), device="cpu", dtype=torch.float32, generator=generator)
85
+ xT = xT.to("cuda")
86
  with torch.no_grad():
87
  condition, uncondition = conditioner([self.classlabels2ids[y],]*num_images)
88
  # Sample images:
 
124
  w_scheduler=LinearScheduler(),
125
  guidance_fn=simple_guidance_fn,
126
  num_steps=50,
127
+ guidance=4.0,
128
  state_refresh_rate=1,
129
  guidance_interval_min=0.3,
130
  guidance_interval_max=1.0,