MLPScaling / preprocess.py
TeacherPuffy's picture
Update preprocess.py
e7059c0 verified
## Debugging, don't use.
import os
import torchvision.transforms as transforms
from datasets import load_dataset
from PIL import Image
# Preprocess the images
def preprocess_image(example, image_size):
image = example['image'].convert('RGB') # Directly use the PIL image
transform = transforms.Compose([
#transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
#transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image = transform(image)
return {'image': image, 'label': example['label']}
# Main function
def main():
# Load the dataset
dataset = load_dataset('zh-plus/tiny-imagenet')
train_dataset = dataset['train']
val_dataset = dataset['valid']
# Determine the fixed resolution of the images
example_image = train_dataset[0]['image'] # Directly use the PIL image
image_size = example_image.size[0] # Assuming the images are square
# Preprocess the dataset
train_dataset = train_dataset.map(lambda x: preprocess_image(x, image_size))
val_dataset = val_dataset.map(lambda x: preprocess_image(x, image_size))
# Save the preprocessed datasets
train_dataset.save_to_disk('preprocessed_train_dataset')
val_dataset.save_to_disk('preprocessed_val_dataset')
if __name__ == '__main__':
main()