File size: 1,048 Bytes
d2e7940 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
import os
import pytest
import torch
from src.datamodules.mnist_datamodule import MNISTDataModule
@pytest.mark.parametrize("batch_size", [32, 128])
def test_mnist_datamodule(batch_size):
datamodule = MNISTDataModule(batch_size=batch_size)
datamodule.prepare_data()
assert not datamodule.data_train and not datamodule.data_val and not datamodule.data_test
assert os.path.exists(os.path.join("data", "MNIST"))
assert os.path.exists(os.path.join("data", "MNIST", "raw"))
datamodule.setup()
assert datamodule.data_train and datamodule.data_val and datamodule.data_test
assert (
len(datamodule.data_train) + len(datamodule.data_val) + len(datamodule.data_test) == 70_000
)
assert datamodule.train_dataloader()
assert datamodule.val_dataloader()
assert datamodule.test_dataloader()
batch = next(iter(datamodule.train_dataloader()))
x, y = batch
assert len(x) == batch_size
assert len(y) == batch_size
assert x.dtype == torch.float32
assert y.dtype == torch.int64
|