pranavSIT's picture
added pali inference
74e8f2f
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ImageNet input pipeline."""
import collections
import functools
import itertools
import math
import multiprocessing.pool
from absl import logging
from big_vision.datasets import sequence_packing
import big_vision.datasets.core as ds_core
import big_vision.pp.builder as pp_builder
import big_vision.utils as u
import einops
import jax
import numpy as np
import tensorflow as tf
DEFAULT_NUM_PARALLEL_CALLS = 100
def make_for_train(
data, preprocess_fn, batch_size,
shuffle_buffer_size=None, cache_raw=False,
num_parallel_calls=DEFAULT_NUM_PARALLEL_CALLS, prefetch=2,
*,
pre_filter_fn=None, post_filter_fn=None,
pack=None, skip_errors=False,
):
"""Makes an input pipeline for training."""
# Use data filtering at your own risk: the actual split sizes won't be known
# in advance, so epoch-based things won't work correctly.
data = _add_tpu_host_options(data)
data = data.filter(pre_filter_fn) if pre_filter_fn else data
data = data.cache() if cache_raw else data
# First shuffle and then repeat (each with a different shuffle). This way
# the data for one epoch is all seen before the next one is processed and
# significantly affects the number of times each example is seen when
# processing for small number of epochs.
if shuffle_buffer_size:
data = data.shuffle(shuffle_buffer_size, reshuffle_each_iteration=True)
data = data.repeat(None)
data = data.map(preprocess_fn, num_parallel_calls=num_parallel_calls)
data = data.filter(post_filter_fn) if post_filter_fn else data
data = data.ignore_errors(log_warning=True) if skip_errors else data
data = sequence_packing.pack_dataset(data, pack) if pack else data
# Drop remainder makes shape fully static, so we can later use it if needed.
if batch_size:
data = data.batch(batch_size // jax.process_count(), drop_remainder=True)
if prefetch: # None means autotune, but we never want that.
data = data.prefetch(prefetch)
return data
def training(input_config):
"""Reads the data from a single dataset, or mixes it from multiple.
The data is read either from one or mixed from multiple datasets, depending
on the `input_config`.
Args:
input_config: Configures the input pipeline. See input_pipeline_test for
examples.
Returns:
A tuple containing (possibly mixed) tf.data.Dataset and a total number of
training examples.
"""
per_pipeline_configs = (
"shuffle_buffer_size", "cache_raw", "num_parallel_calls",
"pre_filter_fn", "post_filter_fn", "pack", "skip_errors")
def config_to_kw(config):
assert "filter_fn" not in config, "Deprecated; use `pre_filter_fn` instead."
return {k: config[k] for k in per_pipeline_configs if k in config}
batch_size = input_config.batch_size
# Handle separately the common case when no mixing happens.
if isinstance(input_config.data.get("name"), str):
train_data = ds_core.get(**input_config.data)
train_ds = make_for_train(
data=train_data.get_tfdata(ordered=False),
batch_size=batch_size,
preprocess_fn=pp_builder.get_preprocess_fn(input_config.get("pp")),
prefetch=input_config.get("prefetch", 2), # Default 2 for bwd compat.
**config_to_kw(input_config)
)
return train_ds, train_data.total_examples
# A helpful error instead of silent ignore:
for k in per_pipeline_configs:
assert k not in input_config, f"{k} is per-dataset in multi-input."
# Parallelize the loading of datasets when doing data mixture.
# For larger mixes, we sometimes spend >5min when doing sequentially.
# NOTE: functools.cache is thread-safe.
def _make(name_and_weight):
name, weight = name_and_weight
dataset = input_config[name]
train_data = ds_core.get(**dataset.data)
dataset = make_for_train(
data=train_data.get_tfdata(ordered=False),
# Don't batch the data just yet, it will be done after
# mixing the different datasets below.
batch_size=None,
preprocess_fn=pp_builder.get_preprocess_fn(dataset.get("pp"), name),
prefetch=0, # Prefetching each pipeline leads to huge OOMs.
**config_to_kw(dataset)
)
if keys := input_config.get("keep_only"):
dataset = dataset.map(lambda d, keys=keys: {k: d[k] for k in keys})
return name, dataset, weight, train_data.total_examples
names, datasets, weights, totals = [], [], [], []
pool = multiprocessing.pool.ThreadPool(len(input_config.data))
for name, dataset, weight, total in pool.map(
# Skip weight=0 datasets as a convenient optimization in sweeps.
_make, ((name, w) for name, w in input_config.data.items() if w)):
names.append(name)
datasets.append(dataset)
weights.append(weight)
totals.append(total)
# Normalize the weights such that they sum up to 1.
weights = [x / sum(weights) for x in weights]
logging.info(
"NOTE: Total dataset mix size: %d\nContributions:\n%s", sum(totals),
"\n".join(f"{ds}: {n} ({w * 100:.1g}%)"
for ds, n, w in zip(names, totals, weights))
)
train_ds = tf.data.Dataset.sample_from_datasets(
datasets, weights, stop_on_empty_dataset=True)
if input_config.get("pack"):
train_ds = sequence_packing.pack_dataset(train_ds, input_config.get("pack"))
train_ds = train_ds.batch(
input_config["batch_size"] // jax.process_count(), drop_remainder=True)
if (pf := input_config.get("prefetch", 2)):
train_ds = train_ds.prefetch(pf)
return train_ds, sum(totals)
# The pipeline below is used for evals in multi-{G,T}PU and multi-host settings.
# As the total number of examples may not be evenly divisible accross all
# devices, we use the `infinite tf.data padding` trick, which was suggested by
# Andreas Steiner and also implemented by him in the clu library:
# https://github.com/google/CommonLoopUtils/blob/84b777c42dfd3fb6685537138433bfeb5241a006/clu/deterministic_data.py#L304.
def make_for_inference(
data, preprocess_fn, batch_size, num_ex_per_process,
cache_raw=False, cache_final=False,
num_parallel_calls=DEFAULT_NUM_PARALLEL_CALLS, prefetch=1,
):
"""Makes an input pipeline for inference."""
data = _add_tpu_host_options(data)
data = data.cache() if cache_raw else data
data = data.map(_add_internal_fields(preprocess_fn),
num_parallel_calls=num_parallel_calls)
data = data.concatenate(_get_pad_data(data))
local_batch_size = batch_size // jax.process_count()
# This is just like `batch`, but allows batching elements of different shapes
# into a tf.RaggedTensor. Elements of the same fixed shape remain tf.Tensors.
# Since we do 'infinite' padding it is safe to drop the remainder.
data = data.ragged_batch(batch_size=local_batch_size, drop_remainder=True)
# We need to make sure that all hosts process all data and exactly the same
# number of batches. Below we take max per-host num examples and use it on all
# hosts to derive the number of batches.
num_batches = math.ceil(max(num_ex_per_process) / local_batch_size)
data = data.take(num_batches)
# Note we cache data after a finite number of batches is taken.
data = data.cache() if cache_final else data
data = data.repeat()
data = data.prefetch(prefetch) if prefetch else data
return data, num_batches
def _get_pad_data(data):
def zeros_like_spec(spec):
# For unknown/flexible dimensions (None), just use 0 instead.
return tf.zeros([x or 0 for x in spec.shape], spec.dtype)
zero = jax.tree.map(zeros_like_spec, data.element_spec)
return tf.data.Dataset.from_tensors(zero).repeat()
def _add_internal_fields(pp_fn):
"""Wraps pp_fn to add _mask and _id keys."""
# Adds internal keys, that we either, in this order of preference:
# 1. keep from result of pp_fn,
# 2. carry over from raw (not pp_fn'd) example, or
# 3. add, if that makes sense.
def _pp_fn(example):
result = pp_fn(example)
# _mask will be False on padded examples (see _get_pad_data).
result.setdefault("_mask", example.get("_mask", tf.constant(True)))
# Not all data-sources can provide an ID. Only carry-over if it can:
if "_id" in example and "_id" not in result:
result["_id"] = example["_id"]
return result
return _pp_fn
def _add_tpu_host_options(data):
options = tf.data.Options()
options.threading.private_threadpool_size = 48
options.threading.max_intra_op_parallelism = 1
# Stop a whole bunch of magic stuff that eats up all RAM:
options.experimental_optimization.inject_prefetch = False
return data.with_options(options)
def prefetch_iterator(it, n):
"""Runs iterator `it` ahead for `n` steps. Adapted from flax."""
if not n:
yield from it
return
queue = collections.deque()
def enqueue(n_steps): # Enqueues *up to* `n` elements from the iterator.
for data in itertools.islice(it, n_steps):
# Prefetching will parallelize any processing that happens in a different
# thread (like `jax.device_put()`), but it will be of no use for
# processing that happens in the same thread.
queue.append(data)
enqueue(n) # Fill up the buffer.
while queue:
yield queue.popleft()
enqueue(1)
def threadstart_iterator(it):
"""Starts an iterator right away in a background thread."""
# We already want to "start" the iterator in order to start the underlying
# dataset prefetch mechanisms, so here we get the first element. But we don't
# want to lose it from training, so we yield that one afterwards.
# (internal link)
pool = multiprocessing.pool.ThreadPool(processes=1)
first_ex_promise = pool.apply_async(lambda: next(it))
yield first_ex_promise.get()
yield from it
def tf_to_numpy(x):
"""Convert any TF types to numpy."""
if isinstance(x, tf.Tensor):
if x.dtype != tf.string: # Dense, non-string tensor? Easy!
return x.numpy()
else: # A dense string tensor? Turn into actual strings, not bytes.
return np.vectorize(bytes.decode, otypes=[str])(x.numpy())
# The rest deals with RaggedTensors, for two main reasons:
# - For strings, recursively apply the above conversion
# - For common cases (eg batch of images), return more reasonable shapes.
# Replace all None's in the shape by a fixed number, in the (somewhat common)
# case that they are marked ragged, but really all have the same shape.
real_shape = list(x.shape)
for i, s in enumerate(real_shape[1:]):
if s is not None: continue
rowlens = np.diff(x.nested_row_splits[i])
if len(set(rowlens)) == 1:
real_shape[i + 1] = rowlens[0]
if None not in real_shape:
return tf_to_numpy(x.flat_values).reshape(real_shape)
# It's actually ragged, reconstruct the array from the variable length pieces.
splits = x.row_splits.numpy()
rows = [tf_to_numpy(x.values[splits[i]:splits[i + 1]])
for i in range(len(splits) - 1)]
return np.fromiter(rows, dtype=object)
# Note that the order of global devices for sharding data is important and
# should be compatible with device order used for models params, state, etc.
def start_global(
data, global_devices, n_prefetch=1, keep_on_cpu=frozenset(), warmup=False):
"""Starts the global input pipeline."""
def maybe_shard(name, x):
if name in keep_on_cpu:
return tf_to_numpy(x)
return u.make_fsarray_from_local_slice(x, global_devices)
it = iter(data)
if warmup: # actually pre-fill shuffle buffers etc.
it = threadstart_iterator(it)
it = (u.tree_map_with_names(maybe_shard, elem) for elem in it)
return prefetch_iterator(it, n_prefetch)
##########################################################################
# The code below is pmap-specific and is deprecated, please switch to jit.
##########################################################################
def shard_and_put(x, shard=True, put=True):
x = np.asarray(memoryview(x)) # No-copy conversion: http://(internal link)
if shard:
x = einops.rearrange(x, "(d l) ... -> d l ...", d=jax.local_device_count())
if shard and put: # Only works for pmap (for now).
x = jax.device_put_sharded(list(x), jax.local_devices())
return x
def start_input_pipeline(data, n_prefetch=1, shard=True):
fn = functools.partial(shard_and_put, shard=shard, put=n_prefetch)
it = (jax.tree.map(fn, elem) for elem in iter(data))
return prefetch_iterator(it, n_prefetch)
def start_ragged_input_pipeline(data, n_prefetch=1, shard=True, ragged=None):
def maybe_shard_and_put(name, x):
return x if name in (ragged or {}) else shard_and_put(x, shard)
it = (u.tree_map_with_names(maybe_shard_and_put, elem) for elem in iter(data))
return prefetch_iterator(it, n_prefetch)