Spaces:
Running
Running
import os | |
from torchvision.datasets import VisionDataset | |
from PIL import Image | |
from sklearn.model_selection import train_test_split | |
class CustomDataset(VisionDataset): | |
def __init__(self, root_path, subset="train", transform=None, target_transform=None, split_ratios=(0.7, 0.15, 0.15), seed=42): | |
super(CustomDataset, self).__init__(root_path, transform=transform, target_transform=target_transform) | |
self.root = root_path | |
self.subset = subset # Can be "train", "val", or "test" | |
self.split_ratios = split_ratios | |
self.seed = seed | |
self.classes, self.class_idx = self._find_classes() | |
self.samples = self._make_dataset() | |
def _find_classes(self): | |
classes = [d.name for d in os.scandir(self.root) if d.is_dir()] | |
classes.sort() | |
class_idx = {cls_name: i for i, cls_name in enumerate(classes)} | |
return classes, class_idx | |
def _make_dataset(self): | |
samples = [] | |
for target_class in sorted(self.class_idx.keys()): | |
class_index = self.class_idx[target_class] | |
target_dir = os.path.join(self.root, target_class) | |
for root, _, fnames in sorted(os.walk(target_dir)): | |
for fname in sorted(fnames): | |
path = os.path.join(root, fname) | |
samples.append((path, class_index)) | |
# Split into train, val, and test sets | |
train_samples, test_samples = train_test_split( | |
samples, test_size=1 - self.split_ratios[0], random_state=self.seed, stratify=[s[1] for s in samples] | |
) | |
val_samples, test_samples = train_test_split( | |
test_samples, test_size=self.split_ratios[2] / (self.split_ratios[1] + self.split_ratios[2]), | |
random_state=self.seed, stratify=[s[1] for s in test_samples] | |
) | |
if self.subset == "train": | |
return train_samples | |
elif self.subset == "val": | |
return val_samples | |
elif self.subset == "test": | |
return test_samples | |
else: | |
raise ValueError(f"Unknown subset: {self.subset}") | |
def __len__(self): | |
return len(self.samples) | |
def __getitem__(self, index): | |
path, target = self.samples[index] | |
img = Image.open(path).convert("RGB") | |
if self.transform is not None: | |
img = self.transform(img) | |
return img, target | |