File size: 6,508 Bytes
74e8f2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Simple data input from .jsonl files."""
import hashlib
import json
from multiprocessing.pool import ThreadPool
import os
import tempfile
import urllib.request
from absl import logging
import big_vision.datasets.core as ds_core
import jax
import numpy as np
import overrides
import tensorflow as tf
def cached_download(url, dest=None, verbose=True):
"""Download `url` to local file and return path to that, but with caching."""
# NOTE: there is a small chance of saving corrupted data if the process is
# interrupted in the middle of writing the file. Then, reading in the input
# pipeline will fail, and the fix is to nuke the temp folder.
# Compute a temp name based on the URL, so we can check if we already
# downloaded it before.
dest = dest or os.path.join(tempfile.gettempdir(), "bv")
os.makedirs(dest, exist_ok=True)
dest = os.path.join(dest, hashlib.md5(url.encode()).hexdigest())
# NOTE: we should use last-modified header to know whether to re-download.
if os.path.isfile(dest):
return dest
if verbose:
print(f"\rRetrieving {url} into {dest}", end="", flush=True)
with urllib.request.urlopen(url) as f:
data = f.read()
with open(dest, "wb+") as f:
f.write(data)
return dest
class DataSource(ds_core.DataSource):
""".jsonl DataSource."""
def __init__(self, fname, *, fopen_keys=(), download_keys=(),
start=0, stop=float("inf")):
"""Create data-source that's jsonl + data files (eg images).
This correctly supports multi-host in that each host only reads a subset of
the dataset automatically. However, currently, all hosts download all items
if `download_keys` is specified. TODO: b/lbeyer - This can be improved.
Args:
fname: str, the path to the jsonl file that holds the dataset.
fopen_keys: collection of str or dict, the keys in the dataset whose
string value actually is a file-path that should be opened and read,
and its content is what goes into the batch (eg image filenames
commonly ["image"]).
If a dict, the values are folders prefixed to the filenames.
Supports gs:// for reading from buckets.
download_keys: collection of str, the keys in the dataset whose string
value actually is a URL from which the file should be downloaded first.
files are downloaded to a persistent tmp folder using the URL hash as
filename. If the file already exists, the download is skipped.
Must be a subset of `fopen_keys`.
start: int, index of the first row to use; use for slicing the data.
stop: int or inf, index of the row after the last one to use.
Note:
This simple data input does not allow for nested/hierarchical values,
or in any way more complicated values like vectors. Use TFDS for that.
The way start/stop arguments are used is as in list slicing[start:stop].
"""
self.examples = []
with tf.io.gfile.GFile(fname) as f:
for i, line in enumerate(f):
if (start or 0) <= i < (stop or float("inf")):
try:
self.examples.append(json.loads(line))
except json.decoder.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in line {i}:\n{line}") from e
if download_keys:
for k in download_keys:
assert k in fopen_keys, (
f"{k} in download_keys but missing from fopen_keys {fopen_keys}")
# TODO: b/lbeyer - use info from trainer instead, move that to utils.
logging.info( # pylint: disable=logging-fstring-interpolation
f"\u001b[33mNOTE\u001b[0m: Downloading {download_keys} "
f"for dataset {fname} ({len(self.examples)} examples) ...")
def _dl_one(ex):
for k in download_keys:
ex[k] = cached_download(ex[k])
ThreadPool(100).map(_dl_one, self.examples)
print("Done")
logging.info("\u001b[33mNOTE\u001b[0m: Done downloading.")
# Normalize.
if isinstance(fopen_keys, (list, tuple)):
self.fopen_keys = {k: "" for k in fopen_keys}
else:
self.fopen_keys = fopen_keys or {}
# We need to apply fopen path prefix here already, because doing so while
# actually reading the files in TF, things are symbolic :(
for ex in self.examples:
for k, dirname in self.fopen_keys.items():
ex[k] = os.path.join(dirname, ex[k])
def _indices(self, *, process_split=True, process_index=None):
indices = np.arange(len(self.examples))
if not process_split:
return list(indices)
pid = jax.process_index() if process_index is None else process_index
return list(np.array_split(indices, jax.process_count())[pid])
@overrides.overrides
def get_tfdata(self, ordered=False, *, process_split=True, allow_cache=True):
del allow_cache # We don't cache anything anyways.
assert not process_split or len(self.examples) >= jax.process_count(), (
"Process splitting the data with fewer examples than processes!?")
my_idxs = self._indices(process_split=process_split)
if not ordered:
np.random.shuffle(my_idxs)
dataset = tf.data.Dataset.from_generator(
generator=lambda: ({"id": str(i), **self.examples[i]} for i in my_idxs),
output_signature={
"id": _guess_signature("0"),
**{k: _guess_signature(v) for k, v in self.examples[0].items()},
})
def _read_files(example):
for k in self.fopen_keys:
example[k] = tf.io.read_file(example[k])
return example
dataset = dataset.map(_read_files)
return dataset
@property
@overrides.overrides
def total_examples(self):
return len(self.examples)
@overrides.overrides
def num_examples_per_process(self):
return [len(self._indices(process_index=pid))
for pid in range(jax.process_count())]
def _guess_signature(value):
return tf.TensorSpec.from_tensor(tf.constant(value))
|