charleselena commited on
Commit
0dda0de
1 Parent(s): cc2c67b

add safety checker with type

Browse files
Files changed (1) hide show
  1. handler.py +3 -6
handler.py CHANGED
@@ -85,14 +85,11 @@ class EndpointHandler():
85
  # safety_checker = SafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
86
  # ).to(device)
87
 
88
- self.pipe = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id,
 
89
  controlnet=self.controlnet,
90
  torch_dtype=dtype,
91
- safety_checker=Node).to(device)
92
-
93
- #StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
94
-
95
-
96
  # Define Generator with seed
97
  self.generator = torch.Generator(device=device.type).manual_seed(3)
98
 
 
85
  # safety_checker = SafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
86
  # ).to(device)
87
 
88
+
89
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id,
90
  controlnet=self.controlnet,
91
  torch_dtype=dtype,
92
+ safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", torch_dtype=torch.float16)).to("cuda")
 
 
 
 
93
  # Define Generator with seed
94
  self.generator = torch.Generator(device=device.type).manual_seed(3)
95