import math import torch from torch.utils.data import Sampler import torch.distributed as dist class DistributedEvalSampler(Sampler): r""" DistributedEvalSampler is different from DistributedSampler. It does NOT add extra samples to make it evenly divisible. DistributedEvalSampler should NOT be used for training. The distributed processes could hang forever. See this issue for details: https://github.com/pytorch/pytorch/issues/22584 shuffle is disabled by default DistributedEvalSampler is for evaluation purpose where synchronization does not happen every epoch. Synchronization should be done outside the dataloader loop. Sampler that restricts data loading to a subset of the dataset. It is especially useful in conjunction with :class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each process can pass a :class`~torch.utils.data.DistributedSampler` instance as a :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the original dataset that is exclusive to it. .. note:: Dataset is assumed to be of constant size. Arguments: dataset: Dataset used for sampling. num_replicas (int, optional): Number of processes participating in distributed training. By default, :attr:`rank` is retrieved from the current distributed group. rank (int, optional): Rank of the current process within :attr:`num_replicas`. By default, :attr:`rank` is retrieved from the current distributed group. shuffle (bool, optional): If ``True`` (default), sampler will shuffle the indices. seed (int, optional): random seed used to shuffle the sampler if :attr:`shuffle=True`. This number should be identical across all processes in the distributed group. Default: ``0``. .. warning:: In distributed mode, calling the :meth`set_epoch(epoch) ` method at the beginning of each epoch **before** creating the :class:`DataLoader` iterator is necessary to make shuffling work properly across multiple epochs. Otherwise, the same ordering will be always used. Example:: >>> sampler = DistributedSampler(dataset) if is_distributed else None >>> loader = DataLoader(dataset, shuffle=(sampler is None), ... sampler=sampler) >>> for epoch in range(start_epoch, n_epochs): ... if is_distributed: ... sampler.set_epoch(epoch) ... train(loader) """ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=False, seed=0): if num_replicas is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") num_replicas = dist.get_world_size() if rank is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") rank = dist.get_rank() self.dataset = dataset self.num_replicas = num_replicas self.rank = rank self.epoch = 0 # self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) # self.total_size = self.num_samples * self.num_replicas self.total_size = len(self.dataset) # true value without extra samples indices = list(range(self.total_size)) indices = indices[self.rank:self.total_size:self.num_replicas] self.num_samples = len(indices) # true value without extra samples self.shuffle = shuffle self.seed = seed def __iter__(self): if self.shuffle: # deterministically shuffle based on epoch and seed g = torch.Generator() g.manual_seed(self.seed + self.epoch) indices = torch.randperm(len(self.dataset), generator=g).tolist() else: indices = list(range(len(self.dataset))) # # add extra samples to make it evenly divisible # indices += indices[:(self.total_size - len(indices))] # assert len(indices) == self.total_size # subsample indices = indices[self.rank:self.total_size:self.num_replicas] assert len(indices) == self.num_samples return iter(indices) def __len__(self): return self.num_samples def set_epoch(self, epoch): r""" Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering. Arguments: epoch (int): _epoch number. """ self.epoch = epoch