|
import bisect |
|
import warnings |
|
|
|
from torch._utils import _accumulate |
|
from torch import randperm |
|
|
|
|
|
class Dataset(object): |
|
"""An abstract class representing a Dataset. |
|
|
|
All other datasets should subclass it. All subclasses should override |
|
``__len__``, that provides the size of the dataset, and ``__getitem__``, |
|
supporting integer indexing in range from 0 to len(self) exclusive. |
|
""" |
|
|
|
def __getitem__(self, index): |
|
raise NotImplementedError |
|
|
|
def __len__(self): |
|
raise NotImplementedError |
|
|
|
def __add__(self, other): |
|
return ConcatDataset([self, other]) |
|
|
|
|
|
class TensorDataset(Dataset): |
|
"""Dataset wrapping data and target tensors. |
|
|
|
Each sample will be retrieved by indexing both tensors along the first |
|
dimension. |
|
|
|
Arguments: |
|
data_tensor (Tensor): contains sample data. |
|
target_tensor (Tensor): contains sample targets (labels). |
|
""" |
|
|
|
def __init__(self, data_tensor, target_tensor): |
|
assert data_tensor.size(0) == target_tensor.size(0) |
|
self.data_tensor = data_tensor |
|
self.target_tensor = target_tensor |
|
|
|
def __getitem__(self, index): |
|
return self.data_tensor[index], self.target_tensor[index] |
|
|
|
def __len__(self): |
|
return self.data_tensor.size(0) |
|
|
|
|
|
class ConcatDataset(Dataset): |
|
""" |
|
Dataset to concatenate multiple datasets. |
|
Purpose: useful to assemble different existing datasets, possibly |
|
large-scale datasets as the concatenation operation is done in an |
|
on-the-fly manner. |
|
|
|
Arguments: |
|
datasets (iterable): List of datasets to be concatenated |
|
""" |
|
|
|
@staticmethod |
|
def cumsum(sequence): |
|
r, s = [], 0 |
|
for e in sequence: |
|
l = len(e) |
|
r.append(l + s) |
|
s += l |
|
return r |
|
|
|
def __init__(self, datasets): |
|
super(ConcatDataset, self).__init__() |
|
assert len(datasets) > 0, 'datasets should not be an empty iterable' |
|
self.datasets = list(datasets) |
|
self.cumulative_sizes = self.cumsum(self.datasets) |
|
|
|
def __len__(self): |
|
return self.cumulative_sizes[-1] |
|
|
|
def __getitem__(self, idx): |
|
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) |
|
if dataset_idx == 0: |
|
sample_idx = idx |
|
else: |
|
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] |
|
return self.datasets[dataset_idx][sample_idx] |
|
|
|
@property |
|
def cummulative_sizes(self): |
|
warnings.warn("cummulative_sizes attribute is renamed to " |
|
"cumulative_sizes", DeprecationWarning, stacklevel=2) |
|
return self.cumulative_sizes |
|
|
|
|
|
class Subset(Dataset): |
|
def __init__(self, dataset, indices): |
|
self.dataset = dataset |
|
self.indices = indices |
|
|
|
def __getitem__(self, idx): |
|
return self.dataset[self.indices[idx]] |
|
|
|
def __len__(self): |
|
return len(self.indices) |
|
|
|
|
|
def random_split(dataset, lengths): |
|
""" |
|
Randomly split a dataset into non-overlapping new datasets of given lengths |
|
ds |
|
|
|
Arguments: |
|
dataset (Dataset): Dataset to be split |
|
lengths (iterable): lengths of splits to be produced |
|
""" |
|
if sum(lengths) != len(dataset): |
|
raise ValueError("Sum of input lengths does not equal the length of the input dataset!") |
|
|
|
indices = randperm(sum(lengths)) |
|
return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)] |
|
|