File size: 1,687 Bytes
b84549f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""
convert load-all-images-into-memory-before-training dataset
to load-when-training-dataset


"""


from torchvision.datasets import CIFAR10, STL10, MNIST, USPS, SVHN
import os
import tqdm


def convert(datasets_of_split, new_dir):
    img_idx = {}

    for d in datasets_of_split:
        for x, y in tqdm.tqdm(d, total=len(d), dynamic_ncols=True):
            # print(type(x), type(y))
            # break
            # y = str(y)
            if y not in img_idx:
                img_idx[y] = -1
            img_idx[y] += 1

            p = os.path.join(new_dir, f'{y:06d}', f'{img_idx[y]:06d}' + '.png')
            os.makedirs(os.path.dirname(p), exist_ok=True)

            x.save(p)


if __name__ == '__main__':
    # convert(
    #     [CIFAR10('/data/zql/datasets/CIFAR10', True, download=True), CIFAR10('/data/zql/datasets/CIFAR10', False, download=True)],
    #     '/data/zql/datasets/CIFAR10-single'
    # )

    # convert(
    #     [STL10('/data/zql/datasets/STL10', 'train', download=False), STL10('/data/zql/datasets/STL10', 'test', download=False)],
    #     '/data/zql/datasets/STL10-single'
    # )

    # convert(
    #     [MNIST('/data/zql/datasets/MNIST', True, download=True), MNIST('/data/zql/datasets/MNIST', False, download=True)],
    #     '/data/zql/datasets/MNIST-single'
    # )

    convert(
        [SVHN('/data/zql/datasets/SVHN', 'train', download=True), SVHN('/data/zql/datasets/SVHN', 'test', download=True)],
        '/data/zql/datasets/SVHN-single'
    )

    # convert(
    #     [USPS('/data/zql/datasets/USPS', True, download=False), USPS('/data/zql/datasets/USPS', False, download=False)],
    #     '/data/zql/datasets/USPS-single'
    # )