Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
""" | |
[Copied from detectron2] | |
This file contains primitives for multi-gpu communication. | |
This is useful when doing distributed training. | |
""" | |
import functools | |
import logging | |
import numpy as np | |
import pickle | |
import torch | |
import torch.distributed as dist | |
_LOCAL_PROCESS_GROUP = None | |
""" | |
A torch process group which only includes processes that on the same machine as the current process. | |
This variable is set when processes are spawned by `launch()` in "engine/launch.py". | |
""" | |
def get_world_size() -> int: | |
if not dist.is_available(): | |
return 1 | |
if not dist.is_initialized(): | |
return 1 | |
return dist.get_world_size() | |
def get_rank() -> int: | |
if not dist.is_available(): | |
return 0 | |
if not dist.is_initialized(): | |
return 0 | |
return dist.get_rank() | |
def is_main_process() -> bool: | |
return get_rank() == 0 | |
def synchronize(): | |
""" | |
Helper function to synchronize (barrier) among all processes when | |
using distributed training | |
""" | |
if not dist.is_available(): | |
return | |
if not dist.is_initialized(): | |
return | |
world_size = dist.get_world_size() | |
if world_size == 1: | |
return | |
dist.barrier() | |