Spaces:
Runtime error
Runtime error
cocktailpeanut
commited on
Commit
·
9f22dc9
1
Parent(s):
2fa2ca4
update
Browse files
app.py
CHANGED
@@ -30,7 +30,7 @@ elif torch.backends.mps.is_available():
|
|
30 |
else:
|
31 |
device = torch.device("cpu")
|
32 |
#device = "cuda" if torch.cuda.is_available() else "cpu"
|
33 |
-
dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
|
34 |
|
35 |
# initialization
|
36 |
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
|
@@ -38,13 +38,13 @@ image_encoder_path = "sdxl_models/image_encoder"
|
|
38 |
ip_ckpt = "sdxl_models/ip-adapter_sdxl.bin"
|
39 |
|
40 |
controlnet_path = "diffusers/controlnet-canny-sdxl-1.0"
|
41 |
-
controlnet = ControlNetModel.from_pretrained(controlnet_path, use_safetensors=False, torch_dtype=
|
42 |
|
43 |
# load SDXL pipeline
|
44 |
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
|
45 |
base_model_path,
|
46 |
controlnet=controlnet,
|
47 |
-
torch_dtype=
|
48 |
add_watermarker=False,
|
49 |
)
|
50 |
|
|
|
30 |
else:
|
31 |
device = torch.device("cpu")
|
32 |
#device = "cuda" if torch.cuda.is_available() else "cpu"
|
33 |
+
dtype = torch.float16 if str(device).__contains__("cuda") or str(device).__contains__("mps") else torch.float32
|
34 |
|
35 |
# initialization
|
36 |
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
|
|
|
38 |
ip_ckpt = "sdxl_models/ip-adapter_sdxl.bin"
|
39 |
|
40 |
controlnet_path = "diffusers/controlnet-canny-sdxl-1.0"
|
41 |
+
controlnet = ControlNetModel.from_pretrained(controlnet_path, use_safetensors=False, torch_dtype=dtype).to(device)
|
42 |
|
43 |
# load SDXL pipeline
|
44 |
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
|
45 |
base_model_path,
|
46 |
controlnet=controlnet,
|
47 |
+
torch_dtype=dtype,
|
48 |
add_watermarker=False,
|
49 |
)
|
50 |
|