""" | |
convert load-all-images-into-memory-before-training dataset | |
to load-when-training-dataset | |
""" | |
from torchvision.datasets import CIFAR10, STL10, MNIST, USPS, SVHN | |
import os | |
import tqdm | |
def convert(datasets_of_split, new_dir): | |
img_idx = {} | |
for d in datasets_of_split: | |
for x, y in tqdm.tqdm(d, total=len(d), dynamic_ncols=True): | |
# print(type(x), type(y)) | |
# break | |
# y = str(y) | |
if y not in img_idx: | |
img_idx[y] = -1 | |
img_idx[y] += 1 | |
p = os.path.join(new_dir, f'{y:06d}', f'{img_idx[y]:06d}' + '.png') | |
os.makedirs(os.path.dirname(p), exist_ok=True) | |
x.save(p) | |
if __name__ == '__main__': | |
# convert( | |
# [CIFAR10('/data/zql/datasets/CIFAR10', True, download=True), CIFAR10('/data/zql/datasets/CIFAR10', False, download=True)], | |
# '/data/zql/datasets/CIFAR10-single' | |
# ) | |
# convert( | |
# [STL10('/data/zql/datasets/STL10', 'train', download=False), STL10('/data/zql/datasets/STL10', 'test', download=False)], | |
# '/data/zql/datasets/STL10-single' | |
# ) | |
# convert( | |
# [MNIST('/data/zql/datasets/MNIST', True, download=True), MNIST('/data/zql/datasets/MNIST', False, download=True)], | |
# '/data/zql/datasets/MNIST-single' | |
# ) | |
convert( | |
[SVHN('/data/zql/datasets/SVHN', 'train', download=True), SVHN('/data/zql/datasets/SVHN', 'test', download=True)], | |
'/data/zql/datasets/SVHN-single' | |
) | |
# convert( | |
# [USPS('/data/zql/datasets/USPS', True, download=False), USPS('/data/zql/datasets/USPS', False, download=False)], | |
# '/data/zql/datasets/USPS-single' | |
# ) |