Spaces:
Runtime error
Runtime error
File size: 3,313 Bytes
12bfd03 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
# ------------------------------------------
# Diffsound
# code based https://github.com/cientgu/VQ-Diffusion
# ------------------------------------------
import pickle
import torch
from torch import distributed as dist
from torch.utils import data
LOCAL_PROCESS_GROUP = None
def is_primary():
return get_rank() == 0
def get_rank():
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
return dist.get_rank()
def get_local_rank():
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
if LOCAL_PROCESS_GROUP is None:
raise ValueError("tensorfn.distributed.LOCAL_PROCESS_GROUP is None")
return dist.get_rank(group=LOCAL_PROCESS_GROUP)
def synchronize():
if not dist.is_available():
return
if not dist.is_initialized():
return
world_size = dist.get_world_size()
if world_size == 1:
return
dist.barrier()
def get_world_size():
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size()
def is_distributed():
raise RuntimeError('Please debug this function!')
return get_world_size() > 1
def all_reduce(tensor, op=dist.ReduceOp.SUM, async_op=False):
world_size = get_world_size()
if world_size == 1:
return tensor
dist.all_reduce(tensor, op=op, async_op=async_op)
return tensor
def all_gather(data):
world_size = get_world_size()
if world_size == 1:
return [data]
buffer = pickle.dumps(data)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to("cuda")
local_size = torch.IntTensor([tensor.numel()]).to("cuda")
size_list = [torch.IntTensor([1]).to("cuda") for _ in range(world_size)]
dist.all_gather(size_list, local_size)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
tensor_list = []
for _ in size_list:
tensor_list.append(torch.ByteTensor(size=(max_size, )).to("cuda"))
if local_size != max_size:
padding = torch.ByteTensor(size=(max_size - local_size, )).to("cuda")
tensor = torch.cat((tensor, padding), 0)
dist.all_gather(tensor_list, tensor)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
def reduce_dict(input_dict, average=True):
world_size = get_world_size()
if world_size < 2:
return input_dict
with torch.no_grad():
keys = []
values = []
for k in sorted(input_dict.keys()):
keys.append(k)
values.append(input_dict[k])
values = torch.stack(values, 0)
dist.reduce(values, dst=0)
if dist.get_rank() == 0 and average:
values /= world_size
reduced_dict = {k: v for k, v in zip(keys, values)}
return reduced_dict
def data_sampler(dataset, shuffle, distributed):
if distributed:
return data.distributed.DistributedSampler(dataset, shuffle=shuffle)
if shuffle:
return data.RandomSampler(dataset)
else:
return data.SequentialSampler(dataset)
|