|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Simple data input from .jsonl files.""" |
|
|
|
import hashlib |
|
import json |
|
from multiprocessing.pool import ThreadPool |
|
import os |
|
import tempfile |
|
import urllib.request |
|
|
|
from absl import logging |
|
import big_vision.datasets.core as ds_core |
|
import jax |
|
import numpy as np |
|
import overrides |
|
import tensorflow as tf |
|
|
|
|
|
def cached_download(url, dest=None, verbose=True): |
|
"""Download `url` to local file and return path to that, but with caching.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
dest = dest or os.path.join(tempfile.gettempdir(), "bv") |
|
os.makedirs(dest, exist_ok=True) |
|
dest = os.path.join(dest, hashlib.md5(url.encode()).hexdigest()) |
|
|
|
|
|
if os.path.isfile(dest): |
|
return dest |
|
|
|
if verbose: |
|
print(f"\rRetrieving {url} into {dest}", end="", flush=True) |
|
|
|
with urllib.request.urlopen(url) as f: |
|
data = f.read() |
|
with open(dest, "wb+") as f: |
|
f.write(data) |
|
return dest |
|
|
|
|
|
class DataSource(ds_core.DataSource): |
|
""".jsonl DataSource.""" |
|
|
|
def __init__(self, fname, *, fopen_keys=(), download_keys=(), |
|
start=0, stop=float("inf")): |
|
"""Create data-source that's jsonl + data files (eg images). |
|
|
|
This correctly supports multi-host in that each host only reads a subset of |
|
the dataset automatically. However, currently, all hosts download all items |
|
if `download_keys` is specified. TODO: b/lbeyer - This can be improved. |
|
|
|
Args: |
|
fname: str, the path to the jsonl file that holds the dataset. |
|
fopen_keys: collection of str or dict, the keys in the dataset whose |
|
string value actually is a file-path that should be opened and read, |
|
and its content is what goes into the batch (eg image filenames |
|
commonly ["image"]). |
|
If a dict, the values are folders prefixed to the filenames. |
|
Supports gs:// for reading from buckets. |
|
download_keys: collection of str, the keys in the dataset whose string |
|
value actually is a URL from which the file should be downloaded first. |
|
files are downloaded to a persistent tmp folder using the URL hash as |
|
filename. If the file already exists, the download is skipped. |
|
Must be a subset of `fopen_keys`. |
|
start: int, index of the first row to use; use for slicing the data. |
|
stop: int or inf, index of the row after the last one to use. |
|
|
|
Note: |
|
This simple data input does not allow for nested/hierarchical values, |
|
or in any way more complicated values like vectors. Use TFDS for that. |
|
|
|
The way start/stop arguments are used is as in list slicing[start:stop]. |
|
""" |
|
self.examples = [] |
|
|
|
with tf.io.gfile.GFile(fname) as f: |
|
for i, line in enumerate(f): |
|
if (start or 0) <= i < (stop or float("inf")): |
|
try: |
|
self.examples.append(json.loads(line)) |
|
except json.decoder.JSONDecodeError as e: |
|
raise ValueError(f"Invalid JSON in line {i}:\n{line}") from e |
|
|
|
if download_keys: |
|
for k in download_keys: |
|
assert k in fopen_keys, ( |
|
f"{k} in download_keys but missing from fopen_keys {fopen_keys}") |
|
|
|
|
|
logging.info( |
|
f"\u001b[33mNOTE\u001b[0m: Downloading {download_keys} " |
|
f"for dataset {fname} ({len(self.examples)} examples) ...") |
|
|
|
def _dl_one(ex): |
|
for k in download_keys: |
|
ex[k] = cached_download(ex[k]) |
|
|
|
ThreadPool(100).map(_dl_one, self.examples) |
|
print("Done") |
|
logging.info("\u001b[33mNOTE\u001b[0m: Done downloading.") |
|
|
|
|
|
if isinstance(fopen_keys, (list, tuple)): |
|
self.fopen_keys = {k: "" for k in fopen_keys} |
|
else: |
|
self.fopen_keys = fopen_keys or {} |
|
|
|
|
|
|
|
for ex in self.examples: |
|
for k, dirname in self.fopen_keys.items(): |
|
ex[k] = os.path.join(dirname, ex[k]) |
|
|
|
def _indices(self, *, process_split=True, process_index=None): |
|
indices = np.arange(len(self.examples)) |
|
|
|
if not process_split: |
|
return list(indices) |
|
|
|
pid = jax.process_index() if process_index is None else process_index |
|
return list(np.array_split(indices, jax.process_count())[pid]) |
|
|
|
@overrides.overrides |
|
def get_tfdata(self, ordered=False, *, process_split=True, allow_cache=True): |
|
del allow_cache |
|
assert not process_split or len(self.examples) >= jax.process_count(), ( |
|
"Process splitting the data with fewer examples than processes!?") |
|
|
|
my_idxs = self._indices(process_split=process_split) |
|
if not ordered: |
|
np.random.shuffle(my_idxs) |
|
|
|
dataset = tf.data.Dataset.from_generator( |
|
generator=lambda: ({"id": str(i), **self.examples[i]} for i in my_idxs), |
|
output_signature={ |
|
"id": _guess_signature("0"), |
|
**{k: _guess_signature(v) for k, v in self.examples[0].items()}, |
|
}) |
|
|
|
def _read_files(example): |
|
for k in self.fopen_keys: |
|
example[k] = tf.io.read_file(example[k]) |
|
return example |
|
dataset = dataset.map(_read_files) |
|
|
|
return dataset |
|
|
|
@property |
|
@overrides.overrides |
|
def total_examples(self): |
|
return len(self.examples) |
|
|
|
@overrides.overrides |
|
def num_examples_per_process(self): |
|
return [len(self._indices(process_index=pid)) |
|
for pid in range(jax.process_count())] |
|
|
|
|
|
def _guess_signature(value): |
|
return tf.TensorSpec.from_tensor(tf.constant(value)) |
|
|