File size: 1,281 Bytes
b84549f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
import importlib
from typing import Type
import torch
from torch.utils.data import TensorDataset
from torch.utils.data.dataloader import DataLoader
from .datasets.ab_dataset import ABDataset
from .datasets import * # import all datasets
from .datasets.registery import static_dataset_registery
def get_dataset(dataset_name, root_dir, split, transform=None, ignore_classes=[], idx_map=None) -> ABDataset:
dataset_cls = static_dataset_registery[dataset_name][0]
dataset = dataset_cls(root_dir, split, transform, ignore_classes, idx_map)
return dataset
def get_num_limited_dataset(dataset: ABDataset, num_samples: int, discard_label=True):
dataloader = iter(DataLoader(dataset, num_samples // 2, shuffle=True))
x, y = [], []
cur_num_samples = 0
while True:
batch = next(dataloader)
cur_x, cur_y = batch[0], batch[1]
x += [cur_x]
y += [cur_y]
cur_num_samples += cur_x.size(0)
if cur_num_samples >= num_samples:
break
x, y = torch.cat(x)[0: num_samples], torch.cat(y)[0: num_samples]
if discard_label:
new_dataset = TensorDataset(x)
else:
new_dataset = TensorDataset(x, y)
dataset.dataset = new_dataset
return dataset
|