File size: 5,443 Bytes
319d3b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
"""
Code adapted from timm https://github.com/huggingface/pytorch-image-models
Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich
"""
import logging
from contextlib import suppress
from functools import partial
from itertools import repeat
import numpy as np
import torch
import torch.utils.data
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data.dataset import IterableImageDataset
from timm.data.loader import PrefetchLoader, _worker_init
from timm.data.transforms_factory import create_transform
_logger = logging.getLogger(__name__)
def fast_collate(batch, target_dtype=torch.uint8):
"""A fast collation function optimized for uint8 images (np array or torch) and target_dtype targets (labels)"""
assert isinstance(batch[0], tuple)
batch_size = len(batch)
if isinstance(batch[0][0], np.ndarray):
targets = torch.tensor([b[1] for b in batch], dtype=target_dtype)
assert len(targets) == batch_size
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
for i in range(batch_size):
tensor[i] += torch.from_numpy(batch[i][0])
return tensor, targets
else:
raise ValueError(f"Incorrect batch type: {type(batch[0][0])}")
def adapt_to_chs(x, n):
if not isinstance(x, (tuple, list)):
x = tuple(repeat(x, n))
elif len(x) != n:
# doubled channels
if len(x) * 2 == n:
x = np.concatenate((x, x))
_logger.warning(f"Pretrained mean/std different shape than model (doubled channes), using concat: {x}.")
else:
x_mean = np.mean(x).item()
x = (x_mean,) * n
_logger.warning(f"Pretrained mean/std different shape than model, using avg value {x}.")
else:
assert len(x) == n, "normalization stats must match image channels"
return x
class PrefetchLoaderForMultiInput(PrefetchLoader):
def __init__(
self,
loader,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
channels=3,
device=torch.device("cpu"),
img_dtype=torch.float32,
):
mean = adapt_to_chs(mean, channels)
std = adapt_to_chs(std, channels)
normalization_shape = (1, channels, 1, 1)
self.loader = loader
self.device = device
self.img_dtype = img_dtype
self.mean = torch.tensor([x * 255 for x in mean], device=device, dtype=img_dtype).view(normalization_shape)
self.std = torch.tensor([x * 255 for x in std], device=device, dtype=img_dtype).view(normalization_shape)
self.is_cuda = torch.cuda.is_available() and device.type == "cpu"
def __iter__(self):
first = True
if self.is_cuda:
stream = torch.cuda.Stream()
stream_context = partial(torch.cuda.stream, stream=stream)
else:
stream = None
stream_context = suppress
for next_input, next_target in self.loader:
with stream_context():
next_input = next_input.to(device=self.device, non_blocking=True)
next_target = next_target.to(device=self.device, non_blocking=True)
next_input = next_input.to(self.img_dtype).sub_(self.mean).div_(self.std)
if not first:
yield input, target # noqa: F823, F821
else:
first = False
if stream is not None:
torch.cuda.current_stream().wait_stream(stream)
input = next_input
target = next_target
yield input, target
def create_loader(
dataset,
input_size,
batch_size,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
num_workers=1,
crop_pct=None,
crop_mode=None,
pin_memory=False,
img_dtype=torch.float32,
device=torch.device("cpu"),
persistent_workers=True,
worker_seeding="all",
target_type=torch.int64,
):
transform = create_transform(
input_size,
is_training=False,
use_prefetcher=True,
mean=mean,
std=std,
crop_pct=crop_pct,
crop_mode=crop_mode,
)
dataset.transform = transform
if isinstance(dataset, IterableImageDataset):
# give Iterable datasets early knowledge of num_workers so that sample estimates
# are correct before worker processes are launched
dataset.set_loader_cfg(num_workers=num_workers)
raise ValueError("Incorrect dataset type: IterableImageDataset")
loader_class = torch.utils.data.DataLoader
loader_args = dict(
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
sampler=None,
collate_fn=lambda batch: fast_collate(batch, target_dtype=target_type),
pin_memory=pin_memory,
drop_last=False,
worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding),
persistent_workers=persistent_workers,
)
try:
loader = loader_class(dataset, **loader_args)
except TypeError:
loader_args.pop("persistent_workers") # only in Pytorch 1.7+
loader = loader_class(dataset, **loader_args)
loader = PrefetchLoaderForMultiInput(
loader,
mean=mean,
std=std,
channels=input_size[0],
device=device,
img_dtype=img_dtype,
)
return loader
|