SEED-Story / src /train /dist_utils.py
xinlai's picture
seedx
674d663
raw
history blame
939 Bytes
import torch
import torch.distributed as dist
def all_gather(tensor):
world_size = dist.get_world_size()
tensor_list = [torch.zeros_like(tensor) for _ in range(world_size)]
dist.all_gather(tensor_list, tensor)
return tensor_list
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
@torch.no_grad()
def concat_all_gather(tensor):
"""
Performs all_gather operation on the provided tensors.
*** Warning ***: torch.distributed.all_gather has no gradient.
"""
# if use distributed training
if not is_dist_avail_and_initialized():
return tensor
tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
output = torch.cat(tensors_gather, dim=0)
return output