File size: 2,335 Bytes
803ef9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import os
import copy
import json
import operator
import numpy as np
from PIL import Image
from os.path import join
from itertools import chain
from scipy.io import loadmat
from collections import defaultdict
import torch
import torch.utils.data as data
from torchvision import transforms
DATA_ROOTS = 'data/VGGFlower'
# wget https://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz
# tar -xvzf 102flowers.tgz
# rename file to VGGFlower
# cd VGGFlower
# wget https://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat
class VGGFlower(data.Dataset):
def __init__(self, root=DATA_ROOTS, train=True, image_transforms=None):
super().__init__()
self.root = root
self.train = train
self.image_transforms = image_transforms
paths, labels = self.load_images()
self.paths, self.labels = paths, labels
def load_images(self):
rs = np.random.RandomState(42)
imagelabels_path = os.path.join(self.root, 'imagelabels.mat')
with open(imagelabels_path, 'rb') as f:
labels = loadmat(f)['labels'][0]
all_filepaths = defaultdict(list)
for i, label in enumerate(labels):
# all_filepaths[label].append(os.path.join(self.root, 'jpg', 'image_{:05d}.jpg'.format(i+1)))
all_filepaths[label].append(os.path.join(self.root, 'image_{:05d}.jpg'.format(i+1)))
# train test split
split_filepaths, split_labels = [], []
for label, paths in all_filepaths.items():
num = len(paths)
paths = np.array(paths)
indexer = np.arange(num)
rs.shuffle(indexer)
paths = paths[indexer].tolist()
if self.train:
paths = paths[:int(0.8 * num)]
else:
paths = paths[int(0.8 * num):]
labels = [label] * len(paths)
split_filepaths.extend(paths)
split_labels.extend(labels)
return split_filepaths, split_labels
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
path = self.paths[index]
label = int(self.labels[index]) - 1
image = Image.open(path).convert(mode='RGB')
if self.image_transforms:
image = self.image_transforms(image)
return image, label |