# Copyright 2022 Google LLC # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # https://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Dataset creation for frame interpolation.""" from typing import Callable, Dict, List, Optional from absl import logging import gin.tf import tensorflow as tf def _create_feature_map() -> Dict[str, tf.io.FixedLenFeature]: """Creates the feature map for extracting the frame triplet.""" feature_map = { 'frame_0/encoded': tf.io.FixedLenFeature((), tf.string, default_value=''), 'frame_0/format': tf.io.FixedLenFeature((), tf.string, default_value='jpg'), 'frame_0/height': tf.io.FixedLenFeature((), tf.int64, default_value=0), 'frame_0/width': tf.io.FixedLenFeature((), tf.int64, default_value=0), 'frame_1/encoded': tf.io.FixedLenFeature((), tf.string, default_value=''), 'frame_1/format': tf.io.FixedLenFeature((), tf.string, default_value='jpg'), 'frame_1/height': tf.io.FixedLenFeature((), tf.int64, default_value=0), 'frame_1/width': tf.io.FixedLenFeature((), tf.int64, default_value=0), 'frame_2/encoded': tf.io.FixedLenFeature((), tf.string, default_value=''), 'frame_2/format': tf.io.FixedLenFeature((), tf.string, default_value='jpg'), 'frame_2/height': tf.io.FixedLenFeature((), tf.int64, default_value=0), 'frame_2/width': tf.io.FixedLenFeature((), tf.int64, default_value=0), 'path': tf.io.FixedLenFeature((), tf.string, default_value=''), } return feature_map def _parse_example(sample): """Parses a serialized sample. Args: sample: A serialized tf.Example to be parsed. Returns: dictionary containing the following: encoded_image image_height image_width """ feature_map = _create_feature_map() features = tf.io.parse_single_example(sample, feature_map) output_dict = { 'x0': tf.io.decode_image(features['frame_0/encoded'], dtype=tf.float32), 'x1': tf.io.decode_image(features['frame_2/encoded'], dtype=tf.float32), 'y': tf.io.decode_image(features['frame_1/encoded'], dtype=tf.float32), # The fractional time value of frame_1 is not included in our tfrecords, # but is always at 0.5. The model will expect this to be specificed, so # we insert it here. 'time': 0.5, # Store the original mid frame filepath for identifying examples. 'path': features['path'], } return output_dict def _random_crop_images(crop_size: int, images: tf.Tensor, total_channel_size: int) -> tf.Tensor: """Crops the tensor with random offset to the given size.""" if crop_size > 0: crop_shape = tf.constant([crop_size, crop_size, total_channel_size]) images = tf.image.random_crop(images, crop_shape) return images def crop_example(example: tf.Tensor, crop_size: int, crop_keys: Optional[List[str]] = None): """Random crops selected images in the example to given size and keys. Args: example: Input tensor representing images to be cropped. crop_size: The size to crop images to. This value is used for both height and width. crop_keys: The images in the input example to crop. Returns: Example with cropping applied to selected images. """ if crop_keys is None: crop_keys = ['x0', 'x1', 'y'] channels = [3, 3, 3] # Stack images along channel axis, and perform a random crop once. image_to_crop = [example[key] for key in crop_keys] stacked_images = tf.concat(image_to_crop, axis=-1) cropped_images = _random_crop_images(crop_size, stacked_images, sum(channels)) cropped_images = tf.split( cropped_images, num_or_size_splits=channels, axis=-1) for key, cropped_image in zip(crop_keys, cropped_images): example[key] = cropped_image return example def apply_data_augmentation( augmentation_fns: Dict[str, Callable[..., tf.Tensor]], example: tf.Tensor, augmentation_keys: Optional[List[str]] = None) -> tf.Tensor: """Applies random augmentation in succession to selected image keys. Args: augmentation_fns: A Dict of Callables to data augmentation functions. example: Input tensor representing images to be augmented. augmentation_keys: The images in the input example to augment. Returns: Example with augmentation applied to selected images. """ if augmentation_keys is None: augmentation_keys = ['x0', 'x1', 'y'] # Apply each augmentation in sequence augmented_images = {key: example[key] for key in augmentation_keys} for augmentation_function in augmentation_fns.values(): augmented_images = augmentation_function(augmented_images) for key in augmentation_keys: example[key] = augmented_images[key] return example def _create_from_tfrecord(batch_size, file, augmentation_fns, crop_size) -> tf.data.Dataset: """Creates a dataset from TFRecord.""" dataset = tf.data.TFRecordDataset(file) dataset = dataset.map( _parse_example, num_parallel_calls=tf.data.experimental.AUTOTUNE) # Perform data_augmentation before cropping and batching if augmentation_fns is not None: dataset = dataset.map( lambda x: apply_data_augmentation(augmentation_fns, x), num_parallel_calls=tf.data.experimental.AUTOTUNE) if crop_size > 0: dataset = dataset.map( lambda x: crop_example(x, crop_size=crop_size), num_parallel_calls=tf.data.experimental.AUTOTUNE) dataset = dataset.batch(batch_size, drop_remainder=True) return dataset def _generate_sharded_filenames(filename: str) -> List[str]: """Generates filenames of the each file in the sharded filepath. Based on github.com/google/revisiting-self-supervised/blob/master/datasets.py. Args: filename: The sharded filepath. Returns: A list of filepaths for each file in the shard. """ base, count = filename.split('@') count = int(count) return ['{}-{:05d}-of-{:05d}'.format(base, i, count) for i in range(count)] def _create_from_sharded_tfrecord(batch_size, train_mode, file, augmentation_fns, crop_size, max_examples=-1) -> tf.data.Dataset: """Creates a dataset from a sharded tfrecord.""" dataset = tf.data.Dataset.from_tensor_slices( _generate_sharded_filenames(file)) # pylint: disable=g-long-lambda dataset = dataset.interleave( lambda x: _create_from_tfrecord( batch_size, file=x, augmentation_fns=augmentation_fns, crop_size=crop_size), num_parallel_calls=tf.data.AUTOTUNE, deterministic=not train_mode) # pylint: enable=g-long-lambda dataset = dataset.prefetch(buffer_size=2) if max_examples > 0: return dataset.take(max_examples) return dataset @gin.configurable('training_dataset') def create_training_dataset( batch_size: int, file: Optional[str] = None, files: Optional[List[str]] = None, crop_size: int = -1, crop_sizes: Optional[List[int]] = None, augmentation_fns: Optional[Dict[str, Callable[..., tf.Tensor]]] = None ) -> tf.data.Dataset: """Creates the training dataset. The given tfrecord should contain data in a format produced by frame_interpolation/datasets/create_*_tfrecord.py Args: batch_size: The number of images to batch per example. file: (deprecated) A path to a sharded tfrecord in @N format. Deprecated. Use 'files' instead. files: A list of paths to sharded tfrecords in @N format. crop_size: (deprecated) If > 0, images are cropped to crop_size x crop_size using tensorflow's random cropping. Deprecated: use 'files' and 'crop_sizes' instead. crop_sizes: List of crop sizes. If > 0, images are cropped to crop_size x crop_size using tensorflow's random cropping. augmentation_fns: A Dict of Callables to data augmentation functions. Returns: A tensorflow dataset for accessing examples that contain the input images 'x0', 'x1', ground truth 'y' and time of the ground truth 'time'=[0,1] in a dictionary of tensors. """ if file: logging.warning('gin-configurable training_dataset.file is deprecated. ' 'Use training_dataset.files instead.') return _create_from_sharded_tfrecord(batch_size, True, file, augmentation_fns, crop_size) else: if not crop_sizes or len(crop_sizes) != len(files): raise ValueError('Please pass crop_sizes[] with training_dataset.files.') if crop_size > 0: raise ValueError( 'crop_size should not be used with files[], use crop_sizes[] instead.' ) tables = [] for file, crop_size in zip(files, crop_sizes): tables.append( _create_from_sharded_tfrecord(batch_size, True, file, augmentation_fns, crop_size)) return tf.data.experimental.sample_from_datasets(tables) @gin.configurable('eval_datasets') def create_eval_datasets(batch_size: int, files: List[str], names: List[str], crop_size: int = -1, max_examples: int = -1) -> Dict[str, tf.data.Dataset]: """Creates the evaluation datasets. As opposed to create_training_dataset this function makes sure that the examples for each dataset are always read in a deterministic (same) order. Each given tfrecord should contain data in a format produced by frame_interpolation/datasets/create_*_tfrecord.py The (batch_size, crop_size, max_examples) are specified for all eval datasets. Args: batch_size: The number of images to batch per example. files: List of paths to a sharded tfrecord in @N format. names: List of names of eval datasets. crop_size: If > 0, images are cropped to crop_size x crop_size using tensorflow's random cropping. max_examples: If > 0, truncate the dataset to 'max_examples' in length. This can be useful for speeding up evaluation loop in case the tfrecord for the evaluation set is very large. Returns: A dict of name to tensorflow dataset for accessing examples that contain the input images 'x0', 'x1', ground truth 'y' and time of the ground truth 'time'=[0,1] in a dictionary of tensors. """ return { name: _create_from_sharded_tfrecord(batch_size, False, file, None, crop_size, max_examples) for name, file in zip(names, files) }