cocktailpeanut commited on
Commit
9f22dc9
·
1 Parent(s): 2fa2ca4
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -30,7 +30,7 @@ elif torch.backends.mps.is_available():
30
  else:
31
  device = torch.device("cpu")
32
  #device = "cuda" if torch.cuda.is_available() else "cpu"
33
- dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
34
 
35
  # initialization
36
  base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
@@ -38,13 +38,13 @@ image_encoder_path = "sdxl_models/image_encoder"
38
  ip_ckpt = "sdxl_models/ip-adapter_sdxl.bin"
39
 
40
  controlnet_path = "diffusers/controlnet-canny-sdxl-1.0"
41
- controlnet = ControlNetModel.from_pretrained(controlnet_path, use_safetensors=False, torch_dtype=torch.float16).to(device)
42
 
43
  # load SDXL pipeline
44
  pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
45
  base_model_path,
46
  controlnet=controlnet,
47
- torch_dtype=torch.float16,
48
  add_watermarker=False,
49
  )
50
 
 
30
  else:
31
  device = torch.device("cpu")
32
  #device = "cuda" if torch.cuda.is_available() else "cpu"
33
+ dtype = torch.float16 if str(device).__contains__("cuda") or str(device).__contains__("mps") else torch.float32
34
 
35
  # initialization
36
  base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
 
38
  ip_ckpt = "sdxl_models/ip-adapter_sdxl.bin"
39
 
40
  controlnet_path = "diffusers/controlnet-canny-sdxl-1.0"
41
+ controlnet = ControlNetModel.from_pretrained(controlnet_path, use_safetensors=False, torch_dtype=dtype).to(device)
42
 
43
  # load SDXL pipeline
44
  pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
45
  base_model_path,
46
  controlnet=controlnet,
47
+ torch_dtype=dtype,
48
  add_watermarker=False,
49
  )
50