|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Dataset augmentation for frame interpolation.""" |
|
from typing import Callable, Dict, List |
|
|
|
import gin.tf |
|
import numpy as np |
|
import tensorflow as tf |
|
import tensorflow.math as tfm |
|
import tensorflow_addons.image as tfa_image |
|
|
|
_PI = 3.141592653589793 |
|
|
|
|
|
def _rotate_flow_vectors(flow: tf.Tensor, angle_rad: float) -> tf.Tensor: |
|
r"""Rotate the (u,v) vector of each pixel with angle in radians. |
|
|
|
Flow matrix system of coordinates. |
|
. . . . u (x) |
|
. |
|
. |
|
. v (-y) |
|
|
|
Rotation system of coordinates. |
|
. y |
|
. |
|
. |
|
. . . . x |
|
Args: |
|
flow: Flow map which has been image-rotated. |
|
angle_rad: The rotation angle in radians. |
|
|
|
Returns: |
|
A flow with the same map but each (u,v) vector rotated by angle_rad. |
|
""" |
|
u, v = tf.split(flow, 2, axis=-1) |
|
|
|
rot_u = tfm.cos(angle_rad) * u + tfm.sin(angle_rad) * v |
|
|
|
rot_v = -tfm.sin(angle_rad) * u + tfm.cos(angle_rad) * v |
|
return tf.concat((rot_u, rot_v), axis=-1) |
|
|
|
|
|
def flow_rot90(flow: tf.Tensor, k: int) -> tf.Tensor: |
|
"""Rotates a flow by a multiple of 90 degrees. |
|
|
|
Args: |
|
flow: The flow image shaped (H, W, 2) to rotate by multiples of 90 degrees. |
|
k: The multiplier factor. |
|
|
|
Returns: |
|
A flow image of the same shape as the input rotated by multiples of 90 |
|
degrees. |
|
""" |
|
angle_rad = tf.cast(k, dtype=tf.float32) * 90. * (_PI/180.) |
|
flow = tf.image.rot90(flow, k) |
|
return _rotate_flow_vectors(flow, angle_rad) |
|
|
|
|
|
def rotate_flow(flow: tf.Tensor, angle_rad: float) -> tf.Tensor: |
|
"""Rotates a flow by a the provided angle in radians. |
|
|
|
Args: |
|
flow: The flow image shaped (H, W, 2) to rotate by multiples of 90 degrees. |
|
angle_rad: The angle to ratate the flow in radians. |
|
|
|
Returns: |
|
A flow image of the same shape as the input rotated by the provided angle in |
|
radians. |
|
""" |
|
flow = tfa_image.rotate( |
|
flow, |
|
angles=angle_rad, |
|
interpolation='bilinear', |
|
fill_mode='reflect') |
|
return _rotate_flow_vectors(flow, angle_rad) |
|
|
|
|
|
def flow_flip(flow: tf.Tensor) -> tf.Tensor: |
|
"""Flips a flow left to right. |
|
|
|
Args: |
|
flow: The flow image shaped (H, W, 2) to flip left to right. |
|
|
|
Returns: |
|
A flow image of the same shape as the input flipped left to right. |
|
""" |
|
flow = tf.image.flip_left_right(tf.identity(flow)) |
|
flow_u, flow_v = tf.split(flow, 2, axis=-1) |
|
return tf.stack([-1 * flow_u, flow_v], axis=-1) |
|
|
|
|
|
def random_image_rot90(images: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]: |
|
"""Rotates a stack of images by a random multiples of 90 degrees. |
|
|
|
Args: |
|
images: A tf.Tensor shaped (H, W, num_channels) of images stacked along the |
|
channel's axis. |
|
Returns: |
|
A tf.Tensor of the same rank as the `images` after random rotation by |
|
multiples of 90 degrees applied counter-clock wise. |
|
""" |
|
random_k = tf.random.uniform((), minval=0, maxval=4, dtype=tf.int32) |
|
for key in images: |
|
images[key] = tf.image.rot90(images[key], k=random_k) |
|
return images |
|
|
|
|
|
def random_flip(images: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]: |
|
"""Flips a stack of images randomly. |
|
|
|
Args: |
|
images: A tf.Tensor shaped (H, W, num_channels) of images stacked along the |
|
channel's axis. |
|
|
|
Returns: |
|
A tf.Tensor of the images after random left to right flip. |
|
""" |
|
prob = tf.random.uniform((), minval=0, maxval=2, dtype=tf.int32) |
|
prob = tf.cast(prob, tf.bool) |
|
|
|
def _identity(image): |
|
return image |
|
|
|
def _flip_left_right(image): |
|
return tf.image.flip_left_right(image) |
|
|
|
|
|
for key in images: |
|
images[key] = tf.cond(prob, lambda: _flip_left_right(images[key]), |
|
lambda: _identity(images[key])) |
|
return images |
|
|
|
|
|
def random_reverse(images: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]: |
|
"""Reverses a stack of images randomly. |
|
|
|
Args: |
|
images: A dictionary of tf.Tensors, each shaped (H, W, num_channels), with |
|
each tensor being a stack of iamges along the last channel axis. |
|
|
|
Returns: |
|
A dictionary of tf.Tensors, each shaped the same as the input images dict. |
|
""" |
|
prob = tf.random.uniform((), minval=0, maxval=2, dtype=tf.int32) |
|
prob = tf.cast(prob, tf.bool) |
|
|
|
def _identity(images): |
|
return images |
|
|
|
def _reverse(images): |
|
images['x0'], images['x1'] = images['x1'], images['x0'] |
|
return images |
|
|
|
return tf.cond(prob, lambda: _reverse(images), lambda: _identity(images)) |
|
|
|
|
|
def random_rotate(images: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]: |
|
"""Rotates image randomly with [-45 to 45 degrees]. |
|
|
|
Args: |
|
images: A tf.Tensor shaped (H, W, num_channels) of images stacked along the |
|
channel's axis. |
|
|
|
Returns: |
|
A tf.Tensor of the images after random rotation with a bound of -72 to 72 |
|
degrees. |
|
""" |
|
prob = tf.random.uniform((), minval=0, maxval=2, dtype=tf.int32) |
|
prob = tf.cast(prob, tf.float32) |
|
random_angle = tf.random.uniform((), |
|
minval=-0.25 * np.pi, |
|
maxval=0.25 * np.pi, |
|
dtype=tf.float32) |
|
|
|
for key in images: |
|
images[key] = tfa_image.rotate( |
|
images[key], |
|
angles=random_angle * prob, |
|
interpolation='bilinear', |
|
fill_mode='constant') |
|
return images |
|
|
|
|
|
@gin.configurable('data_augmentation') |
|
def data_augmentations( |
|
names: List[str]) -> Dict[str, Callable[..., tf.Tensor]]: |
|
"""Creates the data augmentation functions. |
|
|
|
Args: |
|
names: The list of augmentation function names. |
|
Returns: |
|
A dictionary of Callables to the augmentation functions, keyed by their |
|
names. |
|
""" |
|
augmentations = dict() |
|
for name in names: |
|
if name == 'random_image_rot90': |
|
augmentations[name] = random_image_rot90 |
|
elif name == 'random_rotate': |
|
augmentations[name] = random_rotate |
|
elif name == 'random_flip': |
|
augmentations[name] = random_flip |
|
elif name == 'random_reverse': |
|
augmentations[name] = random_reverse |
|
else: |
|
raise AttributeError('Invalid augmentation function %s' % name) |
|
return augmentations |
|
|