MiVOLO / mivolo /data /dataset /age_gender_loader.py
admin
sync
319d3b5
raw
history blame
5.44 kB
"""
Code adapted from timm https://github.com/huggingface/pytorch-image-models
Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich
"""
import logging
from contextlib import suppress
from functools import partial
from itertools import repeat
import numpy as np
import torch
import torch.utils.data
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data.dataset import IterableImageDataset
from timm.data.loader import PrefetchLoader, _worker_init
from timm.data.transforms_factory import create_transform
_logger = logging.getLogger(__name__)
def fast_collate(batch, target_dtype=torch.uint8):
"""A fast collation function optimized for uint8 images (np array or torch) and target_dtype targets (labels)"""
assert isinstance(batch[0], tuple)
batch_size = len(batch)
if isinstance(batch[0][0], np.ndarray):
targets = torch.tensor([b[1] for b in batch], dtype=target_dtype)
assert len(targets) == batch_size
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
for i in range(batch_size):
tensor[i] += torch.from_numpy(batch[i][0])
return tensor, targets
else:
raise ValueError(f"Incorrect batch type: {type(batch[0][0])}")
def adapt_to_chs(x, n):
if not isinstance(x, (tuple, list)):
x = tuple(repeat(x, n))
elif len(x) != n:
# doubled channels
if len(x) * 2 == n:
x = np.concatenate((x, x))
_logger.warning(f"Pretrained mean/std different shape than model (doubled channes), using concat: {x}.")
else:
x_mean = np.mean(x).item()
x = (x_mean,) * n
_logger.warning(f"Pretrained mean/std different shape than model, using avg value {x}.")
else:
assert len(x) == n, "normalization stats must match image channels"
return x
class PrefetchLoaderForMultiInput(PrefetchLoader):
def __init__(
self,
loader,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
channels=3,
device=torch.device("cpu"),
img_dtype=torch.float32,
):
mean = adapt_to_chs(mean, channels)
std = adapt_to_chs(std, channels)
normalization_shape = (1, channels, 1, 1)
self.loader = loader
self.device = device
self.img_dtype = img_dtype
self.mean = torch.tensor([x * 255 for x in mean], device=device, dtype=img_dtype).view(normalization_shape)
self.std = torch.tensor([x * 255 for x in std], device=device, dtype=img_dtype).view(normalization_shape)
self.is_cuda = torch.cuda.is_available() and device.type == "cpu"
def __iter__(self):
first = True
if self.is_cuda:
stream = torch.cuda.Stream()
stream_context = partial(torch.cuda.stream, stream=stream)
else:
stream = None
stream_context = suppress
for next_input, next_target in self.loader:
with stream_context():
next_input = next_input.to(device=self.device, non_blocking=True)
next_target = next_target.to(device=self.device, non_blocking=True)
next_input = next_input.to(self.img_dtype).sub_(self.mean).div_(self.std)
if not first:
yield input, target # noqa: F823, F821
else:
first = False
if stream is not None:
torch.cuda.current_stream().wait_stream(stream)
input = next_input
target = next_target
yield input, target
def create_loader(
dataset,
input_size,
batch_size,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
num_workers=1,
crop_pct=None,
crop_mode=None,
pin_memory=False,
img_dtype=torch.float32,
device=torch.device("cpu"),
persistent_workers=True,
worker_seeding="all",
target_type=torch.int64,
):
transform = create_transform(
input_size,
is_training=False,
use_prefetcher=True,
mean=mean,
std=std,
crop_pct=crop_pct,
crop_mode=crop_mode,
)
dataset.transform = transform
if isinstance(dataset, IterableImageDataset):
# give Iterable datasets early knowledge of num_workers so that sample estimates
# are correct before worker processes are launched
dataset.set_loader_cfg(num_workers=num_workers)
raise ValueError("Incorrect dataset type: IterableImageDataset")
loader_class = torch.utils.data.DataLoader
loader_args = dict(
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
sampler=None,
collate_fn=lambda batch: fast_collate(batch, target_dtype=target_type),
pin_memory=pin_memory,
drop_last=False,
worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding),
persistent_workers=persistent_workers,
)
try:
loader = loader_class(dataset, **loader_args)
except TypeError:
loader_args.pop("persistent_workers") # only in Pytorch 1.7+
loader = loader_class(dataset, **loader_args)
loader = PrefetchLoaderForMultiInput(
loader,
mean=mean,
std=std,
channels=input_size[0],
device=device,
img_dtype=img_dtype,
)
return loader