File size: 1,150 Bytes
c39b2dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import gc
import logging
from typing import List, TypeVar
import torch
from torch.utils.data import Dataset
logger = logging.getLogger(__name__)
T = TypeVar("T")
def get_torch_device(device: str = "auto") -> str:
"""
Returns the device (string) to be used by PyTorch.
`device` arg defaults to "auto" which will use:
- "cuda:0" if available
- else "mps" if available
- else "cpu".
"""
if device == "auto":
if torch.cuda.is_available():
device = "cuda:0"
elif torch.backends.mps.is_available(): # for Apple Silicon
device = "mps"
else:
device = "cpu"
logger.info(f"Using device: {device}")
return device
def tear_down_torch():
"""
Teardown for PyTorch.
Clears GPU cache for both CUDA and MPS.
"""
gc.collect()
torch.cuda.empty_cache()
torch.mps.empty_cache()
class ListDataset(Dataset[T]):
def __init__(self, elements: List[T]):
self.elements = elements
def __len__(self) -> int:
return len(self.elements)
def __getitem__(self, idx: int) -> T:
return self.elements[idx]
|