import torch import torch.distributed as dist def all_gather(tensor): world_size = dist.get_world_size() tensor_list = [torch.zeros_like(tensor) for _ in range(world_size)] dist.all_gather(tensor_list, tensor) return tensor_list def is_dist_avail_and_initialized(): if not dist.is_available(): return False if not dist.is_initialized(): return False return True @torch.no_grad() def concat_all_gather(tensor): """ Performs all_gather operation on the provided tensors. *** Warning ***: torch.distributed.all_gather has no gradient. """ # if use distributed training if not is_dist_avail_and_initialized(): return tensor tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())] torch.distributed.all_gather(tensors_gather, tensor, async_op=False) output = torch.cat(tensors_gather, dim=0) return output