|
|
|
|
|
import os |
|
import torchvision.transforms as transforms |
|
from datasets import load_dataset |
|
from PIL import Image |
|
|
|
|
|
def preprocess_image(example, image_size): |
|
image = example['image'].convert('RGB') |
|
transform = transforms.Compose([ |
|
|
|
transforms.ToTensor(), |
|
|
|
]) |
|
image = transform(image) |
|
return {'image': image, 'label': example['label']} |
|
|
|
|
|
def main(): |
|
|
|
dataset = load_dataset('zh-plus/tiny-imagenet') |
|
train_dataset = dataset['train'] |
|
val_dataset = dataset['valid'] |
|
|
|
|
|
example_image = train_dataset[0]['image'] |
|
image_size = example_image.size[0] |
|
|
|
|
|
train_dataset = train_dataset.map(lambda x: preprocess_image(x, image_size)) |
|
val_dataset = val_dataset.map(lambda x: preprocess_image(x, image_size)) |
|
|
|
|
|
train_dataset.save_to_disk('preprocessed_train_dataset') |
|
val_dataset.save_to_disk('preprocessed_val_dataset') |
|
|
|
if __name__ == '__main__': |
|
main() |