Spaces:
Runtime error
Runtime error
import torch | |
import torch.distributed as dist | |
def all_gather(tensor): | |
world_size = dist.get_world_size() | |
tensor_list = [torch.zeros_like(tensor) for _ in range(world_size)] | |
dist.all_gather(tensor_list, tensor) | |
return tensor_list | |
def is_dist_avail_and_initialized(): | |
if not dist.is_available(): | |
return False | |
if not dist.is_initialized(): | |
return False | |
return True | |
def concat_all_gather(tensor): | |
""" | |
Performs all_gather operation on the provided tensors. | |
*** Warning ***: torch.distributed.all_gather has no gradient. | |
""" | |
# if use distributed training | |
if not is_dist_avail_and_initialized(): | |
return tensor | |
tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())] | |
torch.distributed.all_gather(tensors_gather, tensor, async_op=False) | |
output = torch.cat(tensors_gather, dim=0) | |
return output | |