|
from functools import partial
|
|
import numpy as np
|
|
|
|
import torch
|
|
import pytorch_lightning as pl
|
|
from torch.utils.data import DataLoader, Dataset
|
|
|
|
import os, sys
|
|
os.chdir(sys.path[0])
|
|
sys.path.append("..")
|
|
from lvdm.data.base import Txt2ImgIterableBaseDataset
|
|
from utils.utils import instantiate_from_config
|
|
|
|
|
|
def worker_init_fn(_):
|
|
worker_info = torch.utils.data.get_worker_info()
|
|
|
|
dataset = worker_info.dataset
|
|
worker_id = worker_info.id
|
|
|
|
if isinstance(dataset, Txt2ImgIterableBaseDataset):
|
|
split_size = dataset.num_records // worker_info.num_workers
|
|
|
|
dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size]
|
|
current_id = np.random.choice(len(np.random.get_state()[1]), 1)
|
|
return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
|
|
else:
|
|
return np.random.seed(np.random.get_state()[1][0] + worker_id)
|
|
|
|
|
|
class WrappedDataset(Dataset):
|
|
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
|
|
|
|
def __init__(self, dataset):
|
|
self.data = dataset
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, idx):
|
|
return self.data[idx]
|
|
|
|
|
|
class DataModuleFromConfig(pl.LightningDataModule):
|
|
def __init__(self, batch_size, train=None, validation=None, test=None, predict=None,
|
|
wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False,
|
|
shuffle_val_dataloader=False, train_img=None,
|
|
test_max_n_samples=None):
|
|
super().__init__()
|
|
self.batch_size = batch_size
|
|
self.dataset_configs = dict()
|
|
self.num_workers = num_workers if num_workers is not None else batch_size * 2
|
|
self.use_worker_init_fn = use_worker_init_fn
|
|
if train is not None:
|
|
self.dataset_configs["train"] = train
|
|
self.train_dataloader = self._train_dataloader
|
|
if validation is not None:
|
|
self.dataset_configs["validation"] = validation
|
|
self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader)
|
|
if test is not None:
|
|
self.dataset_configs["test"] = test
|
|
self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader)
|
|
if predict is not None:
|
|
self.dataset_configs["predict"] = predict
|
|
self.predict_dataloader = self._predict_dataloader
|
|
|
|
self.img_loader = None
|
|
self.wrap = wrap
|
|
self.test_max_n_samples = test_max_n_samples
|
|
self.collate_fn = None
|
|
|
|
def prepare_data(self):
|
|
pass
|
|
|
|
def setup(self, stage=None):
|
|
self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
|
|
if self.wrap:
|
|
for k in self.datasets:
|
|
self.datasets[k] = WrappedDataset(self.datasets[k])
|
|
|
|
def _train_dataloader(self):
|
|
is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
|
|
if is_iterable_dataset or self.use_worker_init_fn:
|
|
init_fn = worker_init_fn
|
|
else:
|
|
init_fn = None
|
|
loader = DataLoader(self.datasets["train"], batch_size=self.batch_size,
|
|
num_workers=self.num_workers, shuffle=False if is_iterable_dataset else True,
|
|
worker_init_fn=init_fn, collate_fn=self.collate_fn,
|
|
)
|
|
return loader
|
|
|
|
def _val_dataloader(self, shuffle=False):
|
|
if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
|
|
init_fn = worker_init_fn
|
|
else:
|
|
init_fn = None
|
|
return DataLoader(self.datasets["validation"],
|
|
batch_size=self.batch_size,
|
|
num_workers=self.num_workers,
|
|
worker_init_fn=init_fn,
|
|
shuffle=shuffle,
|
|
collate_fn=self.collate_fn,
|
|
)
|
|
|
|
def _test_dataloader(self, shuffle=False):
|
|
try:
|
|
is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
|
|
except:
|
|
is_iterable_dataset = isinstance(self.datasets['test'], Txt2ImgIterableBaseDataset)
|
|
|
|
if is_iterable_dataset or self.use_worker_init_fn:
|
|
init_fn = worker_init_fn
|
|
else:
|
|
init_fn = None
|
|
|
|
|
|
shuffle = shuffle and (not is_iterable_dataset)
|
|
if self.test_max_n_samples is not None:
|
|
dataset = torch.utils.data.Subset(self.datasets["test"], list(range(self.test_max_n_samples)))
|
|
else:
|
|
dataset = self.datasets["test"]
|
|
return DataLoader(dataset, batch_size=self.batch_size,
|
|
num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle,
|
|
collate_fn=self.collate_fn,
|
|
)
|
|
|
|
def _predict_dataloader(self, shuffle=False):
|
|
if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
|
|
init_fn = worker_init_fn
|
|
else:
|
|
init_fn = None
|
|
return DataLoader(self.datasets["predict"], batch_size=self.batch_size,
|
|
num_workers=self.num_workers, worker_init_fn=init_fn,
|
|
collate_fn=self.collate_fn,
|
|
)
|
|
|