|
import importlib |
|
from typing import Type |
|
import torch |
|
from torch.utils.data import TensorDataset |
|
from torch.utils.data.dataloader import DataLoader |
|
|
|
from .datasets.ab_dataset import ABDataset |
|
|
|
from .datasets import * |
|
from .datasets.registery import static_dataset_registery |
|
|
|
|
|
def get_dataset(dataset_name, root_dir, split, transform=None, ignore_classes=[], idx_map=None) -> ABDataset: |
|
dataset_cls = static_dataset_registery[dataset_name][0] |
|
dataset = dataset_cls(root_dir, split, transform, ignore_classes, idx_map) |
|
|
|
return dataset |
|
|
|
|
|
def get_num_limited_dataset(dataset: ABDataset, num_samples: int, discard_label=True): |
|
dataloader = iter(DataLoader(dataset, num_samples // 2, shuffle=True)) |
|
x, y = [], [] |
|
cur_num_samples = 0 |
|
while True: |
|
batch = next(dataloader) |
|
cur_x, cur_y = batch[0], batch[1] |
|
|
|
x += [cur_x] |
|
y += [cur_y] |
|
cur_num_samples += cur_x.size(0) |
|
|
|
if cur_num_samples >= num_samples: |
|
break |
|
|
|
x, y = torch.cat(x)[0: num_samples], torch.cat(y)[0: num_samples] |
|
if discard_label: |
|
new_dataset = TensorDataset(x) |
|
else: |
|
new_dataset = TensorDataset(x, y) |
|
|
|
dataset.dataset = new_dataset |
|
|
|
return dataset |
|
|