|
""" Quick n Simple Image Folder, Tarfile based DataSet |
|
|
|
Hacked together by / Copyright 2019, Ross Wightman |
|
""" |
|
import io |
|
import logging |
|
from typing import Optional |
|
|
|
import torch |
|
import torch.utils.data as data |
|
from PIL import Image |
|
|
|
from .readers import create_reader |
|
|
|
_logger = logging.getLogger(__name__) |
|
|
|
|
|
_ERROR_RETRY = 50 |
|
|
|
|
|
class ImageDataset(data.Dataset): |
|
|
|
def __init__( |
|
self, |
|
root, |
|
reader=None, |
|
split='train', |
|
class_map=None, |
|
load_bytes=False, |
|
input_img_mode='RGB', |
|
transform=None, |
|
target_transform=None, |
|
): |
|
if reader is None or isinstance(reader, str): |
|
reader = create_reader( |
|
reader or '', |
|
root=root, |
|
split=split, |
|
class_map=class_map |
|
) |
|
self.reader = reader |
|
self.load_bytes = load_bytes |
|
self.input_img_mode = input_img_mode |
|
self.transform = transform |
|
self.target_transform = target_transform |
|
self._consecutive_errors = 0 |
|
|
|
def __getitem__(self, index): |
|
img, target = self.reader[index] |
|
|
|
try: |
|
img = img.read() if self.load_bytes else Image.open(img) |
|
except Exception as e: |
|
_logger.warning(f'Skipped sample (index {index}, file {self.reader.filename(index)}). {str(e)}') |
|
self._consecutive_errors += 1 |
|
if self._consecutive_errors < _ERROR_RETRY: |
|
return self.__getitem__((index + 1) % len(self.reader)) |
|
else: |
|
raise e |
|
self._consecutive_errors = 0 |
|
|
|
if self.input_img_mode and not self.load_bytes: |
|
img = img.convert(self.input_img_mode) |
|
if self.transform is not None: |
|
img = self.transform(img) |
|
|
|
if target is None: |
|
target = -1 |
|
elif self.target_transform is not None: |
|
target = self.target_transform(target) |
|
|
|
return img, target |
|
|
|
def __len__(self): |
|
return len(self.reader) |
|
|
|
def filename(self, index, basename=False, absolute=False): |
|
return self.reader.filename(index, basename, absolute) |
|
|
|
def filenames(self, basename=False, absolute=False): |
|
return self.reader.filenames(basename, absolute) |
|
|
|
|
|
class IterableImageDataset(data.IterableDataset): |
|
|
|
def __init__( |
|
self, |
|
root, |
|
reader=None, |
|
split='train', |
|
class_map=None, |
|
is_training=False, |
|
batch_size=1, |
|
num_samples=None, |
|
seed=42, |
|
repeats=0, |
|
download=False, |
|
input_img_mode='RGB', |
|
input_key=None, |
|
target_key=None, |
|
transform=None, |
|
target_transform=None, |
|
max_steps=None, |
|
): |
|
assert reader is not None |
|
if isinstance(reader, str): |
|
self.reader = create_reader( |
|
reader, |
|
root=root, |
|
split=split, |
|
class_map=class_map, |
|
is_training=is_training, |
|
batch_size=batch_size, |
|
num_samples=num_samples, |
|
seed=seed, |
|
repeats=repeats, |
|
download=download, |
|
input_img_mode=input_img_mode, |
|
input_key=input_key, |
|
target_key=target_key, |
|
max_steps=max_steps, |
|
) |
|
else: |
|
self.reader = reader |
|
self.transform = transform |
|
self.target_transform = target_transform |
|
self._consecutive_errors = 0 |
|
|
|
def __iter__(self): |
|
for img, target in self.reader: |
|
if self.transform is not None: |
|
img = self.transform(img) |
|
if self.target_transform is not None: |
|
target = self.target_transform(target) |
|
yield img, target |
|
|
|
def __len__(self): |
|
if hasattr(self.reader, '__len__'): |
|
return len(self.reader) |
|
else: |
|
return 0 |
|
|
|
def set_epoch(self, count): |
|
|
|
if hasattr(self.reader, 'set_epoch'): |
|
self.reader.set_epoch(count) |
|
|
|
def set_loader_cfg( |
|
self, |
|
num_workers: Optional[int] = None, |
|
): |
|
|
|
if hasattr(self.reader, 'set_loader_cfg'): |
|
self.reader.set_loader_cfg(num_workers=num_workers) |
|
|
|
def filename(self, index, basename=False, absolute=False): |
|
assert False, 'Filename lookup by index not supported, use filenames().' |
|
|
|
def filenames(self, basename=False, absolute=False): |
|
return self.reader.filenames(basename, absolute) |
|
|
|
|
|
class AugMixDataset(torch.utils.data.Dataset): |
|
"""Dataset wrapper to perform AugMix or other clean/augmentation mixes""" |
|
|
|
def __init__(self, dataset, num_splits=2): |
|
self.augmentation = None |
|
self.normalize = None |
|
self.dataset = dataset |
|
if self.dataset.transform is not None: |
|
self._set_transforms(self.dataset.transform) |
|
self.num_splits = num_splits |
|
|
|
def _set_transforms(self, x): |
|
assert isinstance(x, (list, tuple)) and len(x) == 3, 'Expecting a tuple/list of 3 transforms' |
|
self.dataset.transform = x[0] |
|
self.augmentation = x[1] |
|
self.normalize = x[2] |
|
|
|
@property |
|
def transform(self): |
|
return self.dataset.transform |
|
|
|
@transform.setter |
|
def transform(self, x): |
|
self._set_transforms(x) |
|
|
|
def _normalize(self, x): |
|
return x if self.normalize is None else self.normalize(x) |
|
|
|
def __getitem__(self, i): |
|
x, y = self.dataset[i] |
|
x_list = [self._normalize(x)] |
|
|
|
for _ in range(self.num_splits - 1): |
|
x_list.append(self._normalize(self.augmentation(x))) |
|
return tuple(x_list), y |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|