MaykaGR commited on
Commit
da79093
·
verified ·
1 Parent(s): 7840427

Update comfy/model_management.py

Browse files
Files changed (1) hide show
  1. comfy/model_management.py +7 -5
comfy/model_management.py CHANGED
@@ -25,6 +25,8 @@ import sys
25
  import platform
26
  import weakref
27
  import gc
 
 
28
 
29
  class VRAMState(Enum):
30
  DISABLED = 0 #No vram present: no need to move models to vram
@@ -117,16 +119,16 @@ def get_torch_device():
117
  global directml_device
118
  return directml_device
119
  if cpu_state == CPUState.MPS:
120
- return torch.device("mps")
121
  if cpu_state == CPUState.CPU:
122
  return torch.device("cpu")
123
  else:
124
  if is_intel_xpu():
125
- return torch.device("xpu", torch.xpu.current_device())
126
  elif is_ascend_npu():
127
- return torch.device("npu", torch.npu.current_device())
128
  else:
129
- return torch.device(torch.cuda.current_device())
130
 
131
  def get_total_memory(dev=None, torch_total_too=False):
132
  global directml_enabled
@@ -790,7 +792,7 @@ def vae_dtype(device=None, allowed_dtypes=[]):
790
  def get_autocast_device(dev):
791
  if hasattr(dev, 'type'):
792
  return dev.type
793
- return "cuda"
794
 
795
  def supports_dtype(device, dtype): #TODO
796
  if dtype == torch.float32:
 
25
  import platform
26
  import weakref
27
  import gc
28
+ import os
29
+ os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
30
 
31
  class VRAMState(Enum):
32
  DISABLED = 0 #No vram present: no need to move models to vram
 
119
  global directml_device
120
  return directml_device
121
  if cpu_state == CPUState.MPS:
122
+ return torch.device("cpu")
123
  if cpu_state == CPUState.CPU:
124
  return torch.device("cpu")
125
  else:
126
  if is_intel_xpu():
127
+ return torch.device("cpu")
128
  elif is_ascend_npu():
129
+ return torch.device("cpu")
130
  else:
131
+ return torch.device("cpu")
132
 
133
  def get_total_memory(dev=None, torch_total_too=False):
134
  global directml_enabled
 
792
  def get_autocast_device(dev):
793
  if hasattr(dev, 'type'):
794
  return dev.type
795
+ return "cpu"
796
 
797
  def supports_dtype(device, dtype): #TODO
798
  if dtype == torch.float32: