# Copyright 2022 Google LLC # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # https://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Dataset augmentation for frame interpolation.""" from typing import Callable, Dict, List import gin.tf import numpy as np import tensorflow as tf import tensorflow.math as tfm import tensorflow_addons.image as tfa_image _PI = 3.141592653589793 def _rotate_flow_vectors(flow: tf.Tensor, angle_rad: float) -> tf.Tensor: r"""Rotate the (u,v) vector of each pixel with angle in radians. Flow matrix system of coordinates. . . . . u (x) . . . v (-y) Rotation system of coordinates. . y . . . . . . x Args: flow: Flow map which has been image-rotated. angle_rad: The rotation angle in radians. Returns: A flow with the same map but each (u,v) vector rotated by angle_rad. """ u, v = tf.split(flow, 2, axis=-1) # rotu = u * cos(angle) - (-v) * sin(angle) rot_u = tfm.cos(angle_rad) * u + tfm.sin(angle_rad) * v # rotv = -(u * sin(theta) + (-v) * cos(theta)) rot_v = -tfm.sin(angle_rad) * u + tfm.cos(angle_rad) * v return tf.concat((rot_u, rot_v), axis=-1) def flow_rot90(flow: tf.Tensor, k: int) -> tf.Tensor: """Rotates a flow by a multiple of 90 degrees. Args: flow: The flow image shaped (H, W, 2) to rotate by multiples of 90 degrees. k: The multiplier factor. Returns: A flow image of the same shape as the input rotated by multiples of 90 degrees. """ angle_rad = tf.cast(k, dtype=tf.float32) * 90. * (_PI/180.) flow = tf.image.rot90(flow, k) return _rotate_flow_vectors(flow, angle_rad) def rotate_flow(flow: tf.Tensor, angle_rad: float) -> tf.Tensor: """Rotates a flow by a the provided angle in radians. Args: flow: The flow image shaped (H, W, 2) to rotate by multiples of 90 degrees. angle_rad: The angle to ratate the flow in radians. Returns: A flow image of the same shape as the input rotated by the provided angle in radians. """ flow = tfa_image.rotate( flow, angles=angle_rad, interpolation='bilinear', fill_mode='reflect') return _rotate_flow_vectors(flow, angle_rad) def flow_flip(flow: tf.Tensor) -> tf.Tensor: """Flips a flow left to right. Args: flow: The flow image shaped (H, W, 2) to flip left to right. Returns: A flow image of the same shape as the input flipped left to right. """ flow = tf.image.flip_left_right(tf.identity(flow)) flow_u, flow_v = tf.split(flow, 2, axis=-1) return tf.stack([-1 * flow_u, flow_v], axis=-1) def random_image_rot90(images: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]: """Rotates a stack of images by a random multiples of 90 degrees. Args: images: A tf.Tensor shaped (H, W, num_channels) of images stacked along the channel's axis. Returns: A tf.Tensor of the same rank as the `images` after random rotation by multiples of 90 degrees applied counter-clock wise. """ random_k = tf.random.uniform((), minval=0, maxval=4, dtype=tf.int32) for key in images: images[key] = tf.image.rot90(images[key], k=random_k) return images def random_flip(images: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]: """Flips a stack of images randomly. Args: images: A tf.Tensor shaped (H, W, num_channels) of images stacked along the channel's axis. Returns: A tf.Tensor of the images after random left to right flip. """ prob = tf.random.uniform((), minval=0, maxval=2, dtype=tf.int32) prob = tf.cast(prob, tf.bool) def _identity(image): return image def _flip_left_right(image): return tf.image.flip_left_right(image) # pylint: disable=cell-var-from-loop for key in images: images[key] = tf.cond(prob, lambda: _flip_left_right(images[key]), lambda: _identity(images[key])) return images def random_reverse(images: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]: """Reverses a stack of images randomly. Args: images: A dictionary of tf.Tensors, each shaped (H, W, num_channels), with each tensor being a stack of iamges along the last channel axis. Returns: A dictionary of tf.Tensors, each shaped the same as the input images dict. """ prob = tf.random.uniform((), minval=0, maxval=2, dtype=tf.int32) prob = tf.cast(prob, tf.bool) def _identity(images): return images def _reverse(images): images['x0'], images['x1'] = images['x1'], images['x0'] return images return tf.cond(prob, lambda: _reverse(images), lambda: _identity(images)) def random_rotate(images: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]: """Rotates image randomly with [-45 to 45 degrees]. Args: images: A tf.Tensor shaped (H, W, num_channels) of images stacked along the channel's axis. Returns: A tf.Tensor of the images after random rotation with a bound of -72 to 72 degrees. """ prob = tf.random.uniform((), minval=0, maxval=2, dtype=tf.int32) prob = tf.cast(prob, tf.float32) random_angle = tf.random.uniform((), minval=-0.25 * np.pi, maxval=0.25 * np.pi, dtype=tf.float32) for key in images: images[key] = tfa_image.rotate( images[key], angles=random_angle * prob, interpolation='bilinear', fill_mode='constant') return images @gin.configurable('data_augmentation') def data_augmentations( names: List[str]) -> Dict[str, Callable[..., tf.Tensor]]: """Creates the data augmentation functions. Args: names: The list of augmentation function names. Returns: A dictionary of Callables to the augmentation functions, keyed by their names. """ augmentations = dict() for name in names: if name == 'random_image_rot90': augmentations[name] = random_image_rot90 elif name == 'random_rotate': augmentations[name] = random_rotate elif name == 'random_flip': augmentations[name] = random_flip elif name == 'random_reverse': augmentations[name] = random_reverse else: raise AttributeError('Invalid augmentation function %s' % name) return augmentations