TeacherPuffy commited on
Commit
f7421f7
·
verified ·
1 Parent(s): 91c2486

Create preprocess.py

Browse files
Files changed (1) hide show
  1. preprocess.py +37 -0
preprocess.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torchvision.transforms as transforms
3
+ from datasets import load_dataset
4
+ from PIL import Image
5
+
6
+ # Preprocess the images
7
+ def preprocess_image(example, image_size):
8
+ image = example['image'].convert('RGB') # Directly use the PIL image
9
+ transform = transforms.Compose([
10
+ transforms.Resize((image_size, image_size)),
11
+ transforms.ToTensor(),
12
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
13
+ ])
14
+ image = transform(image)
15
+ return {'image': image, 'label': example['label']}
16
+
17
+ # Main function
18
+ def main():
19
+ # Load the dataset
20
+ dataset = load_dataset('zh-plus/tiny-imagenet')
21
+ train_dataset = dataset['train']
22
+ val_dataset = dataset['valid']
23
+
24
+ # Determine the fixed resolution of the images
25
+ example_image = train_dataset[0]['image'] # Directly use the PIL image
26
+ image_size = example_image.size[0] # Assuming the images are square
27
+
28
+ # Preprocess the dataset
29
+ train_dataset = train_dataset.map(lambda x: preprocess_image(x, image_size))
30
+ val_dataset = val_dataset.map(lambda x: preprocess_image(x, image_size))
31
+
32
+ # Save the preprocessed datasets
33
+ train_dataset.save_to_disk('preprocessed_train_dataset')
34
+ val_dataset.save_to_disk('preprocessed_val_dataset')
35
+
36
+ if __name__ == '__main__':
37
+ main()