|
from torchvision import datasets, transforms |
|
from base import BaseDataLoader |
|
|
|
|
|
class MnistDataLoader(BaseDataLoader): |
|
""" |
|
MNIST data loading demo using BaseDataLoader |
|
""" |
|
def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_workers=1, training=True): |
|
trsfm = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.1307,), (0.3081,)) |
|
]) |
|
self.data_dir = data_dir |
|
self.dataset = datasets.MNIST(self.data_dir, train=training, download=True, transform=trsfm) |
|
super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers) |
|
|