File size: 857 Bytes
803ef9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
import os
import copy
from PIL import Image
import numpy as np
import torch
import torch.utils.data as data
from torchvision import transforms, datasets
DATA_ROOTS = 'data'
class MNIST(data.Dataset):
def __init__(self, root=DATA_ROOTS, train=True, image_transforms=None):
super().__init__()
if not os.path.isdir(root):
os.makedirs(root)
self.image_transforms = image_transforms
self.dataset = datasets.mnist.MNIST(root, train=train, download=True)
def __getitem__(self, index):
img, target = self.dataset.data[index], int(self.dataset.targets[index])
img = Image.fromarray(img.numpy(), mode='L').convert('RGB')
if self.image_transforms is not None:
img = self.image_transforms(img)
return img, target
def __len__(self):
return len(self.dataset) |