|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""ImageNet input pipeline.""" |
|
import collections |
|
import functools |
|
import itertools |
|
import math |
|
import multiprocessing.pool |
|
|
|
from absl import logging |
|
from big_vision.datasets import sequence_packing |
|
import big_vision.datasets.core as ds_core |
|
import big_vision.pp.builder as pp_builder |
|
import big_vision.utils as u |
|
import einops |
|
import jax |
|
import numpy as np |
|
import tensorflow as tf |
|
|
|
|
|
DEFAULT_NUM_PARALLEL_CALLS = 100 |
|
|
|
|
|
def make_for_train( |
|
data, preprocess_fn, batch_size, |
|
shuffle_buffer_size=None, cache_raw=False, |
|
num_parallel_calls=DEFAULT_NUM_PARALLEL_CALLS, prefetch=2, |
|
*, |
|
pre_filter_fn=None, post_filter_fn=None, |
|
pack=None, skip_errors=False, |
|
): |
|
"""Makes an input pipeline for training.""" |
|
|
|
|
|
|
|
data = _add_tpu_host_options(data) |
|
|
|
data = data.filter(pre_filter_fn) if pre_filter_fn else data |
|
data = data.cache() if cache_raw else data |
|
|
|
|
|
|
|
|
|
|
|
if shuffle_buffer_size: |
|
data = data.shuffle(shuffle_buffer_size, reshuffle_each_iteration=True) |
|
data = data.repeat(None) |
|
|
|
data = data.map(preprocess_fn, num_parallel_calls=num_parallel_calls) |
|
data = data.filter(post_filter_fn) if post_filter_fn else data |
|
|
|
data = data.ignore_errors(log_warning=True) if skip_errors else data |
|
|
|
data = sequence_packing.pack_dataset(data, pack) if pack else data |
|
|
|
|
|
if batch_size: |
|
data = data.batch(batch_size // jax.process_count(), drop_remainder=True) |
|
if prefetch: |
|
data = data.prefetch(prefetch) |
|
return data |
|
|
|
|
|
def training(input_config): |
|
"""Reads the data from a single dataset, or mixes it from multiple. |
|
|
|
The data is read either from one or mixed from multiple datasets, depending |
|
on the `input_config`. |
|
|
|
Args: |
|
input_config: Configures the input pipeline. See input_pipeline_test for |
|
examples. |
|
|
|
Returns: |
|
A tuple containing (possibly mixed) tf.data.Dataset and a total number of |
|
training examples. |
|
""" |
|
per_pipeline_configs = ( |
|
"shuffle_buffer_size", "cache_raw", "num_parallel_calls", |
|
"pre_filter_fn", "post_filter_fn", "pack", "skip_errors") |
|
def config_to_kw(config): |
|
assert "filter_fn" not in config, "Deprecated; use `pre_filter_fn` instead." |
|
return {k: config[k] for k in per_pipeline_configs if k in config} |
|
|
|
batch_size = input_config.batch_size |
|
|
|
if isinstance(input_config.data.get("name"), str): |
|
train_data = ds_core.get(**input_config.data) |
|
train_ds = make_for_train( |
|
data=train_data.get_tfdata(ordered=False), |
|
batch_size=batch_size, |
|
preprocess_fn=pp_builder.get_preprocess_fn(input_config.get("pp")), |
|
prefetch=input_config.get("prefetch", 2), |
|
**config_to_kw(input_config) |
|
) |
|
return train_ds, train_data.total_examples |
|
|
|
|
|
for k in per_pipeline_configs: |
|
assert k not in input_config, f"{k} is per-dataset in multi-input." |
|
|
|
|
|
|
|
|
|
def _make(name_and_weight): |
|
name, weight = name_and_weight |
|
dataset = input_config[name] |
|
train_data = ds_core.get(**dataset.data) |
|
dataset = make_for_train( |
|
data=train_data.get_tfdata(ordered=False), |
|
|
|
|
|
batch_size=None, |
|
preprocess_fn=pp_builder.get_preprocess_fn(dataset.get("pp"), name), |
|
prefetch=0, |
|
**config_to_kw(dataset) |
|
) |
|
if keys := input_config.get("keep_only"): |
|
dataset = dataset.map(lambda d, keys=keys: {k: d[k] for k in keys}) |
|
return name, dataset, weight, train_data.total_examples |
|
|
|
names, datasets, weights, totals = [], [], [], [] |
|
pool = multiprocessing.pool.ThreadPool(len(input_config.data)) |
|
for name, dataset, weight, total in pool.map( |
|
|
|
_make, ((name, w) for name, w in input_config.data.items() if w)): |
|
names.append(name) |
|
datasets.append(dataset) |
|
weights.append(weight) |
|
totals.append(total) |
|
|
|
|
|
weights = [x / sum(weights) for x in weights] |
|
|
|
logging.info( |
|
"NOTE: Total dataset mix size: %d\nContributions:\n%s", sum(totals), |
|
"\n".join(f"{ds}: {n} ({w * 100:.1g}%)" |
|
for ds, n, w in zip(names, totals, weights)) |
|
) |
|
|
|
train_ds = tf.data.Dataset.sample_from_datasets( |
|
datasets, weights, stop_on_empty_dataset=True) |
|
if input_config.get("pack"): |
|
train_ds = sequence_packing.pack_dataset(train_ds, input_config.get("pack")) |
|
train_ds = train_ds.batch( |
|
input_config["batch_size"] // jax.process_count(), drop_remainder=True) |
|
if (pf := input_config.get("prefetch", 2)): |
|
train_ds = train_ds.prefetch(pf) |
|
|
|
return train_ds, sum(totals) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def make_for_inference( |
|
data, preprocess_fn, batch_size, num_ex_per_process, |
|
cache_raw=False, cache_final=False, |
|
num_parallel_calls=DEFAULT_NUM_PARALLEL_CALLS, prefetch=1, |
|
): |
|
"""Makes an input pipeline for inference.""" |
|
|
|
data = _add_tpu_host_options(data) |
|
data = data.cache() if cache_raw else data |
|
data = data.map(_add_internal_fields(preprocess_fn), |
|
num_parallel_calls=num_parallel_calls) |
|
data = data.concatenate(_get_pad_data(data)) |
|
|
|
local_batch_size = batch_size // jax.process_count() |
|
|
|
|
|
|
|
data = data.ragged_batch(batch_size=local_batch_size, drop_remainder=True) |
|
|
|
|
|
|
|
|
|
num_batches = math.ceil(max(num_ex_per_process) / local_batch_size) |
|
data = data.take(num_batches) |
|
|
|
|
|
data = data.cache() if cache_final else data |
|
data = data.repeat() |
|
data = data.prefetch(prefetch) if prefetch else data |
|
return data, num_batches |
|
|
|
|
|
def _get_pad_data(data): |
|
def zeros_like_spec(spec): |
|
|
|
return tf.zeros([x or 0 for x in spec.shape], spec.dtype) |
|
|
|
zero = jax.tree.map(zeros_like_spec, data.element_spec) |
|
return tf.data.Dataset.from_tensors(zero).repeat() |
|
|
|
|
|
def _add_internal_fields(pp_fn): |
|
"""Wraps pp_fn to add _mask and _id keys.""" |
|
|
|
|
|
|
|
|
|
def _pp_fn(example): |
|
result = pp_fn(example) |
|
|
|
result.setdefault("_mask", example.get("_mask", tf.constant(True))) |
|
|
|
if "_id" in example and "_id" not in result: |
|
result["_id"] = example["_id"] |
|
return result |
|
return _pp_fn |
|
|
|
|
|
def _add_tpu_host_options(data): |
|
options = tf.data.Options() |
|
options.threading.private_threadpool_size = 48 |
|
options.threading.max_intra_op_parallelism = 1 |
|
|
|
|
|
options.experimental_optimization.inject_prefetch = False |
|
|
|
return data.with_options(options) |
|
|
|
|
|
def prefetch_iterator(it, n): |
|
"""Runs iterator `it` ahead for `n` steps. Adapted from flax.""" |
|
if not n: |
|
yield from it |
|
return |
|
queue = collections.deque() |
|
|
|
def enqueue(n_steps): |
|
for data in itertools.islice(it, n_steps): |
|
|
|
|
|
|
|
queue.append(data) |
|
|
|
enqueue(n) |
|
while queue: |
|
yield queue.popleft() |
|
enqueue(1) |
|
|
|
|
|
def threadstart_iterator(it): |
|
"""Starts an iterator right away in a background thread.""" |
|
|
|
|
|
|
|
|
|
pool = multiprocessing.pool.ThreadPool(processes=1) |
|
first_ex_promise = pool.apply_async(lambda: next(it)) |
|
|
|
yield first_ex_promise.get() |
|
yield from it |
|
|
|
|
|
def tf_to_numpy(x): |
|
"""Convert any TF types to numpy.""" |
|
if isinstance(x, tf.Tensor): |
|
if x.dtype != tf.string: |
|
return x.numpy() |
|
else: |
|
return np.vectorize(bytes.decode, otypes=[str])(x.numpy()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
real_shape = list(x.shape) |
|
for i, s in enumerate(real_shape[1:]): |
|
if s is not None: continue |
|
rowlens = np.diff(x.nested_row_splits[i]) |
|
if len(set(rowlens)) == 1: |
|
real_shape[i + 1] = rowlens[0] |
|
|
|
if None not in real_shape: |
|
return tf_to_numpy(x.flat_values).reshape(real_shape) |
|
|
|
|
|
splits = x.row_splits.numpy() |
|
rows = [tf_to_numpy(x.values[splits[i]:splits[i + 1]]) |
|
for i in range(len(splits) - 1)] |
|
return np.fromiter(rows, dtype=object) |
|
|
|
|
|
|
|
|
|
def start_global( |
|
data, global_devices, n_prefetch=1, keep_on_cpu=frozenset(), warmup=False): |
|
"""Starts the global input pipeline.""" |
|
def maybe_shard(name, x): |
|
if name in keep_on_cpu: |
|
return tf_to_numpy(x) |
|
return u.make_fsarray_from_local_slice(x, global_devices) |
|
|
|
it = iter(data) |
|
if warmup: |
|
it = threadstart_iterator(it) |
|
|
|
it = (u.tree_map_with_names(maybe_shard, elem) for elem in it) |
|
return prefetch_iterator(it, n_prefetch) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def shard_and_put(x, shard=True, put=True): |
|
x = np.asarray(memoryview(x)) |
|
if shard: |
|
x = einops.rearrange(x, "(d l) ... -> d l ...", d=jax.local_device_count()) |
|
if shard and put: |
|
x = jax.device_put_sharded(list(x), jax.local_devices()) |
|
return x |
|
|
|
|
|
def start_input_pipeline(data, n_prefetch=1, shard=True): |
|
fn = functools.partial(shard_and_put, shard=shard, put=n_prefetch) |
|
it = (jax.tree.map(fn, elem) for elem in iter(data)) |
|
return prefetch_iterator(it, n_prefetch) |
|
|
|
|
|
def start_ragged_input_pipeline(data, n_prefetch=1, shard=True, ragged=None): |
|
def maybe_shard_and_put(name, x): |
|
return x if name in (ragged or {}) else shard_and_put(x, shard) |
|
|
|
it = (u.tree_map_with_names(maybe_shard_and_put, elem) for elem in iter(data)) |
|
return prefetch_iterator(it, n_prefetch) |
|
|