""" Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import time import random import torch # from lavis.datasets.data_utils import move_to_cuda from torch.utils.data import DataLoader class MultiIterLoader: """ A simple wrapper for iterating over multiple iterators. Args: loaders (List[Loader]): List of Iterator loaders. ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly. """ def __init__(self, loaders, ratios=None): # assert all loaders has __next__ method for loader in loaders: assert hasattr(loader, "__next__"), "Loader {} has no __next__ method.".format(loader) if ratios is None: ratios = [1.0] * len(loaders) else: assert len(ratios) == len(loaders) ratios = [float(ratio) / sum(ratios) for ratio in ratios] self.loaders = loaders self.ratios = ratios def __next__(self): # random sample from each loader by ratio loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0] return next(self.loaders[loader_idx]) def __iter__(self): return self class PrefetchLoader(object): """ Modified from https://github.com/ChenRocks/UNITER. overlap compute and cuda data transfer (copied and then modified from nvidia apex) """ def __init__(self, loader): self.loader = loader self.stream = torch.cuda.Stream() def __iter__(self): loader_it = iter(self.loader) self.preload(loader_it) batch = self.next(loader_it) while batch is not None: is_tuple = isinstance(batch, tuple) if is_tuple: task, batch = batch if is_tuple: yield task, batch else: yield batch batch = self.next(loader_it) def __len__(self): return len(self.loader) def preload(self, it): try: self.batch = next(it) except StopIteration: self.batch = None return # if record_stream() doesn't work, another option is to make sure # device inputs are created on the main stream. # self.next_input_gpu = torch.empty_like(self.next_input, # device='cuda') # self.next_target_gpu = torch.empty_like(self.next_target, # device='cuda') # Need to make sure the memory allocated for next_* is not still in use # by the main stream at the time we start copying to next_*: # self.stream.wait_stream(torch.cuda.current_stream()) # with torch.cuda.stream(self.stream): # self.batch = move_to_cuda(self.batch) # more code for the alternative if record_stream() doesn't work: # copy_ will record the use of the pinned source tensor in this # side stream. # self.next_input_gpu.copy_(self.next_input, non_blocking=True) # self.next_target_gpu.copy_(self.next_target, non_blocking=True) # self.next_input = self.next_input_gpu # self.next_target = self.next_target_gpu def next(self, it): torch.cuda.current_stream().wait_stream(self.stream) batch = self.batch if batch is not None: record_cuda_stream(batch) self.preload(it) return batch def __getattr__(self, name): method = self.loader.__getattribute__(name) return method def record_cuda_stream(batch): if isinstance(batch, torch.Tensor): batch.record_stream(torch.cuda.current_stream()) elif isinstance(batch, list) or isinstance(batch, tuple): for t in batch: record_cuda_stream(t) elif isinstance(batch, dict): for t in batch.values(): record_cuda_stream(t) else: pass class IterLoader: """ A wrapper to convert DataLoader as an infinite iterator. Modified from: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py """ def __init__(self, dataloader: DataLoader, use_distributed: bool = False): self._dataloader = dataloader self.iter_loader = iter(self._dataloader) self._use_distributed = use_distributed self._epoch = 0 @property def epoch(self) -> int: return self._epoch def __next__(self): try: data = next(self.iter_loader) except StopIteration: self._epoch += 1 if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed: self._dataloader.sampler.set_epoch(self._epoch) time.sleep(2) # Prevent possible deadlock during epoch transition self.iter_loader = iter(self._dataloader) data = next(self.iter_loader) return data def __iter__(self): return self def __len__(self): return len(self._dataloader)