Spaces:
Running
on
Zero
Running
on
Zero
from typing import Any, Dict, Optional | |
from pytorch_lightning import LightningDataModule | |
from torch.utils.data import DataLoader | |
from omegaconf import DictConfig | |
import sys | |
import pyrootutils | |
root = pyrootutils.setup_root(__file__, pythonpath=True) | |
sys.path.append(root) | |
from sgm.data.video_dataset_latent import VideoDataset | |
class VideoDataModule(LightningDataModule): | |
""" | |
A DataModule implements 5 key methods: | |
def prepare_data(self): | |
# things to do on 1 GPU/TPU (not on every GPU/TPU in DDP) | |
# download data, pre-process, split, save to disk, etc... | |
def setup(self, stage): | |
# things to do on every process in DDP | |
# load data, set variables, etc... | |
def train_dataloader(self): | |
# return train dataloader | |
def val_dataloader(self): | |
# return validation dataloader | |
def test_dataloader(self): | |
# return test dataloader | |
def teardown(self): | |
# called on every process in DDP | |
# clean up after fit or test | |
This allows you to share a full dataset without explaining how to download, | |
split, transform and process the data. | |
Read the docs: | |
https://pytorch-lightning.readthedocs.io/en/latest/data/datamodule.html | |
""" | |
def __init__( | |
self, | |
train: DictConfig, | |
validation: Optional[DictConfig] = None, | |
test: Optional[DictConfig] = None, | |
skip_val_loader: bool = False, | |
): | |
super().__init__() | |
# this line allows to access init params with 'self.hparams' attribute | |
# also ensures init params will be stored in ckpt | |
self.train_config = train | |
assert "datapipeline" in self.train_config and "loader" in self.train_config, ( | |
"train config requires the fields `datapipeline` and `loader`" | |
) | |
self.val_config = validation | |
if not skip_val_loader: | |
if self.val_config is not None: | |
assert ( | |
"datapipeline" in self.val_config and "loader" in self.val_config | |
), "validation config requires the fields `datapipeline` and `loader`" | |
else: | |
print( | |
"Warning: No Validation datapipeline defined, using that one from training" | |
) | |
self.val_config = train | |
self.test_config = test | |
if self.test_config is not None: | |
assert ( | |
"datapipeline" in self.test_config and "loader" in self.test_config | |
), "test config requires the fields `datapipeline` and `loader`" | |
def setup(self, stage: Optional[str] = None): | |
"""Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. | |
This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be | |
careful not to execute things like random split twice! | |
""" | |
print("Preparing datasets") | |
self.train_datapipeline = VideoDataset(**self.train_config.datapipeline) | |
if self.val_config: | |
self.val_datapipeline = VideoDataset(**self.val_config.datapipeline) | |
if self.test_config: | |
self.test_datapipeline = VideoDataset(**self.test_config.datapipeline) | |
def train_dataloader(self): | |
return DataLoader(self.train_datapipeline, **self.train_config.loader) | |
def val_dataloader(self): | |
if self.val_datapipeline: | |
return DataLoader(self.val_datapipeline, **self.val_config.loader) | |
else: | |
return None | |
def test_dataloader(self): | |
if self.test_datapipeline: | |
return DataLoader(self.test_datapipeline, **self.test_config.loader) | |
else: | |
return None | |
def teardown(self, stage: Optional[str] = None): | |
"""Clean up after fit or test.""" | |
pass | |
def state_dict(self): | |
"""Extra things to save to checkpoint.""" | |
return {} | |
def load_state_dict(self, state_dict: Dict[str, Any]): | |
"""Things to do when loading checkpoint.""" | |
pass | |
if __name__ == "__main__": | |
import hydra | |
import omegaconf | |
import pyrootutils | |
import cv2 | |
root = pyrootutils.setup_root(__file__, pythonpath=True) | |
cfg = omegaconf.OmegaConf.load( | |
root / "configs" / "datamodule" / "image_datamodule.yaml" | |
) | |
# cfg.data_dir = str(root / "data") | |
data = hydra.utils.instantiate(cfg) | |
data.prepare_data() | |
data.setup() | |
print(data.data_train.__getitem__(0)[0].shape) | |
batch = next(iter(data.train_dataloader())) | |
identity, target = batch | |
image_identity = (identity[0].permute(1, 2, 0).numpy() + 1) / 2 * 255 | |
image_other = (target[0].permute(1, 2, 0).numpy() + 1) / 2 * 255 | |
cv2.imwrite("image_identity.png", image_identity[:, :, ::-1]) | |
cv2.imwrite("image_other.png", image_other[:, :, ::-1]) | |