Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import subprocess | |
import torch | |
import torch.distributed as dist | |
def setup_distributed(backend="nccl", port=None): | |
"""AdaHessian Optimizer | |
Lifted from https://github.com/BIGBALLON/distribuuuu/blob/master/distribuuuu/utils.py | |
Originally licensed MIT, Copyright (c) 2020 Wei Li | |
""" | |
num_gpus = torch.cuda.device_count() | |
if "SLURM_JOB_ID" in os.environ: | |
rank = int(os.environ["SLURM_PROCID"]) | |
world_size = int(os.environ["SLURM_NTASKS"]) | |
node_list = os.environ["SLURM_NODELIST"] | |
addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1") | |
# specify master port | |
if port is not None: | |
os.environ["MASTER_PORT"] = str(port) | |
elif "MASTER_PORT" not in os.environ: | |
os.environ["MASTER_PORT"] = "10685" | |
if "MASTER_ADDR" not in os.environ: | |
os.environ["MASTER_ADDR"] = addr | |
os.environ["WORLD_SIZE"] = str(world_size) | |
os.environ["LOCAL_RANK"] = str(rank % num_gpus) | |
os.environ["RANK"] = str(rank) | |
else: | |
rank = int(os.environ["RANK"]) | |
world_size = int(os.environ["WORLD_SIZE"]) | |
torch.cuda.set_device(rank % num_gpus) | |
dist.init_process_group( | |
backend=backend, | |
world_size=world_size, | |
rank=rank, | |
) | |
return rank, world_size | |