from typing import Optional, Tuple import torch from pytorch_lightning import LightningDataModule from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split from torchvision.datasets import MNIST from torchvision.transforms import transforms class MNISTDataModule(LightningDataModule): """ Example of LightningDataModule for MNIST dataset. A DataModule implements 5 key methods: - prepare_data (things to do on 1 GPU/TPU, not on every GPU/TPU in distributed mode) - setup (things to do on every accelerator in distributed mode) - train_dataloader (the training dataloader) - val_dataloader (the validation dataloader(s)) - test_dataloader (the test dataloader(s)) This allows you to share a full dataset without explaining how to download, split, transform and process the data. Read the docs: https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html """ def __init__( self, data_dir: str = "data/", train_val_test_split: Tuple[int, int, int] = (55_000, 5_000, 10_000), batch_size: int = 64, num_workers: int = 0, pin_memory: bool = False, ): super().__init__() # this line allows to access init params with 'self.hparams' attribute self.save_hyperparameters(logger=False) # data transformations self.transforms = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] ) self.data_train: Optional[Dataset] = None self.data_val: Optional[Dataset] = None self.data_test: Optional[Dataset] = None @property def num_classes(self) -> int: return 10 def prepare_data(self): """Download data if needed. This method is called only from a single GPU. Do not use it to assign state (self.x = y).""" MNIST(self.hparams.data_dir, train=True, download=True) MNIST(self.hparams.data_dir, train=False, download=True) def setup(self, stage: Optional[str] = None): """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. This method is called by lightning twice for `trainer.fit()` and `trainer.test()`, so be careful if you do a random split! The `stage` can be used to differentiate whether it's called before trainer.fit()` or `trainer.test()`.""" # load datasets only if they're not loaded already if not self.data_train and not self.data_val and not self.data_test: trainset = MNIST(self.hparams.data_dir, train=True, transform=self.transforms) testset = MNIST(self.hparams.data_dir, train=False, transform=self.transforms) dataset = ConcatDataset(datasets=[trainset, testset]) self.data_train, self.data_val, self.data_test = random_split( dataset=dataset, lengths=self.hparams.train_val_test_split, generator=torch.Generator().manual_seed(42), ) def train_dataloader(self): return DataLoader( dataset=self.data_train, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory, shuffle=True, ) def val_dataloader(self): return DataLoader( dataset=self.data_val, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory, shuffle=False, ) def test_dataloader(self): return DataLoader( dataset=self.data_test, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory, shuffle=False, )