File size: 4,448 Bytes
02e90e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
from functools import lru_cache
import sys
import torch
from modules import config
import logging
logger = logging.getLogger(__name__)
if sys.platform == "darwin":
from modules.devices import mac_devices
def has_mps() -> bool:
if sys.platform != "darwin":
return False
else:
return mac_devices.has_mps
def get_cuda_device_id():
return (
int(config.runtime_env_vars.device_id)
if config.runtime_env_vars.device_id is not None
and config.runtime_env_vars.device_id.isdigit()
else 0
) or torch.cuda.current_device()
def get_cuda_device_string():
if config.runtime_env_vars.device_id is not None:
return f"cuda:{config.runtime_env_vars.device_id}"
return "cuda"
def get_available_gpus() -> list[tuple[int, int]]:
"""
Get the list of available GPUs and their free memory.
:return: A list of tuples where each tuple contains (GPU index, free memory in bytes).
"""
available_gpus = []
for i in range(torch.cuda.device_count()):
props = torch.cuda.get_device_properties(i)
free_memory = props.total_memory - torch.cuda.memory_reserved(i)
available_gpus.append((i, free_memory))
return available_gpus
def get_memory_available_gpus(min_memory=2048):
available_gpus = get_available_gpus()
memory_available_gpus = [
gpu for gpu, free_memory in available_gpus if free_memory > min_memory
]
return memory_available_gpus
def get_target_device_id_or_memory_available_gpu():
memory_available_gpus = get_memory_available_gpus()
device_id = get_cuda_device_id()
if device_id not in memory_available_gpus:
if len(memory_available_gpus) != 0:
logger.warning(
f"Device {device_id} is not available or does not have enough memory. will try to use {memory_available_gpus}"
)
config.runtime_env_vars.device_id = str(memory_available_gpus[0])
else:
logger.warning(
f"Device {device_id} is not available or does not have enough memory. Using CPU instead."
)
return "cpu"
return get_cuda_device_string()
def get_optimal_device_name():
if config.runtime_env_vars.use_cpu:
return "cpu"
if torch.cuda.is_available():
return get_target_device_id_or_memory_available_gpu()
if has_mps():
return "mps"
return "cpu"
def get_optimal_device():
return torch.device(get_optimal_device_name())
def get_device_for(task):
if task in config.cmd_opts.use_cpu or "all" in config.cmd_opts.use_cpu:
return cpu
return get_optimal_device()
def torch_gc():
try:
if torch.cuda.is_available():
with torch.cuda.device(get_cuda_device_string()):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
if has_mps():
mac_devices.torch_mps_gc()
except Exception as e:
logger.error(f"Error in torch_gc", exc_info=True)
cpu: torch.device = torch.device("cpu")
device: torch.device = get_optimal_device()
dtype: torch.dtype = torch.float32
dtype_dvae: torch.dtype = torch.float32
dtype_vocos: torch.dtype = torch.float32
dtype_gpt: torch.dtype = torch.float32
dtype_decoder: torch.dtype = torch.float32
def reset_device():
if config.runtime_env_vars.half:
global dtype
global dtype_dvae
global dtype_vocos
global dtype_gpt
global dtype_decoder
dtype = torch.float16
dtype_dvae = torch.float16
dtype_vocos = torch.float16
dtype_gpt = torch.float16
dtype_decoder = torch.float16
logger.info("Using half precision: torch.float16")
if (
config.runtime_env_vars.device_id is not None
or config.runtime_env_vars.use_cpu is not None
):
global device
device = get_optimal_device()
logger.info(f"Using device: {device}")
@lru_cache
def first_time_calculation():
"""
just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
spends about 2.7 seconds doing that, at least wih NVidia.
"""
x = torch.zeros((1, 1)).to(device, dtype)
linear = torch.nn.Linear(1, 1).to(device, dtype)
linear(x)
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
conv2d(x)
|