|
|
|
import numpy as np
|
|
from torch.utils.data.sampler import BatchSampler, Sampler
|
|
|
|
|
|
class GroupedBatchSampler(BatchSampler):
|
|
"""
|
|
Wraps another sampler to yield a mini-batch of indices.
|
|
It enforces that the batch only contain elements from the same group.
|
|
It also tries to provide mini-batches which follows an ordering which is
|
|
as close as possible to the ordering from the original sampler.
|
|
"""
|
|
|
|
def __init__(self, sampler, group_ids, batch_size):
|
|
"""
|
|
Args:
|
|
sampler (Sampler): Base sampler.
|
|
group_ids (list[int]): If the sampler produces indices in range [0, N),
|
|
`group_ids` must be a list of `N` ints which contains the group id of each sample.
|
|
The group ids must be a set of integers in the range [0, num_groups).
|
|
batch_size (int): Size of mini-batch.
|
|
"""
|
|
if not isinstance(sampler, Sampler):
|
|
raise ValueError(
|
|
"sampler should be an instance of "
|
|
"torch.utils.data.Sampler, but got sampler={}".format(sampler)
|
|
)
|
|
self.sampler = sampler
|
|
self.group_ids = np.asarray(group_ids)
|
|
assert self.group_ids.ndim == 1
|
|
self.batch_size = batch_size
|
|
groups = np.unique(self.group_ids).tolist()
|
|
|
|
|
|
self.buffer_per_group = {k: [] for k in groups}
|
|
|
|
def __iter__(self):
|
|
for idx in self.sampler:
|
|
group_id = self.group_ids[idx]
|
|
group_buffer = self.buffer_per_group[group_id]
|
|
group_buffer.append(idx)
|
|
if len(group_buffer) == self.batch_size:
|
|
yield group_buffer[:]
|
|
del group_buffer[:]
|
|
|
|
def __len__(self):
|
|
raise NotImplementedError("len() of GroupedBatchSampler is not well-defined.")
|
|
|