import torch from PIL import Image from torchvision.datasets import ImageFolder from torchvision.transforms.functional import to_tensor from torchvision.transforms import Normalize from src.data.dataset.metric_dataset import CenterCrop class LocalCachedDataset(ImageFolder): def __init__(self, root, resolution=256): super().__init__(root) self.transform = CenterCrop(resolution) self.cache_root = None def load_latent(self, latent_path): pk_data = torch.load(latent_path) mean = pk_data['mean'].to(torch.float32) logvar = pk_data['logvar'].to(torch.float32) logvar = torch.clamp(logvar, -30.0, 20.0) std = torch.exp(0.5 * logvar) latent = mean + torch.randn_like(mean) * std return latent def __getitem__(self, idx: int): image_path, target = self.samples[idx] latent_path = image_path.replace(self.root, self.cache_root) + ".pt" raw_image = Image.open(image_path).convert('RGB') raw_image = self.transform(raw_image) raw_image = to_tensor(raw_image) if self.cache_root is not None: latent = self.load_latent(latent_path) else: latent = raw_image return raw_image, latent, target class ImageNet256(LocalCachedDataset): def __init__(self, root, ): super().__init__(root, 256) self.cache_root = root + "_256_latent" class ImageNet512(LocalCachedDataset): def __init__(self, root, ): super().__init__(root, 512) self.cache_root = root + "_512_latent" class PixImageNet(ImageFolder): def __init__(self, root, resolution=256): super().__init__(root) self.transform = CenterCrop(resolution) self.normalize = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) def __getitem__(self, idx: int): image_path, target = self.samples[idx] raw_image = Image.open(image_path).convert('RGB') raw_image = self.transform(raw_image) raw_image = to_tensor(raw_image) normalized_image = self.normalize(raw_image) return raw_image, normalized_image, target class PixImageNet64(PixImageNet): def __init__(self, root, ): super().__init__(root, 64) class PixImageNet128(PixImageNet): def __init__(self, root, ): super().__init__(root, 128) class PixImageNet256(PixImageNet): def __init__(self, root, ): super().__init__(root, 256) class PixImageNet512(PixImageNet): def __init__(self, root, ): super().__init__(root, 512)