|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Dataset creation for frame interpolation.""" |
|
from typing import Callable, Dict, List, Optional |
|
|
|
from absl import logging |
|
import gin.tf |
|
import tensorflow as tf |
|
|
|
|
|
def _create_feature_map() -> Dict[str, tf.io.FixedLenFeature]: |
|
"""Creates the feature map for extracting the frame triplet.""" |
|
feature_map = { |
|
'frame_0/encoded': |
|
tf.io.FixedLenFeature((), tf.string, default_value=''), |
|
'frame_0/format': |
|
tf.io.FixedLenFeature((), tf.string, default_value='jpg'), |
|
'frame_0/height': |
|
tf.io.FixedLenFeature((), tf.int64, default_value=0), |
|
'frame_0/width': |
|
tf.io.FixedLenFeature((), tf.int64, default_value=0), |
|
'frame_1/encoded': |
|
tf.io.FixedLenFeature((), tf.string, default_value=''), |
|
'frame_1/format': |
|
tf.io.FixedLenFeature((), tf.string, default_value='jpg'), |
|
'frame_1/height': |
|
tf.io.FixedLenFeature((), tf.int64, default_value=0), |
|
'frame_1/width': |
|
tf.io.FixedLenFeature((), tf.int64, default_value=0), |
|
'frame_2/encoded': |
|
tf.io.FixedLenFeature((), tf.string, default_value=''), |
|
'frame_2/format': |
|
tf.io.FixedLenFeature((), tf.string, default_value='jpg'), |
|
'frame_2/height': |
|
tf.io.FixedLenFeature((), tf.int64, default_value=0), |
|
'frame_2/width': |
|
tf.io.FixedLenFeature((), tf.int64, default_value=0), |
|
'path': |
|
tf.io.FixedLenFeature((), tf.string, default_value=''), |
|
} |
|
return feature_map |
|
|
|
|
|
def _parse_example(sample): |
|
"""Parses a serialized sample. |
|
|
|
Args: |
|
sample: A serialized tf.Example to be parsed. |
|
|
|
Returns: |
|
dictionary containing the following: |
|
encoded_image |
|
image_height |
|
image_width |
|
""" |
|
feature_map = _create_feature_map() |
|
features = tf.io.parse_single_example(sample, feature_map) |
|
output_dict = { |
|
'x0': tf.io.decode_image(features['frame_0/encoded'], dtype=tf.float32), |
|
'x1': tf.io.decode_image(features['frame_2/encoded'], dtype=tf.float32), |
|
'y': tf.io.decode_image(features['frame_1/encoded'], dtype=tf.float32), |
|
|
|
|
|
|
|
'time': 0.5, |
|
|
|
'path': features['path'], |
|
} |
|
|
|
return output_dict |
|
|
|
|
|
def _random_crop_images(crop_size: int, images: tf.Tensor, |
|
total_channel_size: int) -> tf.Tensor: |
|
"""Crops the tensor with random offset to the given size.""" |
|
if crop_size > 0: |
|
crop_shape = tf.constant([crop_size, crop_size, total_channel_size]) |
|
images = tf.image.random_crop(images, crop_shape) |
|
return images |
|
|
|
|
|
def crop_example(example: tf.Tensor, crop_size: int, |
|
crop_keys: Optional[List[str]] = None): |
|
"""Random crops selected images in the example to given size and keys. |
|
|
|
Args: |
|
example: Input tensor representing images to be cropped. |
|
crop_size: The size to crop images to. This value is used for both |
|
height and width. |
|
crop_keys: The images in the input example to crop. |
|
|
|
Returns: |
|
Example with cropping applied to selected images. |
|
""" |
|
if crop_keys is None: |
|
crop_keys = ['x0', 'x1', 'y'] |
|
channels = [3, 3, 3] |
|
|
|
|
|
image_to_crop = [example[key] for key in crop_keys] |
|
stacked_images = tf.concat(image_to_crop, axis=-1) |
|
cropped_images = _random_crop_images(crop_size, stacked_images, sum(channels)) |
|
cropped_images = tf.split( |
|
cropped_images, num_or_size_splits=channels, axis=-1) |
|
for key, cropped_image in zip(crop_keys, cropped_images): |
|
example[key] = cropped_image |
|
return example |
|
|
|
|
|
def apply_data_augmentation( |
|
augmentation_fns: Dict[str, Callable[..., tf.Tensor]], |
|
example: tf.Tensor, |
|
augmentation_keys: Optional[List[str]] = None) -> tf.Tensor: |
|
"""Applies random augmentation in succession to selected image keys. |
|
|
|
Args: |
|
augmentation_fns: A Dict of Callables to data augmentation functions. |
|
example: Input tensor representing images to be augmented. |
|
augmentation_keys: The images in the input example to augment. |
|
|
|
Returns: |
|
Example with augmentation applied to selected images. |
|
""" |
|
if augmentation_keys is None: |
|
augmentation_keys = ['x0', 'x1', 'y'] |
|
|
|
|
|
augmented_images = {key: example[key] for key in augmentation_keys} |
|
for augmentation_function in augmentation_fns.values(): |
|
augmented_images = augmentation_function(augmented_images) |
|
|
|
for key in augmentation_keys: |
|
example[key] = augmented_images[key] |
|
return example |
|
|
|
|
|
def _create_from_tfrecord(batch_size, file, augmentation_fns, |
|
crop_size) -> tf.data.Dataset: |
|
"""Creates a dataset from TFRecord.""" |
|
dataset = tf.data.TFRecordDataset(file) |
|
dataset = dataset.map( |
|
_parse_example, num_parallel_calls=tf.data.experimental.AUTOTUNE) |
|
|
|
|
|
if augmentation_fns is not None: |
|
dataset = dataset.map( |
|
lambda x: apply_data_augmentation(augmentation_fns, x), |
|
num_parallel_calls=tf.data.experimental.AUTOTUNE) |
|
|
|
if crop_size > 0: |
|
dataset = dataset.map( |
|
lambda x: crop_example(x, crop_size=crop_size), |
|
num_parallel_calls=tf.data.experimental.AUTOTUNE) |
|
dataset = dataset.batch(batch_size, drop_remainder=True) |
|
return dataset |
|
|
|
|
|
def _generate_sharded_filenames(filename: str) -> List[str]: |
|
"""Generates filenames of the each file in the sharded filepath. |
|
|
|
Based on github.com/google/revisiting-self-supervised/blob/master/datasets.py. |
|
|
|
Args: |
|
filename: The sharded filepath. |
|
|
|
Returns: |
|
A list of filepaths for each file in the shard. |
|
""" |
|
base, count = filename.split('@') |
|
count = int(count) |
|
return ['{}-{:05d}-of-{:05d}'.format(base, i, count) for i in range(count)] |
|
|
|
|
|
def _create_from_sharded_tfrecord(batch_size, |
|
train_mode, |
|
file, |
|
augmentation_fns, |
|
crop_size, |
|
max_examples=-1) -> tf.data.Dataset: |
|
"""Creates a dataset from a sharded tfrecord.""" |
|
dataset = tf.data.Dataset.from_tensor_slices( |
|
_generate_sharded_filenames(file)) |
|
|
|
|
|
dataset = dataset.interleave( |
|
lambda x: _create_from_tfrecord( |
|
batch_size, |
|
file=x, |
|
augmentation_fns=augmentation_fns, |
|
crop_size=crop_size), |
|
num_parallel_calls=tf.data.AUTOTUNE, |
|
deterministic=not train_mode) |
|
|
|
dataset = dataset.prefetch(buffer_size=2) |
|
if max_examples > 0: |
|
return dataset.take(max_examples) |
|
return dataset |
|
|
|
|
|
@gin.configurable('training_dataset') |
|
def create_training_dataset( |
|
batch_size: int, |
|
file: Optional[str] = None, |
|
files: Optional[List[str]] = None, |
|
crop_size: int = -1, |
|
crop_sizes: Optional[List[int]] = None, |
|
augmentation_fns: Optional[Dict[str, Callable[..., tf.Tensor]]] = None |
|
) -> tf.data.Dataset: |
|
"""Creates the training dataset. |
|
|
|
The given tfrecord should contain data in a format produced by |
|
frame_interpolation/datasets/create_*_tfrecord.py |
|
|
|
Args: |
|
batch_size: The number of images to batch per example. |
|
file: (deprecated) A path to a sharded tfrecord in <tfrecord>@N format. |
|
Deprecated. Use 'files' instead. |
|
files: A list of paths to sharded tfrecords in <tfrecord>@N format. |
|
crop_size: (deprecated) If > 0, images are cropped to crop_size x crop_size |
|
using tensorflow's random cropping. Deprecated: use 'files' and |
|
'crop_sizes' instead. |
|
crop_sizes: List of crop sizes. If > 0, images are cropped to |
|
crop_size x crop_size using tensorflow's random cropping. |
|
augmentation_fns: A Dict of Callables to data augmentation functions. |
|
Returns: |
|
A tensorflow dataset for accessing examples that contain the input images |
|
'x0', 'x1', ground truth 'y' and time of the ground truth 'time'=[0,1] in a |
|
dictionary of tensors. |
|
""" |
|
if file: |
|
logging.warning('gin-configurable training_dataset.file is deprecated. ' |
|
'Use training_dataset.files instead.') |
|
return _create_from_sharded_tfrecord(batch_size, True, file, |
|
augmentation_fns, crop_size) |
|
else: |
|
if not crop_sizes or len(crop_sizes) != len(files): |
|
raise ValueError('Please pass crop_sizes[] with training_dataset.files.') |
|
if crop_size > 0: |
|
raise ValueError( |
|
'crop_size should not be used with files[], use crop_sizes[] instead.' |
|
) |
|
tables = [] |
|
for file, crop_size in zip(files, crop_sizes): |
|
tables.append( |
|
_create_from_sharded_tfrecord(batch_size, True, file, |
|
augmentation_fns, crop_size)) |
|
return tf.data.experimental.sample_from_datasets(tables) |
|
|
|
|
|
@gin.configurable('eval_datasets') |
|
def create_eval_datasets(batch_size: int, |
|
files: List[str], |
|
names: List[str], |
|
crop_size: int = -1, |
|
max_examples: int = -1) -> Dict[str, tf.data.Dataset]: |
|
"""Creates the evaluation datasets. |
|
|
|
As opposed to create_training_dataset this function makes sure that the |
|
examples for each dataset are always read in a deterministic (same) order. |
|
|
|
Each given tfrecord should contain data in a format produced by |
|
frame_interpolation/datasets/create_*_tfrecord.py |
|
|
|
The (batch_size, crop_size, max_examples) are specified for all eval datasets. |
|
|
|
Args: |
|
batch_size: The number of images to batch per example. |
|
files: List of paths to a sharded tfrecord in <tfrecord>@N format. |
|
names: List of names of eval datasets. |
|
crop_size: If > 0, images are cropped to crop_size x crop_size using |
|
tensorflow's random cropping. |
|
max_examples: If > 0, truncate the dataset to 'max_examples' in length. This |
|
can be useful for speeding up evaluation loop in case the tfrecord for the |
|
evaluation set is very large. |
|
Returns: |
|
A dict of name to tensorflow dataset for accessing examples that contain the |
|
input images 'x0', 'x1', ground truth 'y' and time of the ground truth |
|
'time'=[0,1] in a dictionary of tensors. |
|
""" |
|
return { |
|
name: _create_from_sharded_tfrecord(batch_size, False, file, None, |
|
crop_size, max_examples) |
|
for name, file in zip(names, files) |
|
} |
|
|