Spaces:
Paused
Paused
# Copyright (C) 2024-present Naver Corporation. All rights reserved. | |
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
from .utils.transforms import * | |
from .base.batched_sampler import BatchedRandomSampler # noqa: F401 | |
from .co3d import Co3d # noqa: F401 | |
def get_data_loader(dataset, batch_size, num_workers=8, shuffle=True, drop_last=True, pin_mem=True): | |
import torch | |
from croco.utils.misc import get_world_size, get_rank | |
# pytorch dataset | |
if isinstance(dataset, str): | |
dataset = eval(dataset) | |
world_size = get_world_size() | |
rank = get_rank() | |
try: | |
sampler = dataset.make_sampler(batch_size, shuffle=shuffle, world_size=world_size, | |
rank=rank, drop_last=drop_last) | |
except (AttributeError, NotImplementedError): | |
# not avail for this dataset | |
if torch.distributed.is_initialized(): | |
sampler = torch.utils.data.DistributedSampler( | |
dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, drop_last=drop_last | |
) | |
elif shuffle: | |
sampler = torch.utils.data.RandomSampler(dataset) | |
else: | |
sampler = torch.utils.data.SequentialSampler(dataset) | |
data_loader = torch.utils.data.DataLoader( | |
dataset, | |
sampler=sampler, | |
batch_size=batch_size, | |
num_workers=num_workers, | |
pin_memory=pin_mem, | |
drop_last=drop_last, | |
) | |
return data_loader | |