|
from torch.utils.data import Dataset |
|
import os |
|
import pathlib |
|
import torch |
|
|
|
from PIL import Image |
|
from torch.utils.data import Dataset |
|
from torchvision import transforms |
|
from typing import Tuple, Dict, List |
|
|
|
import torch.utils.data as data |
|
import numpy as np |
|
|
|
import random |
|
|
|
|
|
|
|
|
|
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]: |
|
"""Finds the class folder names in a target directory. |
|
|
|
Assumes target directory is in standard image classification format. |
|
|
|
Args: |
|
directory (str): target directory to load classnames from. |
|
|
|
Returns: |
|
Tuple[List[str], Dict[str, int]]: (list_of_class_names, dict(class_name: idx...)) |
|
|
|
Example: |
|
find_classes("food_images/train") |
|
>>> (["class_1", "class_2"], {"class_1": 0, ...}) |
|
""" |
|
|
|
classes = sorted([entry.name for entry in os.scandir(directory) if entry.is_dir()]) |
|
|
|
if not classes: |
|
raise FileNotFoundError(f"Couldn't find any classes in {directory}.") |
|
|
|
|
|
class_to_idx = {cls_name: int(cls_name) for cls_name in (classes)} |
|
return classes, class_to_idx |
|
|
|
|
|
class SamData(Dataset): |
|
|
|
|
|
def __init__(self, targ_dir: str, transform=None) -> None: |
|
|
|
|
|
|
|
self.paths = sorted(list(pathlib.Path(targ_dir).glob("*/*.jpg"))) |
|
|
|
|
|
self.indexes = [] |
|
self.folds = [] |
|
for i, n in enumerate(self.paths): |
|
|
|
strrr= str(n) |
|
|
|
self.indexes.append(int(strrr[strrr.index('sa_')+13:strrr.index('.jpg')])) |
|
self.folds.append(strrr[strrr.index('sa_')+3:strrr.index('sa_')+9]) |
|
|
|
self.transform = transform |
|
|
|
|
|
|
|
|
|
def load_image(self, index: int) -> Image.Image: |
|
"Opens an image via a path and returns it." |
|
image_path = self.paths[index] |
|
return Image.open(image_path) |
|
|
|
|
|
def __len__(self) -> int: |
|
"Returns the total number of samples." |
|
return len(self.paths) |
|
|
|
|
|
def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]: |
|
"Returns one sample of data, data, label (X, y, index)." |
|
img = self.load_image(index) |
|
|
|
indx = self.indexes[index] |
|
|
|
|
|
|
|
|
|
|
|
if self.transform: |
|
return self.transform(img), indx |
|
else: |
|
return img, indx |