pengdadaaa's picture
Upload 741 files
786f6a6 verified
""" Dataset reader for HF IterableDataset
"""
import math
import os
from itertools import repeat, chain
from typing import Optional
import torch
import torch.distributed as dist
from PIL import Image
try:
import datasets
from datasets.distributed import split_dataset_by_node
from datasets.splits import SplitInfo
except ImportError as e:
print("Please install Hugging Face datasets package `pip install datasets`.")
raise e
from .class_map import load_class_map
from .reader import Reader
from .shared_count import SharedCount
SHUFFLE_SIZE = int(os.environ.get('HFIDS_SHUFFLE_SIZE', 4096))
class ReaderHfids(Reader):
def __init__(
self,
name: str,
root: Optional[str] = None,
split: str = 'train',
is_training: bool = False,
batch_size: int = 1,
download: bool = False,
repeats: int = 0,
seed: int = 42,
class_map: Optional[dict] = None,
input_key: str = 'image',
input_img_mode: str = 'RGB',
target_key: str = 'label',
target_img_mode: str = '',
shuffle_size: Optional[int] = None,
num_samples: Optional[int] = None,
):
super().__init__()
self.root = root
self.split = split
self.is_training = is_training
self.batch_size = batch_size
self.download = download
self.repeats = repeats
self.common_seed = seed # a seed that's fixed across all worker / distributed instances
self.shuffle_size = shuffle_size or SHUFFLE_SIZE
self.input_key = input_key
self.input_img_mode = input_img_mode
self.target_key = target_key
self.target_img_mode = target_img_mode
self.builder = datasets.load_dataset_builder(name, cache_dir=root)
if download:
self.builder.download_and_prepare()
split_info: Optional[SplitInfo] = None
if self.builder.info.splits and split in self.builder.info.splits:
if isinstance(self.builder.info.splits[split], SplitInfo):
split_info: Optional[SplitInfo] = self.builder.info.splits[split]
if num_samples:
self.num_samples = num_samples
elif split_info and split_info.num_examples:
self.num_samples = split_info.num_examples
else:
raise ValueError(
"Dataset length is unknown, please pass `num_samples` explicitely. "
"The number of steps needs to be known in advance for the learning rate scheduler."
)
self.remap_class = False
if class_map:
self.class_to_idx = load_class_map(class_map)
self.remap_class = True
else:
self.class_to_idx = {}
# Distributed world state
self.dist_rank = 0
self.dist_num_replicas = 1
if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
self.dist_rank = dist.get_rank()
self.dist_num_replicas = dist.get_world_size()
# Attributes that are updated in _lazy_init
self.worker_info = None
self.worker_id = 0
self.num_workers = 1
self.global_worker_id = 0
self.global_num_workers = 1
# Initialized lazily on each dataloader worker process
self.ds: Optional[datasets.IterableDataset] = None
self.epoch = SharedCount()
def set_epoch(self, count):
# to update the shuffling effective_seed = seed + epoch
self.epoch.value = count
def set_loader_cfg(
self,
num_workers: Optional[int] = None,
):
if self.ds is not None:
return
if num_workers is not None:
self.num_workers = num_workers
self.global_num_workers = self.dist_num_replicas * self.num_workers
def _lazy_init(self):
""" Lazily initialize worker (in worker processes)
"""
if self.worker_info is None:
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
self.worker_info = worker_info
self.worker_id = worker_info.id
self.num_workers = worker_info.num_workers
self.global_num_workers = self.dist_num_replicas * self.num_workers
self.global_worker_id = self.dist_rank * self.num_workers + self.worker_id
if self.download:
dataset = self.builder.as_dataset(split=self.split)
# to distribute evenly to workers
ds = dataset.to_iterable_dataset(num_shards=self.global_num_workers)
else:
# in this case the number of shard is determined by the number of remote files
ds = self.builder.as_streaming_dataset(split=self.split)
if self.is_training:
# will shuffle the list of shards and use a shuffle buffer
ds = ds.shuffle(seed=self.common_seed, buffer_size=self.shuffle_size)
# Distributed:
# The dataset has a number of shards that is a factor of `dist_num_replicas` (i.e. if `ds.n_shards % dist_num_replicas == 0`),
# so the shards are evenly assigned across the nodes.
# If it's not the case for dataset streaming, each node keeps 1 example out of `dist_num_replicas`, skipping the other examples.
# Workers:
# In a node, datasets.IterableDataset assigns the shards assigned to the node as evenly as possible to workers.
self.ds = split_dataset_by_node(ds, rank=self.dist_rank, world_size=self.dist_num_replicas)
def _num_samples_per_worker(self):
num_worker_samples = \
max(1, self.repeats) * self.num_samples / max(self.global_num_workers, self.dist_num_replicas)
if self.is_training or self.dist_num_replicas > 1:
num_worker_samples = math.ceil(num_worker_samples)
if self.is_training and self.batch_size is not None:
num_worker_samples = math.ceil(num_worker_samples / self.batch_size) * self.batch_size
return int(num_worker_samples)
def __iter__(self):
if self.ds is None:
self._lazy_init()
self.ds.set_epoch(self.epoch.value)
target_sample_count = self._num_samples_per_worker()
sample_count = 0
if self.is_training:
ds_iter = chain.from_iterable(repeat(self.ds))
else:
ds_iter = iter(self.ds)
for sample in ds_iter:
input_data: Image.Image = sample[self.input_key]
if self.input_img_mode and input_data.mode != self.input_img_mode:
input_data = input_data.convert(self.input_img_mode)
target_data = sample[self.target_key]
if self.target_img_mode:
assert isinstance(target_data, Image.Image), "target_img_mode is specified but target is not an image"
if target_data.mode != self.target_img_mode:
target_data = target_data.convert(self.target_img_mode)
elif self.remap_class:
target_data = self.class_to_idx[target_data]
yield input_data, target_data
sample_count += 1
if self.is_training and sample_count >= target_sample_count:
break
def __len__(self):
num_samples = self._num_samples_per_worker() * self.num_workers
return num_samples
def _filename(self, index, basename=False, absolute=False):
assert False, "Not supported" # no random access to examples
def filenames(self, basename=False, absolute=False):
""" Return all filenames in dataset, overrides base"""
if self.ds is None:
self._lazy_init()
names = []
for sample in self.ds:
if 'file_name' in sample:
name = sample['file_name']
elif 'filename' in sample:
name = sample['filename']
elif 'id' in sample:
name = sample['id']
elif 'image_id' in sample:
name = sample['image_id']
else:
assert False, "No supported name field present"
names.append(name)
return names