|
import threading
|
|
import random
|
|
|
|
import torch
|
|
import torch.multiprocessing as multiprocessing
|
|
from torch.utils.data import DataLoader
|
|
from torch.utils.data import SequentialSampler
|
|
from torch.utils.data import RandomSampler
|
|
from torch.utils.data import BatchSampler
|
|
from torch.utils.data import _utils
|
|
from torch.utils.data.dataloader import _DataLoaderIter
|
|
|
|
from torch.utils.data._utils import collate
|
|
from torch.utils.data._utils import signal_handling
|
|
from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
|
|
from torch.utils.data._utils import ExceptionWrapper
|
|
from torch.utils.data._utils import IS_WINDOWS
|
|
from torch.utils.data._utils.worker import ManagerWatchdog
|
|
|
|
from torch._six import queue
|
|
|
|
def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id):
|
|
try:
|
|
collate._use_shared_memory = True
|
|
signal_handling._set_worker_signal_handlers()
|
|
|
|
torch.set_num_threads(1)
|
|
random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
|
|
data_queue.cancel_join_thread()
|
|
|
|
if init_fn is not None:
|
|
init_fn(worker_id)
|
|
|
|
watchdog = ManagerWatchdog()
|
|
|
|
while watchdog.is_alive():
|
|
try:
|
|
r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
|
|
except queue.Empty:
|
|
continue
|
|
|
|
if r is None:
|
|
assert done_event.is_set()
|
|
return
|
|
elif done_event.is_set():
|
|
continue
|
|
|
|
idx, batch_indices = r
|
|
try:
|
|
idx_scale = 0
|
|
if len(scale) > 1 and dataset.train:
|
|
idx_scale = random.randrange(0, len(scale))
|
|
dataset.set_scale(idx_scale)
|
|
|
|
samples = collate_fn([dataset[i] for i in batch_indices])
|
|
samples.append(idx_scale)
|
|
except Exception:
|
|
data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
|
|
else:
|
|
data_queue.put((idx, samples))
|
|
del samples
|
|
|
|
except KeyboardInterrupt:
|
|
pass
|
|
|
|
class _MSDataLoaderIter(_DataLoaderIter):
|
|
|
|
def __init__(self, loader):
|
|
self.dataset = loader.dataset
|
|
self.scale = loader.scale
|
|
self.collate_fn = loader.collate_fn
|
|
self.batch_sampler = loader.batch_sampler
|
|
self.num_workers = loader.num_workers
|
|
self.pin_memory = loader.pin_memory and torch.cuda.is_available()
|
|
self.timeout = loader.timeout
|
|
|
|
self.sample_iter = iter(self.batch_sampler)
|
|
|
|
base_seed = torch.LongTensor(1).random_().item()
|
|
|
|
if self.num_workers > 0:
|
|
self.worker_init_fn = loader.worker_init_fn
|
|
self.worker_queue_idx = 0
|
|
self.worker_result_queue = multiprocessing.Queue()
|
|
self.batches_outstanding = 0
|
|
self.worker_pids_set = False
|
|
self.shutdown = False
|
|
self.send_idx = 0
|
|
self.rcvd_idx = 0
|
|
self.reorder_dict = {}
|
|
self.done_event = multiprocessing.Event()
|
|
|
|
base_seed = torch.LongTensor(1).random_()[0]
|
|
|
|
self.index_queues = []
|
|
self.workers = []
|
|
for i in range(self.num_workers):
|
|
index_queue = multiprocessing.Queue()
|
|
index_queue.cancel_join_thread()
|
|
w = multiprocessing.Process(
|
|
target=_ms_loop,
|
|
args=(
|
|
self.dataset,
|
|
index_queue,
|
|
self.worker_result_queue,
|
|
self.done_event,
|
|
self.collate_fn,
|
|
self.scale,
|
|
base_seed + i,
|
|
self.worker_init_fn,
|
|
i
|
|
)
|
|
)
|
|
w.daemon = True
|
|
w.start()
|
|
self.index_queues.append(index_queue)
|
|
self.workers.append(w)
|
|
|
|
if self.pin_memory:
|
|
self.data_queue = queue.Queue()
|
|
pin_memory_thread = threading.Thread(
|
|
target=_utils.pin_memory._pin_memory_loop,
|
|
args=(
|
|
self.worker_result_queue,
|
|
self.data_queue,
|
|
torch.cuda.current_device(),
|
|
self.done_event
|
|
)
|
|
)
|
|
pin_memory_thread.daemon = True
|
|
pin_memory_thread.start()
|
|
self.pin_memory_thread = pin_memory_thread
|
|
else:
|
|
self.data_queue = self.worker_result_queue
|
|
|
|
_utils.signal_handling._set_worker_pids(
|
|
id(self), tuple(w.pid for w in self.workers)
|
|
)
|
|
_utils.signal_handling._set_SIGCHLD_handler()
|
|
self.worker_pids_set = True
|
|
|
|
for _ in range(2 * self.num_workers):
|
|
self._put_indices()
|
|
|
|
|
|
class MSDataLoader(DataLoader):
|
|
|
|
def __init__(self, cfg, *args, **kwargs):
|
|
super(MSDataLoader, self).__init__(
|
|
*args, **kwargs, num_workers=cfg.n_threads
|
|
)
|
|
self.scale = cfg.scale
|
|
|
|
def __iter__(self):
|
|
return _MSDataLoaderIter(self)
|
|
|
|
|