diffae / dataset.py
qninhdt's picture
Upload 68 files
1ab03a3 verified
import os
from io import BytesIO
from pathlib import Path
import lmdb
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.datasets import CIFAR10, LSUNClass
import torch
import pandas as pd
import torchvision.transforms.functional as Ftrans
class ImageDataset(Dataset):
def __init__(
self,
folder,
image_size,
exts=['jpg'],
do_augment: bool = True,
do_transform: bool = True,
do_normalize: bool = True,
sort_names=False,
has_subdir: bool = True,
):
super().__init__()
self.folder = folder
self.image_size = image_size
# relative paths (make it shorter, saves memory and faster to sort)
if has_subdir:
self.paths = [
p.relative_to(folder) for ext in exts
for p in Path(f'{folder}').glob(f'**/*.{ext}')
]
else:
self.paths = [
p.relative_to(folder) for ext in exts
for p in Path(f'{folder}').glob(f'*.{ext}')
]
if sort_names:
self.paths = sorted(self.paths)
transform = [
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
]
if do_augment:
transform.append(transforms.RandomHorizontalFlip())
if do_transform:
transform.append(transforms.ToTensor())
if do_normalize:
transform.append(
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
self.transform = transforms.Compose(transform)
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
path = os.path.join(self.folder, self.paths[index])
img = Image.open(path)
# if the image is 'rgba'!
img = img.convert('RGB')
if self.transform is not None:
img = self.transform(img)
return {'img': img, 'index': index}
class SubsetDataset(Dataset):
def __init__(self, dataset, size):
assert len(dataset) >= size
self.dataset = dataset
self.size = size
def __len__(self):
return self.size
def __getitem__(self, index):
assert index < self.size
return self.dataset[index]
class BaseLMDB(Dataset):
def __init__(self, path, original_resolution, zfill: int = 5):
self.original_resolution = original_resolution
self.zfill = zfill
self.env = lmdb.open(
path,
max_readers=32,
readonly=True,
lock=False,
readahead=False,
meminit=False,
)
if not self.env:
raise IOError('Cannot open lmdb dataset', path)
with self.env.begin(write=False) as txn:
self.length = int(
txn.get('length'.encode('utf-8')).decode('utf-8'))
def __len__(self):
return self.length
def __getitem__(self, index):
with self.env.begin(write=False) as txn:
key = f'{self.original_resolution}-{str(index).zfill(self.zfill)}'.encode(
'utf-8')
img_bytes = txn.get(key)
buffer = BytesIO(img_bytes)
img = Image.open(buffer)
return img
def make_transform(
image_size,
flip_prob=0.5,
crop_d2c=False,
):
if crop_d2c:
transform = [
d2c_crop(),
transforms.Resize(image_size),
]
else:
transform = [
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
]
transform.append(transforms.RandomHorizontalFlip(p=flip_prob))
transform.append(transforms.ToTensor())
transform.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
transform = transforms.Compose(transform)
return transform
class FFHQlmdb(Dataset):
def __init__(self,
path=os.path.expanduser('datasets/ffhq256.lmdb'),
image_size=256,
original_resolution=256,
split=None,
as_tensor: bool = True,
do_augment: bool = True,
do_normalize: bool = True,
**kwargs):
self.original_resolution = original_resolution
self.data = BaseLMDB(path, original_resolution, zfill=5)
self.length = len(self.data)
if split is None:
self.offset = 0
elif split == 'train':
# last 60k
self.length = self.length - 10000
self.offset = 10000
elif split == 'test':
# first 10k
self.length = 10000
self.offset = 0
else:
raise NotImplementedError()
transform = [
transforms.Resize(image_size),
]
if do_augment:
transform.append(transforms.RandomHorizontalFlip())
if as_tensor:
transform.append(transforms.ToTensor())
if do_normalize:
transform.append(
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
self.transform = transforms.Compose(transform)
def __len__(self):
return self.length
def __getitem__(self, index):
assert index < self.length
index = index + self.offset
img = self.data[index]
if self.transform is not None:
img = self.transform(img)
return {'img': img, 'index': index}
class Crop:
def __init__(self, x1, x2, y1, y2):
self.x1 = x1
self.x2 = x2
self.y1 = y1
self.y2 = y2
def __call__(self, img):
return Ftrans.crop(img, self.x1, self.y1, self.x2 - self.x1,
self.y2 - self.y1)
def __repr__(self):
return self.__class__.__name__ + "(x1={}, x2={}, y1={}, y2={})".format(
self.x1, self.x2, self.y1, self.y2)
def d2c_crop():
# from D2C paper for CelebA dataset.
cx = 89
cy = 121
x1 = cy - 64
x2 = cy + 64
y1 = cx - 64
y2 = cx + 64
return Crop(x1, x2, y1, y2)
class CelebAlmdb(Dataset):
"""
also supports for d2c crop.
"""
def __init__(self,
path,
image_size,
original_resolution=128,
split=None,
as_tensor: bool = True,
do_augment: bool = True,
do_normalize: bool = True,
crop_d2c: bool = False,
**kwargs):
self.original_resolution = original_resolution
self.data = BaseLMDB(path, original_resolution, zfill=7)
self.length = len(self.data)
self.crop_d2c = crop_d2c
if split is None:
self.offset = 0
else:
raise NotImplementedError()
if crop_d2c:
transform = [
d2c_crop(),
transforms.Resize(image_size),
]
else:
transform = [
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
]
if do_augment:
transform.append(transforms.RandomHorizontalFlip())
if as_tensor:
transform.append(transforms.ToTensor())
if do_normalize:
transform.append(
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
self.transform = transforms.Compose(transform)
def __len__(self):
return self.length
def __getitem__(self, index):
assert index < self.length
index = index + self.offset
img = self.data[index]
if self.transform is not None:
img = self.transform(img)
return {'img': img, 'index': index}
class Horse_lmdb(Dataset):
def __init__(self,
path=os.path.expanduser('datasets/horse256.lmdb'),
image_size=128,
original_resolution=256,
do_augment: bool = True,
do_transform: bool = True,
do_normalize: bool = True,
**kwargs):
self.original_resolution = original_resolution
print(path)
self.data = BaseLMDB(path, original_resolution, zfill=7)
self.length = len(self.data)
transform = [
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
]
if do_augment:
transform.append(transforms.RandomHorizontalFlip())
if do_transform:
transform.append(transforms.ToTensor())
if do_normalize:
transform.append(
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
self.transform = transforms.Compose(transform)
def __len__(self):
return self.length
def __getitem__(self, index):
img = self.data[index]
if self.transform is not None:
img = self.transform(img)
return {'img': img, 'index': index}
class Bedroom_lmdb(Dataset):
def __init__(self,
path=os.path.expanduser('datasets/bedroom256.lmdb'),
image_size=128,
original_resolution=256,
do_augment: bool = True,
do_transform: bool = True,
do_normalize: bool = True,
**kwargs):
self.original_resolution = original_resolution
print(path)
self.data = BaseLMDB(path, original_resolution, zfill=7)
self.length = len(self.data)
transform = [
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
]
if do_augment:
transform.append(transforms.RandomHorizontalFlip())
if do_transform:
transform.append(transforms.ToTensor())
if do_normalize:
transform.append(
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
self.transform = transforms.Compose(transform)
def __len__(self):
return self.length
def __getitem__(self, index):
img = self.data[index]
img = self.transform(img)
return {'img': img, 'index': index}
class CelebAttrDataset(Dataset):
id_to_cls = [
'5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes',
'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair',
'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin',
'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones',
'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard',
'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline',
'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair',
'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick',
'Wearing_Necklace', 'Wearing_Necktie', 'Young'
]
cls_to_id = {v: k for k, v in enumerate(id_to_cls)}
def __init__(self,
folder,
image_size=64,
attr_path=os.path.expanduser(
'datasets/celeba_anno/list_attr_celeba.txt'),
ext='png',
only_cls_name: str = None,
only_cls_value: int = None,
do_augment: bool = False,
do_transform: bool = True,
do_normalize: bool = True,
d2c: bool = False):
super().__init__()
self.folder = folder
self.image_size = image_size
self.ext = ext
# relative paths (make it shorter, saves memory and faster to sort)
paths = [
str(p.relative_to(folder))
for p in Path(f'{folder}').glob(f'**/*.{ext}')
]
paths = [str(each).split('.')[0] + '.jpg' for each in paths]
if d2c:
transform = [
d2c_crop(),
transforms.Resize(image_size),
]
else:
transform = [
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
]
if do_augment:
transform.append(transforms.RandomHorizontalFlip())
if do_transform:
transform.append(transforms.ToTensor())
if do_normalize:
transform.append(
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
self.transform = transforms.Compose(transform)
with open(attr_path) as f:
# discard the top line
f.readline()
self.df = pd.read_csv(f, delim_whitespace=True)
self.df = self.df[self.df.index.isin(paths)]
if only_cls_name is not None:
self.df = self.df[self.df[only_cls_name] == only_cls_value]
def pos_count(self, cls_name):
return (self.df[cls_name] == 1).sum()
def neg_count(self, cls_name):
return (self.df[cls_name] == -1).sum()
def __len__(self):
return len(self.df)
def __getitem__(self, index):
row = self.df.iloc[index]
name = row.name.split('.')[0]
name = f'{name}.{self.ext}'
path = os.path.join(self.folder, name)
img = Image.open(path)
labels = [0] * len(self.id_to_cls)
for k, v in row.items():
labels[self.cls_to_id[k]] = int(v)
if self.transform is not None:
img = self.transform(img)
return {'img': img, 'index': index, 'labels': torch.tensor(labels)}
class CelebD2CAttrDataset(CelebAttrDataset):
"""
the dataset is used in the D2C paper.
it has a specific crop from the original CelebA.
"""
def __init__(self,
folder,
image_size=64,
attr_path=os.path.expanduser(
'datasets/celeba_anno/list_attr_celeba.txt'),
ext='jpg',
only_cls_name: str = None,
only_cls_value: int = None,
do_augment: bool = False,
do_transform: bool = True,
do_normalize: bool = True,
d2c: bool = True):
super().__init__(folder,
image_size,
attr_path,
ext=ext,
only_cls_name=only_cls_name,
only_cls_value=only_cls_value,
do_augment=do_augment,
do_transform=do_transform,
do_normalize=do_normalize,
d2c=d2c)
class CelebAttrFewshotDataset(Dataset):
def __init__(
self,
cls_name,
K,
img_folder,
img_size=64,
ext='png',
seed=0,
only_cls_name: str = None,
only_cls_value: int = None,
all_neg: bool = False,
do_augment: bool = False,
do_transform: bool = True,
do_normalize: bool = True,
d2c: bool = False,
) -> None:
self.cls_name = cls_name
self.K = K
self.img_folder = img_folder
self.ext = ext
if all_neg:
path = f'data/celeba_fewshots/K{K}_allneg_{cls_name}_{seed}.csv'
else:
path = f'data/celeba_fewshots/K{K}_{cls_name}_{seed}.csv'
self.df = pd.read_csv(path, index_col=0)
if only_cls_name is not None:
self.df = self.df[self.df[only_cls_name] == only_cls_value]
if d2c:
transform = [
d2c_crop(),
transforms.Resize(img_size),
]
else:
transform = [
transforms.Resize(img_size),
transforms.CenterCrop(img_size),
]
if do_augment:
transform.append(transforms.RandomHorizontalFlip())
if do_transform:
transform.append(transforms.ToTensor())
if do_normalize:
transform.append(
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
self.transform = transforms.Compose(transform)
def pos_count(self, cls_name):
return (self.df[cls_name] == 1).sum()
def neg_count(self, cls_name):
return (self.df[cls_name] == -1).sum()
def __len__(self):
return len(self.df)
def __getitem__(self, index):
row = self.df.iloc[index]
name = row.name.split('.')[0]
name = f'{name}.{self.ext}'
path = os.path.join(self.img_folder, name)
img = Image.open(path)
# (1, 1)
label = torch.tensor(int(row[self.cls_name])).unsqueeze(-1)
if self.transform is not None:
img = self.transform(img)
return {'img': img, 'index': index, 'labels': label}
class CelebD2CAttrFewshotDataset(CelebAttrFewshotDataset):
def __init__(self,
cls_name,
K,
img_folder,
img_size=64,
ext='jpg',
seed=0,
only_cls_name: str = None,
only_cls_value: int = None,
all_neg: bool = False,
do_augment: bool = False,
do_transform: bool = True,
do_normalize: bool = True,
is_negative=False,
d2c: bool = True) -> None:
super().__init__(cls_name,
K,
img_folder,
img_size,
ext=ext,
seed=seed,
only_cls_name=only_cls_name,
only_cls_value=only_cls_value,
all_neg=all_neg,
do_augment=do_augment,
do_transform=do_transform,
do_normalize=do_normalize,
d2c=d2c)
self.is_negative = is_negative
class CelebHQAttrDataset(Dataset):
id_to_cls = [
'5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes',
'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair',
'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin',
'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones',
'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard',
'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline',
'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair',
'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick',
'Wearing_Necklace', 'Wearing_Necktie', 'Young'
]
cls_to_id = {v: k for k, v in enumerate(id_to_cls)}
def __init__(self,
path=os.path.expanduser('datasets/celebahq256.lmdb'),
image_size=None,
attr_path=os.path.expanduser(
'datasets/celeba_anno/CelebAMask-HQ-attribute-anno.txt'),
original_resolution=256,
do_augment: bool = False,
do_transform: bool = True,
do_normalize: bool = True):
super().__init__()
self.image_size = image_size
self.data = BaseLMDB(path, original_resolution, zfill=5)
transform = [
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
]
if do_augment:
transform.append(transforms.RandomHorizontalFlip())
if do_transform:
transform.append(transforms.ToTensor())
if do_normalize:
transform.append(
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
self.transform = transforms.Compose(transform)
with open(attr_path) as f:
# discard the top line
f.readline()
self.df = pd.read_csv(f, delim_whitespace=True)
def pos_count(self, cls_name):
return (self.df[cls_name] == 1).sum()
def neg_count(self, cls_name):
return (self.df[cls_name] == -1).sum()
def __len__(self):
return len(self.df)
def __getitem__(self, index):
row = self.df.iloc[index]
img_name = row.name
img_idx, ext = img_name.split('.')
img = self.data[img_idx]
labels = [0] * len(self.id_to_cls)
for k, v in row.items():
labels[self.cls_to_id[k]] = int(v)
if self.transform is not None:
img = self.transform(img)
return {'img': img, 'index': index, 'labels': torch.tensor(labels)}
class CelebHQAttrFewshotDataset(Dataset):
def __init__(self,
cls_name,
K,
path,
image_size,
original_resolution=256,
do_augment: bool = False,
do_transform: bool = True,
do_normalize: bool = True):
super().__init__()
self.image_size = image_size
self.cls_name = cls_name
self.K = K
self.data = BaseLMDB(path, original_resolution, zfill=5)
transform = [
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
]
if do_augment:
transform.append(transforms.RandomHorizontalFlip())
if do_transform:
transform.append(transforms.ToTensor())
if do_normalize:
transform.append(
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
self.transform = transforms.Compose(transform)
self.df = pd.read_csv(f'data/celebahq_fewshots/K{K}_{cls_name}.csv',
index_col=0)
def pos_count(self, cls_name):
return (self.df[cls_name] == 1).sum()
def neg_count(self, cls_name):
return (self.df[cls_name] == -1).sum()
def __len__(self):
return len(self.df)
def __getitem__(self, index):
row = self.df.iloc[index]
img_name = row.name
img_idx, ext = img_name.split('.')
img = self.data[img_idx]
# (1, 1)
label = torch.tensor(int(row[self.cls_name])).unsqueeze(-1)
if self.transform is not None:
img = self.transform(img)
return {'img': img, 'index': index, 'labels': label}
class Repeat(Dataset):
def __init__(self, dataset, new_len) -> None:
super().__init__()
self.dataset = dataset
self.original_len = len(dataset)
self.new_len = new_len
def __len__(self):
return self.new_len
def __getitem__(self, index):
index = index % self.original_len
return self.dataset[index]