File size: 658 Bytes
b72a776 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
from torchvision import datasets, transforms
from base import BaseDataLoader
class MnistDataLoader(BaseDataLoader):
"""
MNIST data loading demo using BaseDataLoader
"""
def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_workers=1, training=True):
trsfm = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
self.data_dir = data_dir
self.dataset = datasets.MNIST(self.data_dir, train=training, download=True, transform=trsfm)
super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers)
|