File size: 1,489 Bytes
cb9e677 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import logging
import os
from functools import lru_cache
from typing import List, Union
import torch
import torch.distributed as dist
logger = logging.getLogger("distributed")
BACKEND = "nccl"
@lru_cache()
def get_rank() -> int:
return dist.get_rank()
@lru_cache()
def get_world_size() -> int:
return dist.get_world_size()
def visible_devices() -> List[int]:
return [int(d) for d in os.environ["CUDA_VISIBLE_DEVICES"].split(",")]
def set_device():
logger.info(f"torch.cuda.device_count: {torch.cuda.device_count()}")
logger.info(f"CUDA_VISIBLE_DEVICES: {os.environ['CUDA_VISIBLE_DEVICES']}")
logger.info(f"local rank: {int(os.environ['LOCAL_RANK'])}")
assert torch.cuda.is_available()
assert len(visible_devices()) == torch.cuda.device_count()
if torch.cuda.device_count() == 1:
# gpus-per-task set to 1
torch.cuda.set_device(0)
return
local_rank = int(os.environ["LOCAL_RANK"])
logger.info(f"Set cuda device to {local_rank}")
assert 0 <= local_rank < torch.cuda.device_count(), (
local_rank,
torch.cuda.device_count(),
)
torch.cuda.set_device(local_rank)
def avg_aggregate(metric: Union[float, int]) -> Union[float, int]:
buffer = torch.tensor([metric], dtype=torch.float32, device="cuda")
dist.all_reduce(buffer, op=dist.ReduceOp.SUM)
return buffer[0].item() / get_world_size()
def is_torchrun() -> bool:
return "TORCHELASTIC_RESTART_COUNT" in os.environ
|