File size: 1,336 Bytes
e7059c0
 
f7421f7
 
 
 
 
 
 
 
 
48f2751
f7421f7
48f2751
f7421f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
## Debugging, don't use.

import os
import torchvision.transforms as transforms
from datasets import load_dataset
from PIL import Image

# Preprocess the images
def preprocess_image(example, image_size):
    image = example['image'].convert('RGB')  # Directly use the PIL image
    transform = transforms.Compose([
        #transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    image = transform(image)
    return {'image': image, 'label': example['label']}

# Main function
def main():
    # Load the dataset
    dataset = load_dataset('zh-plus/tiny-imagenet')
    train_dataset = dataset['train']
    val_dataset = dataset['valid']

    # Determine the fixed resolution of the images
    example_image = train_dataset[0]['image']  # Directly use the PIL image
    image_size = example_image.size[0]  # Assuming the images are square

    # Preprocess the dataset
    train_dataset = train_dataset.map(lambda x: preprocess_image(x, image_size))
    val_dataset = val_dataset.map(lambda x: preprocess_image(x, image_size))

    # Save the preprocessed datasets
    train_dataset.save_to_disk('preprocessed_train_dataset')
    val_dataset.save_to_disk('preprocessed_val_dataset')

if __name__ == '__main__':
    main()