""" Dataset reader that wraps Hugging Face datasets Hacked together by / Copyright 2022 Ross Wightman """ import io import math from typing import Optional import torch import torch.distributed as dist from PIL import Image try: import datasets except ImportError as e: print("Please install Hugging Face datasets package `pip install datasets`.") raise e from .class_map import load_class_map from .reader import Reader def get_class_labels(info, label_key='label'): if 'label' not in info.features: return {} class_label = info.features[label_key] class_to_idx = {n: class_label.str2int(n) for n in class_label.names} return class_to_idx class ReaderHfds(Reader): def __init__( self, name: str, root: Optional[str] = None, split: str = 'train', class_map: dict = None, image_key: str = 'image', target_key: str = 'label', download: bool = False, ): """ """ super().__init__() self.root = root self.split = split self.dataset = datasets.load_dataset( name, # 'name' maps to path arg in hf datasets split=split, cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path ) # leave decode for caller, plus we want easy access to original path names... self.dataset = self.dataset.cast_column(image_key, datasets.Image(decode=False)) self.image_key = image_key self.label_key = target_key self.remap_class = False if class_map: self.class_to_idx = load_class_map(class_map) self.remap_class = True else: self.class_to_idx = get_class_labels(self.dataset.info, self.label_key) self.split_info = self.dataset.info.splits[split] self.num_samples = self.split_info.num_examples def __getitem__(self, index): item = self.dataset[index] image = item[self.image_key] if 'bytes' in image and image['bytes']: image = io.BytesIO(image['bytes']) else: assert 'path' in image and image['path'] image = open(image['path'], 'rb') label = item[self.label_key] if self.remap_class: label = self.class_to_idx[label] return image, label def __len__(self): return len(self.dataset) def _filename(self, index, basename=False, absolute=False): item = self.dataset[index] return item[self.image_key]['path']