|
import torch |
|
import torch.multiprocessing as multiprocessing |
|
from torch._C import _set_worker_signal_handlers, \ |
|
_remove_worker_pids, _error_if_any_worker_fails |
|
try: |
|
from torch._C import _set_worker_pids |
|
except: |
|
from torch._C import _update_worker_pids as _set_worker_pids |
|
from .sampler import SequentialSampler, RandomSampler, BatchSampler |
|
import signal |
|
import collections |
|
import re |
|
import sys |
|
import threading |
|
import traceback |
|
from torch._six import string_classes, int_classes |
|
import numpy as np |
|
|
|
if sys.version_info[0] == 2: |
|
import Queue as queue |
|
else: |
|
import queue |
|
|
|
|
|
class ExceptionWrapper(object): |
|
r"Wraps an exception plus traceback to communicate across threads" |
|
|
|
def __init__(self, exc_info): |
|
self.exc_type = exc_info[0] |
|
self.exc_msg = "".join(traceback.format_exception(*exc_info)) |
|
|
|
|
|
_use_shared_memory = False |
|
"""Whether to use shared memory in default_collate""" |
|
|
|
|
|
def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id): |
|
global _use_shared_memory |
|
_use_shared_memory = True |
|
|
|
|
|
|
|
|
|
|
|
_set_worker_signal_handlers() |
|
|
|
torch.set_num_threads(1) |
|
torch.manual_seed(seed) |
|
np.random.seed(seed) |
|
|
|
if init_fn is not None: |
|
init_fn(worker_id) |
|
|
|
while True: |
|
r = index_queue.get() |
|
if r is None: |
|
break |
|
idx, batch_indices = r |
|
try: |
|
samples = collate_fn([dataset[i] for i in batch_indices]) |
|
except Exception: |
|
data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) |
|
else: |
|
data_queue.put((idx, samples)) |
|
|
|
|
|
def _worker_manager_loop(in_queue, out_queue, done_event, pin_memory, device_id): |
|
if pin_memory: |
|
torch.cuda.set_device(device_id) |
|
|
|
while True: |
|
try: |
|
r = in_queue.get() |
|
except Exception: |
|
if done_event.is_set(): |
|
return |
|
raise |
|
if r is None: |
|
break |
|
if isinstance(r[1], ExceptionWrapper): |
|
out_queue.put(r) |
|
continue |
|
idx, batch = r |
|
try: |
|
if pin_memory: |
|
batch = pin_memory_batch(batch) |
|
except Exception: |
|
out_queue.put((idx, ExceptionWrapper(sys.exc_info()))) |
|
else: |
|
out_queue.put((idx, batch)) |
|
|
|
numpy_type_map = { |
|
'float64': torch.DoubleTensor, |
|
'float32': torch.FloatTensor, |
|
'float16': torch.HalfTensor, |
|
'int64': torch.LongTensor, |
|
'int32': torch.IntTensor, |
|
'int16': torch.ShortTensor, |
|
'int8': torch.CharTensor, |
|
'uint8': torch.ByteTensor, |
|
} |
|
|
|
|
|
def default_collate(batch): |
|
"Puts each data field into a tensor with outer dimension batch size" |
|
|
|
error_msg = "batch must contain tensors, numbers, dicts or lists; found {}" |
|
elem_type = type(batch[0]) |
|
if torch.is_tensor(batch[0]): |
|
out = None |
|
if _use_shared_memory: |
|
|
|
|
|
numel = sum([x.numel() for x in batch]) |
|
storage = batch[0].storage()._new_shared(numel) |
|
out = batch[0].new(storage) |
|
return torch.stack(batch, 0, out=out) |
|
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ |
|
and elem_type.__name__ != 'string_': |
|
elem = batch[0] |
|
if elem_type.__name__ == 'ndarray': |
|
|
|
if re.search('[SaUO]', elem.dtype.str) is not None: |
|
raise TypeError(error_msg.format(elem.dtype)) |
|
|
|
return torch.stack([torch.from_numpy(b) for b in batch], 0) |
|
if elem.shape == (): |
|
py_type = float if elem.dtype.name.startswith('float') else int |
|
return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) |
|
elif isinstance(batch[0], int_classes): |
|
return torch.LongTensor(batch) |
|
elif isinstance(batch[0], float): |
|
return torch.DoubleTensor(batch) |
|
elif isinstance(batch[0], string_classes): |
|
return batch |
|
elif isinstance(batch[0], collections.Mapping): |
|
return {key: default_collate([d[key] for d in batch]) for key in batch[0]} |
|
elif isinstance(batch[0], collections.Sequence): |
|
transposed = zip(*batch) |
|
return [default_collate(samples) for samples in transposed] |
|
|
|
raise TypeError((error_msg.format(type(batch[0])))) |
|
|
|
|
|
def pin_memory_batch(batch): |
|
if torch.is_tensor(batch): |
|
return batch.pin_memory() |
|
elif isinstance(batch, string_classes): |
|
return batch |
|
elif isinstance(batch, collections.Mapping): |
|
return {k: pin_memory_batch(sample) for k, sample in batch.items()} |
|
elif isinstance(batch, collections.Sequence): |
|
return [pin_memory_batch(sample) for sample in batch] |
|
else: |
|
return batch |
|
|
|
|
|
_SIGCHLD_handler_set = False |
|
"""Whether SIGCHLD handler is set for DataLoader worker failures. Only one |
|
handler needs to be set for all DataLoaders in a process.""" |
|
|
|
|
|
def _set_SIGCHLD_handler(): |
|
|
|
if sys.platform == 'win32': |
|
return |
|
|
|
if not isinstance(threading.current_thread(), threading._MainThread): |
|
return |
|
global _SIGCHLD_handler_set |
|
if _SIGCHLD_handler_set: |
|
return |
|
previous_handler = signal.getsignal(signal.SIGCHLD) |
|
if not callable(previous_handler): |
|
previous_handler = None |
|
|
|
def handler(signum, frame): |
|
|
|
|
|
_error_if_any_worker_fails() |
|
if previous_handler is not None: |
|
previous_handler(signum, frame) |
|
|
|
signal.signal(signal.SIGCHLD, handler) |
|
_SIGCHLD_handler_set = True |
|
|
|
|
|
class DataLoaderIter(object): |
|
"Iterates once over the DataLoader's dataset, as specified by the sampler" |
|
|
|
def __init__(self, loader): |
|
self.dataset = loader.dataset |
|
self.collate_fn = loader.collate_fn |
|
self.batch_sampler = loader.batch_sampler |
|
self.num_workers = loader.num_workers |
|
self.pin_memory = loader.pin_memory and torch.cuda.is_available() |
|
self.timeout = loader.timeout |
|
self.done_event = threading.Event() |
|
|
|
self.sample_iter = iter(self.batch_sampler) |
|
|
|
if self.num_workers > 0: |
|
self.worker_init_fn = loader.worker_init_fn |
|
self.index_queue = multiprocessing.SimpleQueue() |
|
self.worker_result_queue = multiprocessing.SimpleQueue() |
|
self.batches_outstanding = 0 |
|
self.worker_pids_set = False |
|
self.shutdown = False |
|
self.send_idx = 0 |
|
self.rcvd_idx = 0 |
|
self.reorder_dict = {} |
|
|
|
base_seed = torch.LongTensor(1).random_(0, 2**31-1)[0] |
|
self.workers = [ |
|
multiprocessing.Process( |
|
target=_worker_loop, |
|
args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn, |
|
base_seed + i, self.worker_init_fn, i)) |
|
for i in range(self.num_workers)] |
|
|
|
if self.pin_memory or self.timeout > 0: |
|
self.data_queue = queue.Queue() |
|
if self.pin_memory: |
|
maybe_device_id = torch.cuda.current_device() |
|
else: |
|
|
|
maybe_device_id = None |
|
self.worker_manager_thread = threading.Thread( |
|
target=_worker_manager_loop, |
|
args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory, |
|
maybe_device_id)) |
|
self.worker_manager_thread.daemon = True |
|
self.worker_manager_thread.start() |
|
else: |
|
self.data_queue = self.worker_result_queue |
|
|
|
for w in self.workers: |
|
w.daemon = True |
|
w.start() |
|
|
|
_set_worker_pids(id(self), tuple(w.pid for w in self.workers)) |
|
_set_SIGCHLD_handler() |
|
self.worker_pids_set = True |
|
|
|
|
|
for _ in range(2 * self.num_workers): |
|
self._put_indices() |
|
|
|
def __len__(self): |
|
return len(self.batch_sampler) |
|
|
|
def _get_batch(self): |
|
if self.timeout > 0: |
|
try: |
|
return self.data_queue.get(timeout=self.timeout) |
|
except queue.Empty: |
|
raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout)) |
|
else: |
|
return self.data_queue.get() |
|
|
|
def __next__(self): |
|
if self.num_workers == 0: |
|
indices = next(self.sample_iter) |
|
batch = self.collate_fn([self.dataset[i] for i in indices]) |
|
if self.pin_memory: |
|
batch = pin_memory_batch(batch) |
|
return batch |
|
|
|
|
|
if self.rcvd_idx in self.reorder_dict: |
|
batch = self.reorder_dict.pop(self.rcvd_idx) |
|
return self._process_next_batch(batch) |
|
|
|
if self.batches_outstanding == 0: |
|
self._shutdown_workers() |
|
raise StopIteration |
|
|
|
while True: |
|
assert (not self.shutdown and self.batches_outstanding > 0) |
|
idx, batch = self._get_batch() |
|
self.batches_outstanding -= 1 |
|
if idx != self.rcvd_idx: |
|
|
|
self.reorder_dict[idx] = batch |
|
continue |
|
return self._process_next_batch(batch) |
|
|
|
next = __next__ |
|
|
|
def __iter__(self): |
|
return self |
|
|
|
def _put_indices(self): |
|
assert self.batches_outstanding < 2 * self.num_workers |
|
indices = next(self.sample_iter, None) |
|
if indices is None: |
|
return |
|
self.index_queue.put((self.send_idx, indices)) |
|
self.batches_outstanding += 1 |
|
self.send_idx += 1 |
|
|
|
def _process_next_batch(self, batch): |
|
self.rcvd_idx += 1 |
|
self._put_indices() |
|
if isinstance(batch, ExceptionWrapper): |
|
raise batch.exc_type(batch.exc_msg) |
|
return batch |
|
|
|
def __getstate__(self): |
|
|
|
|
|
|
|
|
|
|
|
raise NotImplementedError("DataLoaderIterator cannot be pickled") |
|
|
|
def _shutdown_workers(self): |
|
try: |
|
if not self.shutdown: |
|
self.shutdown = True |
|
self.done_event.set() |
|
|
|
while not self.data_queue.empty(): |
|
self.data_queue.get() |
|
for _ in self.workers: |
|
self.index_queue.put(None) |
|
|
|
|
|
self.worker_result_queue.put(None) |
|
finally: |
|
|
|
if self.worker_pids_set: |
|
_remove_worker_pids(id(self)) |
|
self.worker_pids_set = False |
|
|
|
def __del__(self): |
|
if self.num_workers > 0: |
|
self._shutdown_workers() |
|
|
|
|
|
class DataLoader(object): |
|
""" |
|
Data loader. Combines a dataset and a sampler, and provides |
|
single- or multi-process iterators over the dataset. |
|
|
|
Arguments: |
|
dataset (Dataset): dataset from which to load the data. |
|
batch_size (int, optional): how many samples per batch to load |
|
(default: 1). |
|
shuffle (bool, optional): set to ``True`` to have the data reshuffled |
|
at every epoch (default: False). |
|
sampler (Sampler, optional): defines the strategy to draw samples from |
|
the dataset. If specified, ``shuffle`` must be False. |
|
batch_sampler (Sampler, optional): like sampler, but returns a batch of |
|
indices at a time. Mutually exclusive with batch_size, shuffle, |
|
sampler, and drop_last. |
|
num_workers (int, optional): how many subprocesses to use for data |
|
loading. 0 means that the data will be loaded in the main process. |
|
(default: 0) |
|
collate_fn (callable, optional): merges a list of samples to form a mini-batch. |
|
pin_memory (bool, optional): If ``True``, the data loader will copy tensors |
|
into CUDA pinned memory before returning them. |
|
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, |
|
if the dataset size is not divisible by the batch size. If ``False`` and |
|
the size of dataset is not divisible by the batch size, then the last batch |
|
will be smaller. (default: False) |
|
timeout (numeric, optional): if positive, the timeout value for collecting a batch |
|
from workers. Should always be non-negative. (default: 0) |
|
worker_init_fn (callable, optional): If not None, this will be called on each |
|
worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as |
|
input, after seeding and before data loading. (default: None) |
|
|
|
.. note:: By default, each worker will have its PyTorch seed set to |
|
``base_seed + worker_id``, where ``base_seed`` is a long generated |
|
by main process using its RNG. You may use ``torch.initial_seed()`` to access |
|
this value in :attr:`worker_init_fn`, which can be used to set other seeds |
|
(e.g. NumPy) before data loading. |
|
|
|
.. warning:: If ``spawn'' start method is used, :attr:`worker_init_fn` cannot be an |
|
unpicklable object, e.g., a lambda function. |
|
""" |
|
|
|
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, |
|
num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False, |
|
timeout=0, worker_init_fn=None): |
|
self.dataset = dataset |
|
self.batch_size = batch_size |
|
self.num_workers = num_workers |
|
self.collate_fn = collate_fn |
|
self.pin_memory = pin_memory |
|
self.drop_last = drop_last |
|
self.timeout = timeout |
|
self.worker_init_fn = worker_init_fn |
|
|
|
if timeout < 0: |
|
raise ValueError('timeout option should be non-negative') |
|
|
|
if batch_sampler is not None: |
|
if batch_size > 1 or shuffle or sampler is not None or drop_last: |
|
raise ValueError('batch_sampler is mutually exclusive with ' |
|
'batch_size, shuffle, sampler, and drop_last') |
|
|
|
if sampler is not None and shuffle: |
|
raise ValueError('sampler is mutually exclusive with shuffle') |
|
|
|
if self.num_workers < 0: |
|
raise ValueError('num_workers cannot be negative; ' |
|
'use num_workers=0 to disable multiprocessing.') |
|
|
|
if batch_sampler is None: |
|
if sampler is None: |
|
if shuffle: |
|
sampler = RandomSampler(dataset) |
|
else: |
|
sampler = SequentialSampler(dataset) |
|
batch_sampler = BatchSampler(sampler, batch_size, drop_last) |
|
|
|
self.sampler = sampler |
|
self.batch_sampler = batch_sampler |
|
|
|
def __iter__(self): |
|
return DataLoaderIter(self) |
|
|
|
def __len__(self): |
|
return len(self.batch_sampler) |
|
|