|
import logging |
|
import os |
|
from functools import lru_cache |
|
from typing import List, Union |
|
|
|
import torch |
|
import torch.distributed as dist |
|
|
|
logger = logging.getLogger("distributed") |
|
|
|
BACKEND = "nccl" |
|
|
|
|
|
@lru_cache() |
|
def get_rank() -> int: |
|
return dist.get_rank() |
|
|
|
|
|
@lru_cache() |
|
def get_world_size() -> int: |
|
return dist.get_world_size() |
|
|
|
|
|
def visible_devices() -> List[int]: |
|
return [int(d) for d in os.environ["CUDA_VISIBLE_DEVICES"].split(",")] |
|
|
|
|
|
def set_device(): |
|
logger.info(f"torch.cuda.device_count: {torch.cuda.device_count()}") |
|
logger.info(f"CUDA_VISIBLE_DEVICES: {os.environ['CUDA_VISIBLE_DEVICES']}") |
|
logger.info(f"local rank: {int(os.environ['LOCAL_RANK'])}") |
|
|
|
assert torch.cuda.is_available() |
|
|
|
assert len(visible_devices()) == torch.cuda.device_count() |
|
|
|
if torch.cuda.device_count() == 1: |
|
|
|
torch.cuda.set_device(0) |
|
return |
|
|
|
local_rank = int(os.environ["LOCAL_RANK"]) |
|
logger.info(f"Set cuda device to {local_rank}") |
|
|
|
assert 0 <= local_rank < torch.cuda.device_count(), ( |
|
local_rank, |
|
torch.cuda.device_count(), |
|
) |
|
torch.cuda.set_device(local_rank) |
|
|
|
|
|
def avg_aggregate(metric: Union[float, int]) -> Union[float, int]: |
|
buffer = torch.tensor([metric], dtype=torch.float32, device="cuda") |
|
dist.all_reduce(buffer, op=dist.ReduceOp.SUM) |
|
return buffer[0].item() / get_world_size() |
|
|
|
|
|
def is_torchrun() -> bool: |
|
return "TORCHELASTIC_RESTART_COUNT" in os.environ |
|
|