File size: 617 Bytes
b6d5990
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch.utils.data

class DataProvider():

    def __init__(self, cfg, dataset, batch_size=None, shuffle=True):
        super().__init__()
        self.dataset = dataset
        if batch_size is None:
            batch_size = cfg.BATCH_SIZE
        self.dataloader = torch.utils.data.DataLoader(
            self.dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=int(cfg.WORKERS),
            drop_last=False)

    def __len__(self):
        return len(self.dataset)

    def __iter__(self):
        for i, data in enumerate(self.dataloader):
            yield data