# Copyright 2024 Big Vision Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Packed Sequence Op.""" # Forked from # https://github.com/google/maxtext/blob/main/MaxText/sequence_packing.py. from typing import Dict, Optional, List, Union import tensorflow as tf AUTOTUNE = tf.data.experimental.AUTOTUNE def pack_dataset(dataset: tf.data.Dataset, key2length: Union[int, Dict[str, int]], keys: Optional[List[str]] = None) -> tf.data.Dataset: """Creates a 'packed' version of a dataset on-the-fly. Adapted from the mesh-tf implementation. This is meant to replace the irritation of having to create a separate "packed" version of a dataset to train efficiently on TPU. Each example in the output dataset represents several examples in the input dataset. For each key in the input dataset, two additional keys are created: _seg: an int32 tensor identifying the parts representing the original example. _pos: an int32 tensor identifying the position within the original example. Example: Two input examples get combined to form an output example. The input examples are: {"inputs": [8, 7, 1, 0], "targets":[4, 1, 0]} {"inputs": [2, 3, 4, 1], "targets":[5, 6, 1]} The output example is: { "inputs": [8, 7, 1, 2, 3, 4, 1, 0, 0, 0] "inputs_seg": [1, 1, 1, 2, 2, 2, 2, 0, 0, 0] "inputs_pos": [0, 1, 2, 0, 1, 2, 3, 0, 0, 0] "targets": [4, 1, 5, 6, 1, 0, 0, 0, 0, 0] "targets_seg": [1, 1, 2, 2, 2, 0, 0, 0, 0, 0] "targets_pos": [0, 1, 0, 1, 2, 0, 0, 0, 0, 0] } 0 represents padding in both the inputs and the outputs. Sequences in the incoming examples are truncated to length "length", and the sequences in the output examples all have fixed (padded) length "length". Args: dataset: a tf.data.Dataset key2length: an integer, or a dict from feature-key to integer keys: a list of strings (e.g. ["inputs", "targets"]) Returns: a tf.data.Dataset """ shapes = tf.nest.map_structure(lambda spec: spec.shape, dataset.element_spec) if keys is None: keys = list(shapes.keys()) for k in keys: if k not in shapes: raise ValueError(f"""Key {k} not found in dataset. Available keys are {shapes.keys()}""") if not shapes[k].is_compatible_with(tf.TensorShape([None])): raise ValueError('Tensors to be packed must be one-dimensional.') # make sure that the length dictionary contains all keys as well as the # keys suffixed by "_seg" and "_pos" if isinstance(key2length, int): key2length = {k: key2length for k in keys} else: key2length = dict(key2length) # Make new dict, we'll edit in-place. for k in keys: for suffix in ['_seg', '_pos']: key2length[k + suffix] = key2length[k] # trim to length dataset = dataset.map( lambda x: {k: x[k][:key2length[k]] for k in keys}, num_parallel_calls=AUTOTUNE) # Setting batch_size=length ensures that the concatenated sequences (if they # have length >=1) are sufficient to fill at least one packed example. batch_size = max(key2length.values()) dataset = dataset.padded_batch( batch_size, padded_shapes={k: [-1] for k in keys}) dataset = _pack_with_tf_ops(dataset, keys, key2length) # Set the Tensor shapes correctly since they get lost in the process. def my_fn(x): return {k: tf.reshape(v, [key2length[k]]) for k, v in x.items()} return dataset.map(my_fn, num_parallel_calls=AUTOTUNE) def _pack_with_tf_ops(dataset: tf.data.Dataset, keys: List[str], key2length: Dict[str, int]) -> tf.data.Dataset: """Helper-function for packing a dataset which has already been batched. Helper for pack_dataset() Uses tf.while_loop. Args: dataset: a dataset containing padded batches of examples. keys: a list of strings key2length: an dict from feature-key to integer Returns: a dataset. """ empty_example = {} for k in keys: empty_example[k] = tf.zeros([0], dtype=tf.int32) empty_example[k + '_pos'] = tf.zeros([0], dtype=tf.int32) keys_etc = empty_example.keys() def write_packed_example(partial, outputs): new_partial = empty_example.copy() new_outputs = {} for k in keys_etc: new_outputs[k] = outputs[k].write( outputs[k].size(), tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]])) return new_partial, new_outputs def map_fn(x): """Internal function to flat_map over. Consumes a batch of input examples and produces a variable number of output examples. Args: x: a single example Returns: a tf.data.Dataset """ partial = empty_example.copy() i = tf.zeros([], dtype=tf.int32) dynamic_batch_size = tf.shape(x[keys[0]])[0] outputs = {} for k in keys: outputs[k] = tf.TensorArray( tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]) outputs[k + '_pos'] = tf.TensorArray( tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]) def body_fn(i, partial, outputs): """Body function for while_loop. Args: i: integer scalar partial: dictionary of Tensor (partially-constructed example) outputs: dictionary of TensorArray Returns: A triple containing the new values of the inputs. """ can_append = True one_example = {} for k in keys: val = tf.cast(x[k][i], tf.int32) val = val[:tf.reduce_sum(tf.cast(tf.not_equal(val, 0), tf.int32))] one_example[k] = val for k in keys: can_append = tf.logical_and( can_append, tf.less_equal( tf.size(partial[k]) + tf.size(one_example[k]), key2length[k])) def false_fn(): return write_packed_example(partial, outputs) def true_fn(): return partial, outputs partial, outputs = tf.cond(can_append, true_fn, false_fn) new_partial = {} for k in keys: new_seq = one_example[k][:key2length[k]] new_seq_len = tf.size(new_seq) new_partial[k] = tf.concat([partial[k], new_seq], 0) new_partial[k + '_pos'] = tf.concat( [partial[k + '_pos'], tf.range(new_seq_len)], 0) partial = new_partial return i + 1, partial, outputs # For loop over all examples in the batch. i, partial, outputs = tf.while_loop( cond=lambda *_: True, body=body_fn, loop_vars=(i, partial, outputs), shape_invariants=( tf.TensorShape([]), {k: tf.TensorShape([None]) for k in keys_etc}, {k: tf.TensorShape(None) for k in keys_etc}, ), maximum_iterations=dynamic_batch_size) _, outputs = write_packed_example(partial, outputs) packed = {k: outputs[k].stack() for k in keys_etc} for k in keys: packed[k + '_seg'] = ( tf.cumsum( tf.cast(tf.equal(packed[k + '_pos'], 0), tf.int32), axis=1) * tf.cast(tf.not_equal(packed[k], 0), tf.int32)) return packed dataset = dataset.map(map_fn, num_parallel_calls=AUTOTUNE) return dataset.unbatch()