Wmcs91's picture
Duplicate from shi-labs/Versatile-Diffusion
fb53ec8
raw
history blame contribute delete
No virus
10.3 kB
from tokenize import group
import torch
import numpy as np
import numpy.random as npr
import torch.distributed as dist
import math
from ...log_service import print_log
from ... import sync
def singleton(class_):
instances = {}
def getinstance(*args, **kwargs):
if class_ not in instances:
instances[class_] = class_(*args, **kwargs)
return instances[class_]
return getinstance
@singleton
class get_sampler(object):
def __init__(self):
self.sampler = {}
def register(self, sampler):
self.sampler[sampler.__name__] = sampler
def __call__(self, dataset, cfg):
if cfg == 'default_train':
return GlobalDistributedSampler(dataset, shuffle=True, extend=False)
elif cfg == 'default_eval':
return GlobalDistributedSampler(dataset, shuffle=False, extend=True)
else:
t = cfg.type
return self.sampler[t](dataset=dataset, **cfg.args)
def register():
def wrapper(class_):
get_sampler().register(class_)
return class_
return wrapper
######################
# DistributedSampler #
######################
@register()
class GlobalDistributedSampler(torch.utils.data.Sampler):
"""
This is a distributed sampler that sync accross gpus and nodes.
"""
def __init__(self,
dataset,
shuffle=True,
extend=False,):
"""
Arguments:
dataset: Dataset used for sampling.
shuffle: If true, sampler will shuffle the indices
extend: If true, sampler will extend the indices that can be even distributed by ranks
otherwise sampler will truncate the indices to make it even.
"""
self.ddp = sync.is_ddp()
self.rank = sync.get_rank('global')
self.world_size = sync.get_world_size('global')
self.dataset = dataset
self.shuffle = shuffle
self.extend = extend
num_samples = len(dataset) // self.world_size
if extend and (len(dataset)%self.world_size != 0):
num_samples+=1
self.num_samples = num_samples
self.total_size = num_samples * self.world_size
def __iter__(self):
indices = self.get_sync_order()
if self.extend:
# extend using the front indices
indices = indices+indices[0:self.total_size-len(indices)]
else:
# truncate
indices = indices[0:self.total_size]
# subsample
indices = indices[self.rank : len(indices) : self.world_size]
return iter(indices)
def __len__(self):
return self.num_samples
def get_sync_order(self):
if self.shuffle:
indices = torch.randperm(len(self.dataset)).to(self.rank)
if self.ddp:
dist.broadcast(indices, src=0)
indices = indices.to('cpu').tolist()
else:
indices = list(range(len(self.dataset)))
print_log('Sampler : {}'.format(str(indices[0:5])) )
return indices
@register()
class LocalDistributedSampler(GlobalDistributedSampler):
"""
This is a distributed sampler that sync across gpus within the nodes.
But not sync across nodes.
"""
def __init__(self,
dataset,
shuffle=True,
extend=False,):
super().__init__(dataset, shuffle, extend)
self.rank = sync.get_rank('local')
self.world_size = sync.get_world_size('local')
def get_sync_order(self):
if self.shuffle:
if self.rank == 0:
indices = list(npr.permutation(len(self.dataset)))
sync.nodewise_sync().broadcast_r0(indices)
else:
indices = sync.nodewise_sync().broadcast_r0(None)
else:
indices = list(range(len(self.dataset)))
print_log('Sampler : {}'.format(str(indices[0:5])) )
return indices
############################
# random sample with group #
############################
# Deprecated
@register()
class GroupSampler(torch.utils.data.Sampler):
"""
This is a new DistributedSampler that sample all index according to group.
i.e.
if group_size=3, num_replicas=2, train mode:
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10
==> (group) [0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10]
==> (distribute) process0: [3, 4, 5], (leftover [6, 7, 8, 9, 10])
process1: [0, 1, 2]
==> (group leftover) process0: [3, 4, 5], (leftover [6, 7], [8, 9], 10)
process1: [0, 1, 2]
==> (distribute) process0: [3, 4, 5], [6, 7] (remove 10)
process1: [0, 1, 2], [8, 9]
it will avoid_batchsize=1:
0, 1, 2, 3, 4, 5, 6, 7, 8,
==> (group) [0, 1, 2], [3, 4, 5], [6, 7, 8]
==> (distribute) process0: [3, 4, 5], (leftover [6, 7, 8])
process1: [0, 1, 2]
==> (group leftover) process0: [3, 4, 5], (leftover [6], [7], [8])
process1: [0, 1, 2]
==> (distribute) process0: [3, 4, 5], (remove 6, 7, 8) (because distribute make batchsize 1)
process1: [0, 1, 2]
if group_size=3, num_replicas=2, eval mode:
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10
==> (extend) 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10
==> (group) [0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 10]
==> (distribute) process0: [0, 1, 2], [6, 7, 8],
process1: [3, 4, 5], [9, 10, 10]
"""
def __init__(self,
dataset,
group_size,
num_replicas=None,
rank=None,
mode='train',):
if num_replicas is None:
if not dist.is_available():
raise ValueError
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise ValueError
rank = dist.get_rank()
self.dataset = dataset
self.len_dataset = len(dataset)
self.group_size = group_size
self.num_replicas = num_replicas
self.rank = rank
self.mode = mode
len_dataset = self.len_dataset
if (len_dataset % num_replicas != 0) and (mode == 'train'):
# drop the non_aligned
aligned_indices = np.arange(len_dataset)[:-(len_dataset % num_replicas)]
aligned_len_dataset = aligned_indices.shape[0]
elif (len_dataset % num_replicas != 0) and (mode == 'eval'):
extend = np.array([len_dataset-1 for _ in range(num_replicas - len_dataset % num_replicas)])
aligned_indices = np.concatenate([range(len_dataset), extend])
aligned_len_dataset = aligned_indices.shape[0]
else:
aligned_indices = np.arange(len_dataset)
aligned_len_dataset = len_dataset
num_even_distributed_groups = aligned_len_dataset // (group_size * num_replicas)
num_even = num_even_distributed_groups * group_size * num_replicas
self.regular_groups = aligned_indices[0:num_even].reshape(-1, group_size)
self.leftover_groups = aligned_indices[num_even:].reshape(num_replicas, -1)
if self.leftover_groups.size == 0:
self.leftover_groups = None
elif (self.leftover_groups.shape[-1]==1) and (mode == 'train'):
# avoid bs=1
self.leftover_groups = None
# a urly way to modify dataset.load_info according to the grouping
for groupi in self.regular_groups:
for idx in groupi:
idx_lowerbd = groupi[0]
idx_upperbd = groupi[-1]
idx_reference = (idx_lowerbd+idx_upperbd)//2
dataset.load_info[idx]['ref_size'] = dataset.load_info[idx_reference]['image_size']
if self.leftover_groups is not None:
for groupi in self.leftover_groups:
for idx in groupi:
idx_lowerbd = groupi[0]
idx_upperbd = groupi[-1]
idx_reference = (idx_lowerbd+idx_upperbd)//2
dataset.load_info[idx]['ref_size'] = dataset.load_info[idx_reference]['image_size']
def concat(self, nparrays, axis=0):
# a helper for save concaternation
nparrays = [i for i in nparrays if i.size > 0]
return np.concatenate(nparrays, axis=axis)
def __iter__(self):
indices = self.get_sync_order()
return iter(indices)
def __len__(self):
return self.num_samples
def get_sync_order(self):
# g = torch.Generator()
# g.manual_seed(self.epoch)
mode = self.mode
rank = self.rank
num_replicas = self.num_replicas
group_size = self.group_size
num_groups = len(self.regular_groups)
if mode == 'train':
g_indices = torch.randperm(num_groups).to(rank)
dist.broadcast(g_indices, src=0)
g_indices = g_indices.to('cpu').tolist()
num_groups_per_rank = num_groups // num_replicas
groups = self.regular_groups[g_indices][num_groups_per_rank*rank : num_groups_per_rank*(rank+1)]
indices = groups.flatten()
if self.leftover_groups is not None:
leftg_indices = torch.randperm(len(self.leftover_groups)).to(rank)
dist.broadcast(leftg_indices, src=0)
leftg_indices = leftg_indices.to('cpu').tolist()
last = self.leftover_groups[leftg_indices][rank]
indices = np.concatenate([indices, last], axis=0)
elif mode == 'eval':
groups = self.regular_groups.reshape(-1, num_replicas, group_size)[:, rank, :]
indices = groups.flatten()
if self.leftover_groups is not None:
last = self.leftover_groups[rank]
indices = np.concatenate([indices, last], axis=0)
else:
raise ValueError
print_log('Sampler RANK {} : {}'.format(rank, str(indices[0:group_size+1])))
return indices