|
import torch |
|
from torch.utils.data import DistributedSampler |
|
from torch.utils.data import Dataset, Sampler |
|
from torch.utils.data import RandomSampler |
|
from operator import itemgetter |
|
from typing import List, Union, Iterator, Optional |
|
|
|
|
|
class DatasetFromSampler(Dataset): |
|
"""Dataset to create indexes from `Sampler`. From catalyst library. |
|
|
|
Args: |
|
sampler: PyTorch sampler |
|
""" |
|
|
|
def __init__(self, sampler: Sampler): |
|
"""Initialisation for DatasetFromSampler.""" |
|
self.sampler = sampler |
|
self.sampler_list = None |
|
|
|
def __getitem__(self, index: int): |
|
"""Gets element of the dataset. |
|
|
|
Args: |
|
index: index of the element in the dataset |
|
|
|
Returns: |
|
Single element by index |
|
""" |
|
if self.sampler_list is None: |
|
self.sampler_list = list(self.sampler) |
|
return self.sampler_list[index] |
|
|
|
def __len__(self) -> int: |
|
""" |
|
Returns: |
|
int: length of the dataset |
|
""" |
|
return len(self.sampler) |
|
|
|
|
|
class DistributedSamplerWrapper(DistributedSampler): |
|
""" |
|
Wrapper over `Sampler` for distributed training. |
|
Allows you to use any sampler in distributed mode. |
|
From https://github.com/catalyst-team/catalyst/blob/master/catalyst/data/sampler.py |
|
|
|
It is especially useful in conjunction with |
|
`torch.nn.parallel.DistributedDataParallel`. In such case, each |
|
process can pass a DistributedSamplerWrapper instance as a DataLoader |
|
sampler, and load a subset of subsampled data of the original dataset |
|
that is exclusive to it. |
|
|
|
.. note:: |
|
Sampler is assumed to be of constant size. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
sampler, |
|
num_replicas: Optional[int] = None, |
|
rank: Optional[int] = None, |
|
shuffle: bool = True, |
|
): |
|
""" |
|
|
|
Args: |
|
sampler: Sampler used for subsampling |
|
num_replicas (int, optional): Number of processes participating in |
|
distributed training |
|
rank (int, optional): Rank of the current process |
|
within ``num_replicas`` |
|
shuffle (bool, optional): If true (default), |
|
sampler will shuffle the indices |
|
""" |
|
super(DistributedSamplerWrapper, self).__init__( |
|
DatasetFromSampler(sampler), |
|
num_replicas=num_replicas, |
|
rank=rank, |
|
shuffle=shuffle, |
|
) |
|
self.sampler = sampler |
|
|
|
def __iter__(self) -> Iterator[int]: |
|
"""Iterate over sampler. |
|
|
|
Returns: |
|
python iterator |
|
""" |
|
self.dataset = DatasetFromSampler(self.sampler) |
|
indexes_of_indexes = super().__iter__() |
|
subsampler_indexes = self.dataset |
|
return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes)) |
|
|
|
|
|
class UnimaxSampler(Sampler): |
|
|
|
|
|
def __init__(self, language_character_counts: List[int], total_character_budget: int, |
|
num_epochs: int) -> None: |
|
self.language_character_counts = torch.tensor(language_character_counts) |
|
self.total_character_budget = total_character_budget |
|
self.num_epochs = num_epochs |
|
|
|
self.p = self._unimax() |
|
|
|
|
|
|
|
def __iter__(self) -> iter: |
|
return iter(torch.multinomial(self.p, len(self.p), replacement=True).tolist()) |
|
|
|
|
|
def __len__(self) -> int: |
|
return len(self.p) |
|
|
|
|
|
def _unimax(self) -> torch.Tensor: |
|
|
|
L, indices = torch.sort(self.language_character_counts) |
|
|
|
B = float(self.total_character_budget) |
|
i = 0 |
|
|
|
U = torch.zeros_like(L) |
|
|
|
for idx in indices: |
|
|
|
bl = B / (len(L) - i) |
|
cl = L[idx] |
|
|
|
if bl > cl * self.num_epochs: |
|
Ul = cl * self.num_epochs |
|
|
|
else: |
|
Ul = bl |
|
|
|
U[idx] = Ul |
|
|
|
B -= Ul |
|
|
|
i += 1 |
|
|
|
p = U / U.sum() |
|
|
|
return p |
|
|
|
|
|
class DistributedUnimaxSampler(UnimaxSampler): |
|
|
|
def __init__(self, |
|
language_character_counts: List[int], |
|
total_character_budget: int, |
|
num_epochs: int, |
|
num_replicas: Optional[int] = None, |
|
rank: Optional[int] = None, |
|
shuffle: bool = True) -> None: |
|
|
|
super().__init__(language_character_counts, total_character_budget, num_epochs) |
|
self.distributed_sampler = DistributedSamplerWrapper(self, num_replicas, rank, shuffle) |
|
|
|
def __iter__(self): |
|
return iter(self.distributed_sampler) |
|
|
|
def __len__(self): |
|
return len(self.distributed_sampler) |
|
|
|
def set_epoch(self, epoch): |
|
self.distributed_sampler.set_epoch(epoch) |