Spaces:
Runtime error
Runtime error
File size: 617 Bytes
b6d5990 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
import torch.utils.data
class DataProvider():
def __init__(self, cfg, dataset, batch_size=None, shuffle=True):
super().__init__()
self.dataset = dataset
if batch_size is None:
batch_size = cfg.BATCH_SIZE
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=int(cfg.WORKERS),
drop_last=False)
def __len__(self):
return len(self.dataset)
def __iter__(self):
for i, data in enumerate(self.dataloader):
yield data |