Update comfy/model_management.py
Browse files
comfy/model_management.py
CHANGED
@@ -25,6 +25,8 @@ import sys
|
|
25 |
import platform
|
26 |
import weakref
|
27 |
import gc
|
|
|
|
|
28 |
|
29 |
class VRAMState(Enum):
|
30 |
DISABLED = 0 #No vram present: no need to move models to vram
|
@@ -117,16 +119,16 @@ def get_torch_device():
|
|
117 |
global directml_device
|
118 |
return directml_device
|
119 |
if cpu_state == CPUState.MPS:
|
120 |
-
return torch.device("
|
121 |
if cpu_state == CPUState.CPU:
|
122 |
return torch.device("cpu")
|
123 |
else:
|
124 |
if is_intel_xpu():
|
125 |
-
return torch.device("
|
126 |
elif is_ascend_npu():
|
127 |
-
return torch.device("
|
128 |
else:
|
129 |
-
return torch.device(
|
130 |
|
131 |
def get_total_memory(dev=None, torch_total_too=False):
|
132 |
global directml_enabled
|
@@ -790,7 +792,7 @@ def vae_dtype(device=None, allowed_dtypes=[]):
|
|
790 |
def get_autocast_device(dev):
|
791 |
if hasattr(dev, 'type'):
|
792 |
return dev.type
|
793 |
-
return "
|
794 |
|
795 |
def supports_dtype(device, dtype): #TODO
|
796 |
if dtype == torch.float32:
|
|
|
25 |
import platform
|
26 |
import weakref
|
27 |
import gc
|
28 |
+
import os
|
29 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
30 |
|
31 |
class VRAMState(Enum):
|
32 |
DISABLED = 0 #No vram present: no need to move models to vram
|
|
|
119 |
global directml_device
|
120 |
return directml_device
|
121 |
if cpu_state == CPUState.MPS:
|
122 |
+
return torch.device("cpu")
|
123 |
if cpu_state == CPUState.CPU:
|
124 |
return torch.device("cpu")
|
125 |
else:
|
126 |
if is_intel_xpu():
|
127 |
+
return torch.device("cpu")
|
128 |
elif is_ascend_npu():
|
129 |
+
return torch.device("cpu")
|
130 |
else:
|
131 |
+
return torch.device("cpu")
|
132 |
|
133 |
def get_total_memory(dev=None, torch_total_too=False):
|
134 |
global directml_enabled
|
|
|
792 |
def get_autocast_device(dev):
|
793 |
if hasattr(dev, 'type'):
|
794 |
return dev.type
|
795 |
+
return "cpu"
|
796 |
|
797 |
def supports_dtype(device, dtype): #TODO
|
798 |
if dtype == torch.float32:
|