from typing import Any, Dict, Optional from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader from omegaconf import DictConfig import sys import pyrootutils root = pyrootutils.setup_root(__file__, pythonpath=True) sys.path.append(root) from sgm.data.video_dataset_latent import VideoDataset class VideoDataModule(LightningDataModule): """ A DataModule implements 5 key methods: def prepare_data(self): # things to do on 1 GPU/TPU (not on every GPU/TPU in DDP) # download data, pre-process, split, save to disk, etc... def setup(self, stage): # things to do on every process in DDP # load data, set variables, etc... def train_dataloader(self): # return train dataloader def val_dataloader(self): # return validation dataloader def test_dataloader(self): # return test dataloader def teardown(self): # called on every process in DDP # clean up after fit or test This allows you to share a full dataset without explaining how to download, split, transform and process the data. Read the docs: https://pytorch-lightning.readthedocs.io/en/latest/data/datamodule.html """ def __init__( self, train: DictConfig, validation: Optional[DictConfig] = None, test: Optional[DictConfig] = None, skip_val_loader: bool = False, ): super().__init__() # this line allows to access init params with 'self.hparams' attribute # also ensures init params will be stored in ckpt self.train_config = train assert "datapipeline" in self.train_config and "loader" in self.train_config, ( "train config requires the fields `datapipeline` and `loader`" ) self.val_config = validation if not skip_val_loader: if self.val_config is not None: assert ( "datapipeline" in self.val_config and "loader" in self.val_config ), "validation config requires the fields `datapipeline` and `loader`" else: print( "Warning: No Validation datapipeline defined, using that one from training" ) self.val_config = train self.test_config = test if self.test_config is not None: assert ( "datapipeline" in self.test_config and "loader" in self.test_config ), "test config requires the fields `datapipeline` and `loader`" def setup(self, stage: Optional[str] = None): """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be careful not to execute things like random split twice! """ print("Preparing datasets") self.train_datapipeline = VideoDataset(**self.train_config.datapipeline) if self.val_config: self.val_datapipeline = VideoDataset(**self.val_config.datapipeline) if self.test_config: self.test_datapipeline = VideoDataset(**self.test_config.datapipeline) def train_dataloader(self): return DataLoader(self.train_datapipeline, **self.train_config.loader) def val_dataloader(self): if self.val_datapipeline: return DataLoader(self.val_datapipeline, **self.val_config.loader) else: return None def test_dataloader(self): if self.test_datapipeline: return DataLoader(self.test_datapipeline, **self.test_config.loader) else: return None def teardown(self, stage: Optional[str] = None): """Clean up after fit or test.""" pass def state_dict(self): """Extra things to save to checkpoint.""" return {} def load_state_dict(self, state_dict: Dict[str, Any]): """Things to do when loading checkpoint.""" pass if __name__ == "__main__": import hydra import omegaconf import pyrootutils import cv2 root = pyrootutils.setup_root(__file__, pythonpath=True) cfg = omegaconf.OmegaConf.load( root / "configs" / "datamodule" / "image_datamodule.yaml" ) # cfg.data_dir = str(root / "data") data = hydra.utils.instantiate(cfg) data.prepare_data() data.setup() print(data.data_train.__getitem__(0)[0].shape) batch = next(iter(data.train_dataloader())) identity, target = batch image_identity = (identity[0].permute(1, 2, 0).numpy() + 1) / 2 * 255 image_other = (target[0].permute(1, 2, 0).numpy() + 1) / 2 * 255 cv2.imwrite("image_identity.png", image_identity[:, :, ::-1]) cv2.imwrite("image_other.png", image_other[:, :, ::-1])