"""ImageNet input pipeline.""" |
import collections |
import functools |
import itertools |
import math |
import multiprocessing.pool |
from absl import logging |
from big_vision.datasets import sequence_packing |
import big_vision.datasets.core as ds_core |
import big_vision.pp.builder as pp_builder |
import big_vision.utils as u |
import einops |
import jax |
import numpy as np |
import tensorflow as tf |
def make_for_train( |
data, preprocess_fn, batch_size, |
shuffle_buffer_size=None, cache_raw=False, |
num_parallel_calls=DEFAULT_NUM_PARALLEL_CALLS, prefetch=2, |
*, |
pre_filter_fn=None, post_filter_fn=None, |
pack=None, skip_errors=False, |
): |
"""Makes an input pipeline for training.""" |
data = _add_tpu_host_options(data) |
data = data.filter(pre_filter_fn) if pre_filter_fn else data |
data = data.cache() if cache_raw else data |
if shuffle_buffer_size: |
data = data.shuffle(shuffle_buffer_size, reshuffle_each_iteration=True) |
data = data.repeat(None) |
data = data.map(preprocess_fn, num_parallel_calls=num_parallel_calls) |
data = data.filter(post_filter_fn) if post_filter_fn else data |
data = data.ignore_errors(log_warning=True) if skip_errors else data |
data = sequence_packing.pack_dataset(data, pack) if pack else data |
if batch_size: |
data = data.batch(batch_size // jax.process_count(), drop_remainder=True) |
if prefetch: |
data = data.prefetch(prefetch) |
return data |
def training(input_config): |
"""Reads the data from a single dataset, or mixes it from multiple. |
The data is read either from one or mixed from multiple datasets, depending |
on the `input_config`. |
Args: |
input_config: Configures the input pipeline. See input_pipeline_test for |
examples. |
Returns: |
A tuple containing (possibly mixed) tf.data.Dataset and a total number of |
training examples. |
""" |
per_pipeline_configs = ( |
"shuffle_buffer_size", "cache_raw", "num_parallel_calls", |
"pre_filter_fn", "post_filter_fn", "pack", "skip_errors") |
def config_to_kw(config): |
assert "filter_fn" not in config, "Deprecated; use `pre_filter_fn` instead." |
return {k: config[k] for k in per_pipeline_configs if k in config} |
batch_size = input_config.batch_size |
if isinstance(input_config.data.get("name"), str): |
train_data = ds_core.get(**input_config.data) |
train_ds = make_for_train( |
data=train_data.get_tfdata(ordered=False), |
batch_size=batch_size, |
preprocess_fn=pp_builder.get_preprocess_fn(input_config.get("pp")), |
prefetch=input_config.get("prefetch", 2), |
**config_to_kw(input_config) |
) |
return train_ds, train_data.total_examples |
for k in per_pipeline_configs: |
assert k not in input_config, f"{k} is per-dataset in multi-input." |
def _make(name_and_weight): |
name, weight = name_and_weight |
dataset = input_config[name] |
train_data = ds_core.get(**dataset.data) |
dataset = make_for_train( |
data=train_data.get_tfdata(ordered=False), |
batch_size=None, |
preprocess_fn=pp_builder.get_preprocess_fn(dataset.get("pp"), name), |
prefetch=0, |
**config_to_kw(dataset) |
) |
if keys := input_config.get("keep_only"): |
dataset = dataset.map(lambda d, keys=keys: {k: d[k] for k in keys}) |
return name, dataset, weight, train_data.total_examples |
names, datasets, weights, totals = [], [], [], [] |
pool = multiprocessing.pool.ThreadPool(len(input_config.data)) |
for name, dataset, weight, total in pool.map( |
_make, ((name, w) for name, w in input_config.data.items() if w)): |
names.append(name) |
datasets.append(dataset) |
weights.append(weight) |
totals.append(total) |
weights = [x / sum(weights) for x in weights] |
logging.info( |
"NOTE: Total dataset mix size: %d\nContributions:\n%s", sum(totals), |
"\n".join(f"{ds}: {n} ({w * 100:.1g}%)" |
for ds, n, w in zip(names, totals, weights)) |
) |
train_ds = tf.data.Dataset.sample_from_datasets( |
datasets, weights, stop_on_empty_dataset=True) |
if input_config.get("pack"): |
train_ds = sequence_packing.pack_dataset(train_ds, input_config.get("pack")) |
train_ds = train_ds.batch( |
input_config["batch_size"] // jax.process_count(), drop_remainder=True) |
if (pf := input_config.get("prefetch", 2)): |
train_ds = train_ds.prefetch(pf) |
return train_ds, sum(totals) |
def make_for_inference( |
data, preprocess_fn, batch_size, num_ex_per_process, |
cache_raw=False, cache_final=False, |
num_parallel_calls=DEFAULT_NUM_PARALLEL_CALLS, prefetch=1, |
): |
"""Makes an input pipeline for inference.""" |
data = _add_tpu_host_options(data) |
data = data.cache() if cache_raw else data |
data = data.map(_add_internal_fields(preprocess_fn), |
num_parallel_calls=num_parallel_calls) |
data = data.concatenate(_get_pad_data(data)) |
local_batch_size = batch_size // jax.process_count() |
data = data.ragged_batch(batch_size=local_batch_size, drop_remainder=True) |
num_batches = math.ceil(max(num_ex_per_process) / local_batch_size) |
data = data.take(num_batches) |
data = data.cache() if cache_final else data |
data = data.repeat() |
data = data.prefetch(prefetch) if prefetch else data |
return data, num_batches |
def _get_pad_data(data): |
def zeros_like_spec(spec): |
return tf.zeros([x or 0 for x in spec.shape], spec.dtype) |
zero = jax.tree.map(zeros_like_spec, data.element_spec) |
return tf.data.Dataset.from_tensors(zero).repeat() |
def _add_internal_fields(pp_fn): |
"""Wraps pp_fn to add _mask and _id keys.""" |
def _pp_fn(example): |
result = pp_fn(example) |
result.setdefault("_mask", example.get("_mask", tf.constant(True))) |
if "_id" in example and "_id" not in result: |
result["_id"] = example["_id"] |
return result |
return _pp_fn |
def _add_tpu_host_options(data): |
options = tf.data.Options() |
options.threading.private_threadpool_size = 48 |
options.threading.max_intra_op_parallelism = 1 |
options.experimental_optimization.inject_prefetch = False |
return data.with_options(options) |
def prefetch_iterator(it, n): |
"""Runs iterator `it` ahead for `n` steps. Adapted from flax.""" |
if not n: |
yield from it |
return |
queue = collections.deque() |
def enqueue(n_steps): |
for data in itertools.islice(it, n_steps): |
queue.append(data) |
enqueue(n) |
while queue: |
yield queue.popleft() |
enqueue(1) |
def threadstart_iterator(it): |
"""Starts an iterator right away in a background thread.""" |
pool = multiprocessing.pool.ThreadPool(processes=1) |
first_ex_promise = pool.apply_async(lambda: next(it)) |
yield first_ex_promise.get() |
yield from it |
def tf_to_numpy(x): |
"""Convert any TF types to numpy.""" |
if isinstance(x, tf.Tensor): |
if x.dtype != tf.string: |
return x.numpy() |
else: |
return np.vectorize(bytes.decode, otypes=[str])(x.numpy()) |
real_shape = list(x.shape) |
for i, s in enumerate(real_shape[1:]): |
if s is not None: continue |
rowlens = np.diff(x.nested_row_splits[i]) |
if len(set(rowlens)) == 1: |
real_shape[i + 1] = rowlens[0] |
if None not in real_shape: |
return tf_to_numpy(x.flat_values).reshape(real_shape) |
splits = x.row_splits.numpy() |
rows = [tf_to_numpy(x.values[splits[i]:splits[i + 1]]) |
for i in range(len(splits) - 1)] |
return np.fromiter(rows, dtype=object) |
def start_global( |
data, global_devices, n_prefetch=1, keep_on_cpu=frozenset(), warmup=False): |
"""Starts the global input pipeline.""" |
def maybe_shard(name, x): |
if name in keep_on_cpu: |
return tf_to_numpy(x) |
return u.make_fsarray_from_local_slice(x, global_devices) |
it = iter(data) |
if warmup: |
it = threadstart_iterator(it) |
it = (u.tree_map_with_names(maybe_shard, elem) for elem in it) |
return prefetch_iterator(it, n_prefetch) |
def shard_and_put(x, shard=True, put=True): |
x = np.asarray(memoryview(x)) |
if shard: |
x = einops.rearrange(x, "(d l) ... -> d l ...", d=jax.local_device_count()) |
if shard and put: |
x = jax.device_put_sharded(list(x), jax.local_devices()) |
return x |
def start_input_pipeline(data, n_prefetch=1, shard=True): |
fn = functools.partial(shard_and_put, shard=shard, put=n_prefetch) |
it = (jax.tree.map(fn, elem) for elem in iter(data)) |
return prefetch_iterator(it, n_prefetch) |
def start_ragged_input_pipeline(data, n_prefetch=1, shard=True, ragged=None): |
def maybe_shard_and_put(name, x): |
return x if name in (ragged or {}) else shard_and_put(x, shard) |
it = (u.tree_map_with_names(maybe_shard_and_put, elem) for elem in iter(data)) |
return prefetch_iterator(it, n_prefetch) |