|
import math |
|
import random |
|
from typing import Callable, List, Union |
|
|
|
from torch.utils.data.sampler import BatchSampler, Sampler, SubsetRandomSampler |
|
|
|
|
|
class SubsetSampler(Sampler): |
|
""" |
|
Samples elements sequentially from a given list of indices. |
|
|
|
Args: |
|
indices (list): a sequence of indices |
|
""" |
|
|
|
def __init__(self, indices): |
|
super().__init__(indices) |
|
self.indices = indices |
|
|
|
def __iter__(self): |
|
return (self.indices[i] for i in range(len(self.indices))) |
|
|
|
def __len__(self): |
|
return len(self.indices) |
|
|
|
|
|
class PerfectBatchSampler(Sampler): |
|
""" |
|
Samples a mini-batch of indices for a balanced class batching |
|
|
|
Args: |
|
dataset_items(list): dataset items to sample from. |
|
classes (list): list of classes of dataset_items to sample from. |
|
batch_size (int): total number of samples to be sampled in a mini-batch. |
|
num_gpus (int): number of GPU in the data parallel mode. |
|
shuffle (bool): if True, samples randomly, otherwise samples sequentially. |
|
drop_last (bool): if True, drops last incomplete batch. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dataset_items, |
|
classes, |
|
batch_size, |
|
num_classes_in_batch, |
|
num_gpus=1, |
|
shuffle=True, |
|
drop_last=False, |
|
label_key="class_name", |
|
): |
|
super().__init__(dataset_items) |
|
assert ( |
|
batch_size % (num_classes_in_batch * num_gpus) == 0 |
|
), "Batch size must be divisible by number of classes times the number of data parallel devices (if enabled)." |
|
|
|
label_indices = {} |
|
for idx, item in enumerate(dataset_items): |
|
label = item[label_key] |
|
if label not in label_indices.keys(): |
|
label_indices[label] = [idx] |
|
else: |
|
label_indices[label].append(idx) |
|
|
|
if shuffle: |
|
self._samplers = [SubsetRandomSampler(label_indices[key]) for key in classes] |
|
else: |
|
self._samplers = [SubsetSampler(label_indices[key]) for key in classes] |
|
|
|
self._batch_size = batch_size |
|
self._drop_last = drop_last |
|
self._dp_devices = num_gpus |
|
self._num_classes_in_batch = num_classes_in_batch |
|
|
|
def __iter__(self): |
|
batch = [] |
|
if self._num_classes_in_batch != len(self._samplers): |
|
valid_samplers_idx = random.sample(range(len(self._samplers)), self._num_classes_in_batch) |
|
else: |
|
valid_samplers_idx = None |
|
|
|
iters = [iter(s) for s in self._samplers] |
|
done = False |
|
|
|
while True: |
|
b = [] |
|
for i, it in enumerate(iters): |
|
if valid_samplers_idx is not None and i not in valid_samplers_idx: |
|
continue |
|
idx = next(it, None) |
|
if idx is None: |
|
done = True |
|
break |
|
b.append(idx) |
|
if done: |
|
break |
|
batch += b |
|
if len(batch) == self._batch_size: |
|
yield batch |
|
batch = [] |
|
if valid_samplers_idx is not None: |
|
valid_samplers_idx = random.sample(range(len(self._samplers)), self._num_classes_in_batch) |
|
|
|
if not self._drop_last: |
|
if len(batch) > 0: |
|
groups = len(batch) // self._num_classes_in_batch |
|
if groups % self._dp_devices == 0: |
|
yield batch |
|
else: |
|
batch = batch[: (groups // self._dp_devices) * self._dp_devices * self._num_classes_in_batch] |
|
if len(batch) > 0: |
|
yield batch |
|
|
|
def __len__(self): |
|
class_batch_size = self._batch_size // self._num_classes_in_batch |
|
return min(((len(s) + class_batch_size - 1) // class_batch_size) for s in self._samplers) |
|
|
|
|
|
def identity(x): |
|
return x |
|
|
|
|
|
class SortedSampler(Sampler): |
|
"""Samples elements sequentially, always in the same order. |
|
|
|
Taken from https://github.com/PetrochukM/PyTorch-NLP |
|
|
|
Args: |
|
data (iterable): Iterable data. |
|
sort_key (callable): Specifies a function of one argument that is used to extract a |
|
numerical comparison key from each list element. |
|
|
|
Example: |
|
>>> list(SortedSampler(range(10), sort_key=lambda i: -i)) |
|
[9, 8, 7, 6, 5, 4, 3, 2, 1, 0] |
|
|
|
""" |
|
|
|
def __init__(self, data, sort_key: Callable = identity): |
|
super().__init__(data) |
|
self.data = data |
|
self.sort_key = sort_key |
|
zip_ = [(i, self.sort_key(row)) for i, row in enumerate(self.data)] |
|
zip_ = sorted(zip_, key=lambda r: r[1]) |
|
self.sorted_indexes = [item[0] for item in zip_] |
|
|
|
def __iter__(self): |
|
return iter(self.sorted_indexes) |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
|
|
class BucketBatchSampler(BatchSampler): |
|
"""Bucket batch sampler |
|
|
|
Adapted from https://github.com/PetrochukM/PyTorch-NLP |
|
|
|
Args: |
|
sampler (torch.data.utils.sampler.Sampler): |
|
batch_size (int): Size of mini-batch. |
|
drop_last (bool): If `True` the sampler will drop the last batch if its size would be less |
|
than `batch_size`. |
|
data (list): List of data samples. |
|
sort_key (callable, optional): Callable to specify a comparison key for sorting. |
|
bucket_size_multiplier (int, optional): Buckets are of size |
|
`batch_size * bucket_size_multiplier`. |
|
|
|
Example: |
|
>>> sampler = WeightedRandomSampler(weights, len(weights)) |
|
>>> sampler = BucketBatchSampler(sampler, data=data_items, batch_size=32, drop_last=True) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
sampler, |
|
data, |
|
batch_size, |
|
drop_last, |
|
sort_key: Union[Callable, List] = identity, |
|
bucket_size_multiplier=100, |
|
): |
|
super().__init__(sampler, batch_size, drop_last) |
|
self.data = data |
|
self.sort_key = sort_key |
|
_bucket_size = batch_size * bucket_size_multiplier |
|
if hasattr(sampler, "__len__"): |
|
_bucket_size = min(_bucket_size, len(sampler)) |
|
self.bucket_sampler = BatchSampler(sampler, _bucket_size, False) |
|
|
|
def __iter__(self): |
|
for idxs in self.bucket_sampler: |
|
bucket_data = [self.data[idx] for idx in idxs] |
|
sorted_sampler = SortedSampler(bucket_data, self.sort_key) |
|
for batch_idx in SubsetRandomSampler(list(BatchSampler(sorted_sampler, self.batch_size, self.drop_last))): |
|
sorted_idxs = [idxs[i] for i in batch_idx] |
|
yield sorted_idxs |
|
|
|
def __len__(self): |
|
if self.drop_last: |
|
return len(self.sampler) // self.batch_size |
|
return math.ceil(len(self.sampler) / self.batch_size) |
|
|