File size: 2,394 Bytes
0874d87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
from torchvision.datasets import VisionDataset
from PIL import Image
from sklearn.model_selection import train_test_split


class CustomDataset(VisionDataset):
    def __init__(self, root_path, subset="train", transform=None, target_transform=None, split_ratios=(0.7, 0.15, 0.15), seed=42):
        super(CustomDataset, self).__init__(root_path, transform=transform, target_transform=target_transform)
        self.root = root_path
        self.subset = subset  # Can be "train", "val", or "test"
        self.split_ratios = split_ratios
        self.seed = seed

        self.classes, self.class_idx = self._find_classes()
        self.samples = self._make_dataset()

    def _find_classes(self):
        classes = [d.name for d in os.scandir(self.root) if d.is_dir()]
        classes.sort()
        class_idx = {cls_name: i for i, cls_name in enumerate(classes)}
        return classes, class_idx

    def _make_dataset(self):
        samples = []
        for target_class in sorted(self.class_idx.keys()):
            class_index = self.class_idx[target_class]
            target_dir = os.path.join(self.root, target_class)
            for root, _, fnames in sorted(os.walk(target_dir)):
                for fname in sorted(fnames):
                    path = os.path.join(root, fname)
                    samples.append((path, class_index))

        # Split into train, val, and test sets
        train_samples, test_samples = train_test_split(
            samples, test_size=1 - self.split_ratios[0], random_state=self.seed, stratify=[s[1] for s in samples]
        )
        val_samples, test_samples = train_test_split(
            test_samples, test_size=self.split_ratios[2] / (self.split_ratios[1] + self.split_ratios[2]),
            random_state=self.seed, stratify=[s[1] for s in test_samples]
        )

        if self.subset == "train":
            return train_samples
        elif self.subset == "val":
            return val_samples
        elif self.subset == "test":
            return test_samples
        else:
            raise ValueError(f"Unknown subset: {self.subset}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        path, target = self.samples[index]
        img = Image.open(path).convert("RGB")

        if self.transform is not None:
            img = self.transform(img)

        return img, target