File size: 1,281 Bytes
b84549f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import importlib
from typing import Type
import torch
from torch.utils.data import TensorDataset
from torch.utils.data.dataloader import DataLoader

from .datasets.ab_dataset import ABDataset

from .datasets import * # import all datasets
from .datasets.registery import static_dataset_registery


def get_dataset(dataset_name, root_dir, split, transform=None, ignore_classes=[], idx_map=None) -> ABDataset:
    dataset_cls = static_dataset_registery[dataset_name][0]
    dataset = dataset_cls(root_dir, split, transform, ignore_classes, idx_map)

    return dataset


def get_num_limited_dataset(dataset: ABDataset, num_samples: int, discard_label=True):
    dataloader = iter(DataLoader(dataset, num_samples // 2, shuffle=True))
    x, y = [], []
    cur_num_samples = 0
    while True:
        batch = next(dataloader)
        cur_x, cur_y = batch[0], batch[1]
        
        x += [cur_x]
        y += [cur_y]
        cur_num_samples += cur_x.size(0)
        
        if cur_num_samples >= num_samples:
            break
        
    x, y = torch.cat(x)[0: num_samples], torch.cat(y)[0: num_samples]
    if discard_label:
        new_dataset = TensorDataset(x)
    else:
        new_dataset = TensorDataset(x, y)
    
    dataset.dataset = new_dataset
    
    return dataset