|
""" A dataset reader that extracts images from folders |
|
|
|
Folders are scanned recursively to find image files. Labels are based |
|
on the folder hierarchy, just leaf folders by default. |
|
|
|
Hacked together by / Copyright 2020 Ross Wightman |
|
""" |
|
import os |
|
from typing import Dict, List, Optional, Set, Tuple, Union |
|
|
|
from timm.utils.misc import natural_key |
|
|
|
from .class_map import load_class_map |
|
from .img_extensions import get_img_extensions |
|
from .reader import Reader |
|
|
|
|
|
def find_images_and_targets( |
|
folder: str, |
|
types: Optional[Union[List, Tuple, Set]] = None, |
|
class_to_idx: Optional[Dict] = None, |
|
leaf_name_only: bool = True, |
|
sort: bool = True |
|
): |
|
""" Walk folder recursively to discover images and map them to classes by folder names. |
|
|
|
Args: |
|
folder: root of folder to recrusively search |
|
types: types (file extensions) to search for in path |
|
class_to_idx: specify mapping for class (folder name) to class index if set |
|
leaf_name_only: use only leaf-name of folder walk for class names |
|
sort: re-sort found images by name (for consistent ordering) |
|
|
|
Returns: |
|
A list of image and target tuples, class_to_idx mapping |
|
""" |
|
types = get_img_extensions(as_set=True) if not types else set(types) |
|
labels = [] |
|
filenames = [] |
|
for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True): |
|
rel_path = os.path.relpath(root, folder) if (root != folder) else '' |
|
label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_') |
|
for f in files: |
|
base, ext = os.path.splitext(f) |
|
if ext.lower() in types: |
|
filenames.append(os.path.join(root, f)) |
|
labels.append(label) |
|
if class_to_idx is None: |
|
|
|
unique_labels = set(labels) |
|
sorted_labels = list(sorted(unique_labels, key=natural_key)) |
|
class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)} |
|
images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx] |
|
if sort: |
|
images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0])) |
|
return images_and_targets, class_to_idx |
|
|
|
|
|
class ReaderImageFolder(Reader): |
|
|
|
def __init__( |
|
self, |
|
root, |
|
class_map='', |
|
input_key=None, |
|
): |
|
super().__init__() |
|
|
|
self.root = root |
|
class_to_idx = None |
|
if class_map: |
|
class_to_idx = load_class_map(class_map, root) |
|
find_types = None |
|
if input_key: |
|
find_types = input_key.split(';') |
|
self.samples, self.class_to_idx = find_images_and_targets( |
|
root, |
|
class_to_idx=class_to_idx, |
|
types=find_types, |
|
) |
|
if len(self.samples) == 0: |
|
raise RuntimeError( |
|
f'Found 0 images in subfolders of {root}. ' |
|
f'Supported image extensions are {", ".join(get_img_extensions())}') |
|
|
|
def __getitem__(self, index): |
|
path, target = self.samples[index] |
|
return open(path, 'rb'), target |
|
|
|
def __len__(self): |
|
return len(self.samples) |
|
|
|
def _filename(self, index, basename=False, absolute=False): |
|
filename = self.samples[index][0] |
|
if basename: |
|
filename = os.path.basename(filename) |
|
elif not absolute: |
|
filename = os.path.relpath(filename, self.root) |
|
return filename |
|
|