SEED-Story / src /data /dataloader_utils.py
xinlai's picture
seedx
674d663
raw
history blame
5.24 kB
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import time
import random
import torch
# from lavis.datasets.data_utils import move_to_cuda
from torch.utils.data import DataLoader
class MultiIterLoader:
"""
A simple wrapper for iterating over multiple iterators.
Args:
loaders (List[Loader]): List of Iterator loaders.
ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly.
"""
def __init__(self, loaders, ratios=None):
# assert all loaders has __next__ method
for loader in loaders:
assert hasattr(loader, "__next__"), "Loader {} has no __next__ method.".format(loader)
if ratios is None:
ratios = [1.0] * len(loaders)
else:
assert len(ratios) == len(loaders)
ratios = [float(ratio) / sum(ratios) for ratio in ratios]
self.loaders = loaders
self.ratios = ratios
def __next__(self):
# random sample from each loader by ratio
loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0]
return next(self.loaders[loader_idx])
def __iter__(self):
return self
class PrefetchLoader(object):
"""
Modified from https://github.com/ChenRocks/UNITER.
overlap compute and cuda data transfer
(copied and then modified from nvidia apex)
"""
def __init__(self, loader):
self.loader = loader
self.stream = torch.cuda.Stream()
def __iter__(self):
loader_it = iter(self.loader)
self.preload(loader_it)
batch = self.next(loader_it)
while batch is not None:
is_tuple = isinstance(batch, tuple)
if is_tuple:
task, batch = batch
if is_tuple:
yield task, batch
else:
yield batch
batch = self.next(loader_it)
def __len__(self):
return len(self.loader)
def preload(self, it):
try:
self.batch = next(it)
except StopIteration:
self.batch = None
return
# if record_stream() doesn't work, another option is to make sure
# device inputs are created on the main stream.
# self.next_input_gpu = torch.empty_like(self.next_input,
# device='cuda')
# self.next_target_gpu = torch.empty_like(self.next_target,
# device='cuda')
# Need to make sure the memory allocated for next_* is not still in use
# by the main stream at the time we start copying to next_*:
# self.stream.wait_stream(torch.cuda.current_stream())
# with torch.cuda.stream(self.stream):
# self.batch = move_to_cuda(self.batch)
# more code for the alternative if record_stream() doesn't work:
# copy_ will record the use of the pinned source tensor in this
# side stream.
# self.next_input_gpu.copy_(self.next_input, non_blocking=True)
# self.next_target_gpu.copy_(self.next_target, non_blocking=True)
# self.next_input = self.next_input_gpu
# self.next_target = self.next_target_gpu
def next(self, it):
torch.cuda.current_stream().wait_stream(self.stream)
batch = self.batch
if batch is not None:
record_cuda_stream(batch)
self.preload(it)
return batch
def __getattr__(self, name):
method = self.loader.__getattribute__(name)
return method
def record_cuda_stream(batch):
if isinstance(batch, torch.Tensor):
batch.record_stream(torch.cuda.current_stream())
elif isinstance(batch, list) or isinstance(batch, tuple):
for t in batch:
record_cuda_stream(t)
elif isinstance(batch, dict):
for t in batch.values():
record_cuda_stream(t)
else:
pass
class IterLoader:
"""
A wrapper to convert DataLoader as an infinite iterator.
Modified from:
https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py
"""
def __init__(self, dataloader: DataLoader, use_distributed: bool = False):
self._dataloader = dataloader
self.iter_loader = iter(self._dataloader)
self._use_distributed = use_distributed
self._epoch = 0
@property
def epoch(self) -> int:
return self._epoch
def __next__(self):
try:
data = next(self.iter_loader)
except StopIteration:
self._epoch += 1
if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed:
self._dataloader.sampler.set_epoch(self._epoch)
time.sleep(2) # Prevent possible deadlock during epoch transition
self.iter_loader = iter(self._dataloader)
data = next(self.iter_loader)
return data
def __iter__(self):
return self
def __len__(self):
return len(self._dataloader)