jacklangerman
commited on
Commit
•
edc7860
1
Parent(s):
e5605f2
add streaming support
Browse files- hoho/hoho.py +30 -5
hoho/hoho.py
CHANGED
@@ -3,8 +3,13 @@ import json
|
|
3 |
import shutil
|
4 |
from pathlib import Path
|
5 |
from typing import Dict
|
|
|
6 |
|
7 |
from PIL import ImageFile
|
|
|
|
|
|
|
|
|
8 |
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
9 |
|
10 |
LOCAL_DATADIR = None
|
@@ -29,11 +34,11 @@ def setup(local_dir='./data/usm-training-data/data'):
|
|
29 |
else:
|
30 |
LOCAL_DATADIR = local_val_datadir
|
31 |
print(f"Using {LOCAL_DATADIR} as the data directory (we are running locally)")
|
32 |
-
|
33 |
-
# os.system("ls -lahtr")
|
34 |
-
# os.system(f"ls -lahtr {LOCAL_DATADIR}")
|
35 |
|
36 |
-
|
|
|
|
|
|
|
37 |
return LOCAL_DATADIR
|
38 |
|
39 |
|
@@ -286,7 +291,9 @@ def get_params():
|
|
286 |
import webdataset as wds
|
287 |
import numpy as np
|
288 |
|
289 |
-
|
|
|
|
|
290 |
if LOCAL_DATADIR is None:
|
291 |
raise ValueError('LOCAL_DATADIR is not set. Please run setup() first.')
|
292 |
|
@@ -295,8 +302,24 @@ def get_dataset(decode='pil', proc=proc, split='train', dataset_type='webdataset
|
|
295 |
local_dir = local_dir / split
|
296 |
|
297 |
paths = [str(p) for p in local_dir.rglob('*.tar.gz')]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
298 |
|
299 |
dataset = wds.WebDataset(paths)
|
|
|
300 |
if decode is not None:
|
301 |
dataset = dataset.decode(decode)
|
302 |
else:
|
@@ -315,6 +338,8 @@ def get_dataset(decode='pil', proc=proc, split='train', dataset_type='webdataset
|
|
315 |
return datasets.IterableDataset.from_generator(lambda: dataset.iterator())
|
316 |
elif split == 'val':
|
317 |
return datasets.IterableDataset.from_generator(lambda: dataset.iterator())
|
|
|
|
|
318 |
|
319 |
|
320 |
|
|
|
3 |
import shutil
|
4 |
from pathlib import Path
|
5 |
from typing import Dict
|
6 |
+
import warnings
|
7 |
|
8 |
from PIL import ImageFile
|
9 |
+
|
10 |
+
from huggingface_hub.utils._headers import build_hf_headers # note: using _headers
|
11 |
+
|
12 |
+
|
13 |
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
14 |
|
15 |
LOCAL_DATADIR = None
|
|
|
34 |
else:
|
35 |
LOCAL_DATADIR = local_val_datadir
|
36 |
print(f"Using {LOCAL_DATADIR} as the data directory (we are running locally)")
|
|
|
|
|
|
|
37 |
|
38 |
+
if not LOCAL_DATADIR.exists():
|
39 |
+
warnings.warn(f"Data directory {LOCAL_DATADIR} does not exist: creating it...")
|
40 |
+
LOCAL_DATADIR.mkdir(parents=True)
|
41 |
+
|
42 |
return LOCAL_DATADIR
|
43 |
|
44 |
|
|
|
291 |
import webdataset as wds
|
292 |
import numpy as np
|
293 |
|
294 |
+
|
295 |
+
SHARD_IDS = {'train': (0, 25), 'val': (25, 26), 'public': (26, 27), 'private': (27, 32)}
|
296 |
+
def get_dataset(decode='pil', proc=proc, split='train', dataset_type='webdataset', stream=True):
|
297 |
if LOCAL_DATADIR is None:
|
298 |
raise ValueError('LOCAL_DATADIR is not set. Please run setup() first.')
|
299 |
|
|
|
302 |
local_dir = local_dir / split
|
303 |
|
304 |
paths = [str(p) for p in local_dir.rglob('*.tar.gz')]
|
305 |
+
msg = f'no tarfiles found in {local_dir}.'
|
306 |
+
if len(paths) == 0:
|
307 |
+
if stream:
|
308 |
+
if split=='all': split = 'train'
|
309 |
+
warnings.warn('streaming isn\'t using with \'all\': changing `split` to \'train\'')
|
310 |
+
warnings.warn(msg)
|
311 |
+
if split == 'val':
|
312 |
+
names = [f'data/val/inputs/hoho_v3_{i:03}-of-032.tar.gz' for i in range(*SHARD_IDS[split])]
|
313 |
+
elif split == 'train':
|
314 |
+
names = [f'data/train/hoho_v3_{i:03}-of-032.tar.gz' for i in range(*SHARD_IDS[split])]
|
315 |
+
|
316 |
+
auth = build_hf_headers()['authorization']
|
317 |
+
paths = [f"pipe:curl -L -s https://huggingface.co/datasets/usm3d/hoho-train-set/resolve/main/{name} -H 'Authorization: {auth}'" for name in names]
|
318 |
+
else:
|
319 |
+
raise FileNotFoundError(msg)
|
320 |
|
321 |
dataset = wds.WebDataset(paths)
|
322 |
+
|
323 |
if decode is not None:
|
324 |
dataset = dataset.decode(decode)
|
325 |
else:
|
|
|
338 |
return datasets.IterableDataset.from_generator(lambda: dataset.iterator())
|
339 |
elif split == 'val':
|
340 |
return datasets.IterableDataset.from_generator(lambda: dataset.iterator())
|
341 |
+
else:
|
342 |
+
raise NotImplementedError('only train and val are implemented as hf datasets')
|
343 |
|
344 |
|
345 |
|