|
""" Dataset reader for HF IterableDataset |
|
""" |
|
import math |
|
import os |
|
from itertools import repeat, chain |
|
from typing import Optional |
|
|
|
import torch |
|
import torch.distributed as dist |
|
from PIL import Image |
|
|
|
try: |
|
import datasets |
|
from datasets.distributed import split_dataset_by_node |
|
from datasets.splits import SplitInfo |
|
except ImportError as e: |
|
print("Please install Hugging Face datasets package `pip install datasets`.") |
|
raise e |
|
|
|
|
|
from .class_map import load_class_map |
|
from .reader import Reader |
|
from .shared_count import SharedCount |
|
|
|
|
|
SHUFFLE_SIZE = int(os.environ.get('HFIDS_SHUFFLE_SIZE', 4096)) |
|
|
|
|
|
class ReaderHfids(Reader): |
|
def __init__( |
|
self, |
|
name: str, |
|
root: Optional[str] = None, |
|
split: str = 'train', |
|
is_training: bool = False, |
|
batch_size: int = 1, |
|
download: bool = False, |
|
repeats: int = 0, |
|
seed: int = 42, |
|
class_map: Optional[dict] = None, |
|
input_key: str = 'image', |
|
input_img_mode: str = 'RGB', |
|
target_key: str = 'label', |
|
target_img_mode: str = '', |
|
shuffle_size: Optional[int] = None, |
|
num_samples: Optional[int] = None, |
|
): |
|
super().__init__() |
|
self.root = root |
|
self.split = split |
|
self.is_training = is_training |
|
self.batch_size = batch_size |
|
self.download = download |
|
self.repeats = repeats |
|
self.common_seed = seed |
|
self.shuffle_size = shuffle_size or SHUFFLE_SIZE |
|
|
|
self.input_key = input_key |
|
self.input_img_mode = input_img_mode |
|
self.target_key = target_key |
|
self.target_img_mode = target_img_mode |
|
|
|
self.builder = datasets.load_dataset_builder(name, cache_dir=root) |
|
if download: |
|
self.builder.download_and_prepare() |
|
|
|
split_info: Optional[SplitInfo] = None |
|
if self.builder.info.splits and split in self.builder.info.splits: |
|
if isinstance(self.builder.info.splits[split], SplitInfo): |
|
split_info: Optional[SplitInfo] = self.builder.info.splits[split] |
|
|
|
if num_samples: |
|
self.num_samples = num_samples |
|
elif split_info and split_info.num_examples: |
|
self.num_samples = split_info.num_examples |
|
else: |
|
raise ValueError( |
|
"Dataset length is unknown, please pass `num_samples` explicitely. " |
|
"The number of steps needs to be known in advance for the learning rate scheduler." |
|
) |
|
|
|
self.remap_class = False |
|
if class_map: |
|
self.class_to_idx = load_class_map(class_map) |
|
self.remap_class = True |
|
else: |
|
self.class_to_idx = {} |
|
|
|
|
|
self.dist_rank = 0 |
|
self.dist_num_replicas = 1 |
|
if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1: |
|
self.dist_rank = dist.get_rank() |
|
self.dist_num_replicas = dist.get_world_size() |
|
|
|
|
|
self.worker_info = None |
|
self.worker_id = 0 |
|
self.num_workers = 1 |
|
self.global_worker_id = 0 |
|
self.global_num_workers = 1 |
|
|
|
|
|
self.ds: Optional[datasets.IterableDataset] = None |
|
self.epoch = SharedCount() |
|
|
|
def set_epoch(self, count): |
|
|
|
self.epoch.value = count |
|
|
|
def set_loader_cfg( |
|
self, |
|
num_workers: Optional[int] = None, |
|
): |
|
if self.ds is not None: |
|
return |
|
if num_workers is not None: |
|
self.num_workers = num_workers |
|
self.global_num_workers = self.dist_num_replicas * self.num_workers |
|
|
|
def _lazy_init(self): |
|
""" Lazily initialize worker (in worker processes) |
|
""" |
|
if self.worker_info is None: |
|
worker_info = torch.utils.data.get_worker_info() |
|
if worker_info is not None: |
|
self.worker_info = worker_info |
|
self.worker_id = worker_info.id |
|
self.num_workers = worker_info.num_workers |
|
self.global_num_workers = self.dist_num_replicas * self.num_workers |
|
self.global_worker_id = self.dist_rank * self.num_workers + self.worker_id |
|
|
|
if self.download: |
|
dataset = self.builder.as_dataset(split=self.split) |
|
|
|
ds = dataset.to_iterable_dataset(num_shards=self.global_num_workers) |
|
else: |
|
|
|
ds = self.builder.as_streaming_dataset(split=self.split) |
|
|
|
if self.is_training: |
|
|
|
ds = ds.shuffle(seed=self.common_seed, buffer_size=self.shuffle_size) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.ds = split_dataset_by_node(ds, rank=self.dist_rank, world_size=self.dist_num_replicas) |
|
|
|
def _num_samples_per_worker(self): |
|
num_worker_samples = \ |
|
max(1, self.repeats) * self.num_samples / max(self.global_num_workers, self.dist_num_replicas) |
|
if self.is_training or self.dist_num_replicas > 1: |
|
num_worker_samples = math.ceil(num_worker_samples) |
|
if self.is_training and self.batch_size is not None: |
|
num_worker_samples = math.ceil(num_worker_samples / self.batch_size) * self.batch_size |
|
return int(num_worker_samples) |
|
|
|
def __iter__(self): |
|
if self.ds is None: |
|
self._lazy_init() |
|
self.ds.set_epoch(self.epoch.value) |
|
|
|
target_sample_count = self._num_samples_per_worker() |
|
sample_count = 0 |
|
|
|
if self.is_training: |
|
ds_iter = chain.from_iterable(repeat(self.ds)) |
|
else: |
|
ds_iter = iter(self.ds) |
|
for sample in ds_iter: |
|
input_data: Image.Image = sample[self.input_key] |
|
if self.input_img_mode and input_data.mode != self.input_img_mode: |
|
input_data = input_data.convert(self.input_img_mode) |
|
target_data = sample[self.target_key] |
|
if self.target_img_mode: |
|
assert isinstance(target_data, Image.Image), "target_img_mode is specified but target is not an image" |
|
if target_data.mode != self.target_img_mode: |
|
target_data = target_data.convert(self.target_img_mode) |
|
elif self.remap_class: |
|
target_data = self.class_to_idx[target_data] |
|
yield input_data, target_data |
|
sample_count += 1 |
|
if self.is_training and sample_count >= target_sample_count: |
|
break |
|
|
|
def __len__(self): |
|
num_samples = self._num_samples_per_worker() * self.num_workers |
|
return num_samples |
|
|
|
def _filename(self, index, basename=False, absolute=False): |
|
assert False, "Not supported" |
|
|
|
def filenames(self, basename=False, absolute=False): |
|
""" Return all filenames in dataset, overrides base""" |
|
if self.ds is None: |
|
self._lazy_init() |
|
names = [] |
|
for sample in self.ds: |
|
if 'file_name' in sample: |
|
name = sample['file_name'] |
|
elif 'filename' in sample: |
|
name = sample['filename'] |
|
elif 'id' in sample: |
|
name = sample['id'] |
|
elif 'image_id' in sample: |
|
name = sample['image_id'] |
|
else: |
|
assert False, "No supported name field present" |
|
names.append(name) |
|
return names |