|
from typing import Optional, Tuple |
|
|
|
import torch |
|
from pytorch_lightning import LightningDataModule |
|
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split |
|
from torchvision.datasets import MNIST |
|
from torchvision.transforms import transforms |
|
|
|
|
|
class MNISTDataModule(LightningDataModule): |
|
""" |
|
Example of LightningDataModule for MNIST dataset. |
|
|
|
A DataModule implements 5 key methods: |
|
- prepare_data (things to do on 1 GPU/TPU, not on every GPU/TPU in distributed mode) |
|
- setup (things to do on every accelerator in distributed mode) |
|
- train_dataloader (the training dataloader) |
|
- val_dataloader (the validation dataloader(s)) |
|
- test_dataloader (the test dataloader(s)) |
|
|
|
This allows you to share a full dataset without explaining how to download, |
|
split, transform and process the data. |
|
|
|
Read the docs: |
|
https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html |
|
""" |
|
|
|
def __init__( |
|
self, |
|
data_dir: str = "data/", |
|
train_val_test_split: Tuple[int, int, int] = (55_000, 5_000, 10_000), |
|
batch_size: int = 64, |
|
num_workers: int = 0, |
|
pin_memory: bool = False, |
|
): |
|
super().__init__() |
|
|
|
|
|
self.save_hyperparameters(logger=False) |
|
|
|
|
|
self.transforms = transforms.Compose( |
|
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] |
|
) |
|
|
|
self.data_train: Optional[Dataset] = None |
|
self.data_val: Optional[Dataset] = None |
|
self.data_test: Optional[Dataset] = None |
|
|
|
@property |
|
def num_classes(self) -> int: |
|
return 10 |
|
|
|
def prepare_data(self): |
|
"""Download data if needed. This method is called only from a single GPU. |
|
Do not use it to assign state (self.x = y).""" |
|
MNIST(self.hparams.data_dir, train=True, download=True) |
|
MNIST(self.hparams.data_dir, train=False, download=True) |
|
|
|
def setup(self, stage: Optional[str] = None): |
|
"""Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. |
|
This method is called by lightning twice for `trainer.fit()` and `trainer.test()`, so be careful if you do a random split! |
|
The `stage` can be used to differentiate whether it's called before trainer.fit()` or `trainer.test()`.""" |
|
|
|
|
|
if not self.data_train and not self.data_val and not self.data_test: |
|
trainset = MNIST(self.hparams.data_dir, train=True, transform=self.transforms) |
|
testset = MNIST(self.hparams.data_dir, train=False, transform=self.transforms) |
|
dataset = ConcatDataset(datasets=[trainset, testset]) |
|
self.data_train, self.data_val, self.data_test = random_split( |
|
dataset=dataset, |
|
lengths=self.hparams.train_val_test_split, |
|
generator=torch.Generator().manual_seed(42), |
|
) |
|
|
|
def train_dataloader(self): |
|
return DataLoader( |
|
dataset=self.data_train, |
|
batch_size=self.hparams.batch_size, |
|
num_workers=self.hparams.num_workers, |
|
pin_memory=self.hparams.pin_memory, |
|
shuffle=True, |
|
) |
|
|
|
def val_dataloader(self): |
|
return DataLoader( |
|
dataset=self.data_val, |
|
batch_size=self.hparams.batch_size, |
|
num_workers=self.hparams.num_workers, |
|
pin_memory=self.hparams.pin_memory, |
|
shuffle=False, |
|
) |
|
|
|
def test_dataloader(self): |
|
return DataLoader( |
|
dataset=self.data_test, |
|
batch_size=self.hparams.batch_size, |
|
num_workers=self.hparams.num_workers, |
|
pin_memory=self.hparams.pin_memory, |
|
shuffle=False, |
|
) |
|
|