# Copyright 2022 Google LLC # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # https://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== r"""Training library for frame interpolation using distributed strategy.""" import functools from typing import Any, Callable, Dict, Text, Tuple from absl import logging import tensorflow as tf def _concat_tensors(tensors: tf.Tensor) -> tf.Tensor: """Concat tensors of the different replicas.""" return tf.concat(tf.nest.flatten(tensors, expand_composites=True), axis=0) @tf.function def _distributed_train_step(strategy: tf.distribute.Strategy, batch: Dict[Text, tf.Tensor], model: tf.keras.Model, loss_functions: Dict[Text, Tuple[Callable[..., tf.Tensor], Callable[..., tf.Tensor]]], optimizer: tf.keras.optimizers.Optimizer, iterations: int) -> Dict[Text, Any]: """Distributed training step. Args: strategy: A Tensorflow distribution strategy. batch: A batch of training examples. model: The Keras model to train. loss_functions: The list of Keras losses used to train the model. optimizer: The Keras optimizer used to train the model. iterations: Iteration number used to sample weights to each loss. Returns: A dictionary of train step outputs. """ def _train_step(batch: Dict[Text, tf.Tensor]) -> Dict[Text, tf.Tensor]: """Train for one step.""" with tf.GradientTape() as tape: predictions = model(batch, training=True) losses = [] for (loss_value, loss_weight) in loss_functions.values(): losses.append(loss_value(batch, predictions) * loss_weight(iterations)) loss = tf.add_n(losses) grads = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) # post process for visualization all_data = {'loss': loss} all_data.update(batch) all_data.update(predictions) return all_data step_outputs = strategy.run(_train_step, args=(batch,)) loss = strategy.reduce( tf.distribute.ReduceOp.MEAN, step_outputs['loss'], axis=None) x0 = _concat_tensors(step_outputs['x0']) x1 = _concat_tensors(step_outputs['x1']) y = _concat_tensors(step_outputs['y']) pred_y = _concat_tensors(step_outputs['image']) scalar_summaries = {'training_loss': loss} image_summaries = { 'x0': x0, 'x1': x1, 'y': y, 'pred_y': pred_y } extra_images = { 'importance0', 'importance1', 'x0_warped', 'x1_warped', 'fg_image', 'bg_image', 'fg_alpha', 'x1_unfiltered_warped' } for image in extra_images: if image in step_outputs: image_summaries[image] = _concat_tensors(step_outputs[image]) return { 'loss': loss, 'scalar_summaries': scalar_summaries, 'image_summaries': { f'training/{name}': value for name, value in image_summaries.items() } } def _summary_writer(summaries_dict: Dict[Text, Any]) -> None: """Adds scalar and image summaries.""" # Adds scalar summaries. for key, scalars in summaries_dict['scalar_summaries'].items(): tf.summary.scalar(key, scalars) # Adds image summaries. for key, images in summaries_dict['image_summaries'].items(): tf.summary.image(key, tf.clip_by_value(images, 0.0, 1.0)) tf.summary.histogram(key + '_h', images) def train_loop( strategy: tf.distribute.Strategy, train_set: tf.data.Dataset, create_model_fn: Callable[..., tf.keras.Model], create_losses_fn: Callable[..., Dict[str, Tuple[Callable[..., tf.Tensor], Callable[..., tf.Tensor]]]], create_optimizer_fn: Callable[..., tf.keras.optimizers.Optimizer], distributed_train_step_fn: Callable[[ tf.distribute.Strategy, Dict[str, tf.Tensor], tf.keras.Model, Dict[ str, Tuple[Callable[..., tf.Tensor], Callable[..., tf.Tensor]]], tf.keras.optimizers.Optimizer, int ], Dict[str, Any]], eval_loop_fn: Callable[..., None], create_metrics_fn: Callable[..., Dict[str, tf.keras.metrics.Metric]], eval_folder: Dict[str, Any], eval_datasets: Dict[str, tf.data.Dataset], summary_writer_fn: Callable[[Dict[str, Any]], None], train_folder: str, saved_model_folder: str, num_iterations: int, save_summaries_frequency: int = 500, save_checkpoint_frequency: int = 500, checkpoint_max_to_keep: int = 10, checkpoint_save_every_n_hours: float = 2., timing_frequency: int = 100, logging_frequency: int = 10): """A Tensorflow 2 eager mode training loop. Args: strategy: A Tensorflow distributed strategy. train_set: A tf.data.Dataset to loop through for training. create_model_fn: A callable that returns a tf.keras.Model. create_losses_fn: A callable that returns a tf.keras.losses.Loss. create_optimizer_fn: A callable that returns a tf.keras.optimizers.Optimizer. distributed_train_step_fn: A callable that takes a distribution strategy, a Dict[Text, tf.Tensor] holding the batch of training data, a tf.keras.Model, a tf.keras.losses.Loss, a tf.keras.optimizers.Optimizer, iteartion number to sample a weight value to loos functions, and returns a dictionary to be passed to the summary_writer_fn. eval_loop_fn: Eval loop function. create_metrics_fn: create_metric_fn. eval_folder: A path to where the summary event files and checkpoints will be saved. eval_datasets: A dictionary of evalution tf.data.Dataset to loop through for evaluation. summary_writer_fn: A callable that takes the output of distributed_train_step_fn and writes summaries to be visualized in TensorBoard. train_folder: A path to where the summaries event files and checkpoints will be saved. saved_model_folder: A path to where the saved models are stored. num_iterations: An integer, the number of iterations to train for. save_summaries_frequency: The iteration frequency with which summaries are saved. save_checkpoint_frequency: The iteration frequency with which model checkpoints are saved. checkpoint_max_to_keep: The maximum number of checkpoints to keep. checkpoint_save_every_n_hours: The frequency in hours to keep checkpoints. timing_frequency: The iteration frequency with which to log timing. logging_frequency: How often to output with logging.info(). """ logging.info('Creating training tensorboard summaries ...') summary_writer = tf.summary.create_file_writer(train_folder) if eval_datasets is not None: logging.info('Creating eval tensorboard summaries ...') eval_summary_writer = tf.summary.create_file_writer(eval_folder) train_set = strategy.experimental_distribute_dataset(train_set) with strategy.scope(): logging.info('Building model ...') model = create_model_fn() loss_functions = create_losses_fn() optimizer = create_optimizer_fn() if eval_datasets is not None: metrics = create_metrics_fn() logging.info('Creating checkpoint ...') checkpoint = tf.train.Checkpoint( model=model, optimizer=optimizer, step=optimizer.iterations, epoch=tf.Variable(0, dtype=tf.int64, trainable=False), training_finished=tf.Variable(False, dtype=tf.bool, trainable=False)) logging.info('Restoring old model (if exists) ...') checkpoint_manager = tf.train.CheckpointManager( checkpoint, directory=train_folder, max_to_keep=checkpoint_max_to_keep, keep_checkpoint_every_n_hours=checkpoint_save_every_n_hours) with strategy.scope(): if checkpoint_manager.latest_checkpoint: checkpoint.restore(checkpoint_manager.latest_checkpoint) logging.info('Creating Timer ...') timer = tf.estimator.SecondOrStepTimer(every_steps=timing_frequency) timer.update_last_triggered_step(optimizer.iterations.numpy()) logging.info('Training on devices: %s.', [ el.name.split('/physical_device:')[-1] for el in tf.config.get_visible_devices() ]) # Re-assign training_finished=False, in case we restored a checkpoint. checkpoint.training_finished.assign(False) while optimizer.iterations.numpy() < num_iterations: for i_batch, batch in enumerate(train_set): summary_writer.set_as_default() iterations = optimizer.iterations.numpy() if iterations % logging_frequency == 0: # Log epoch, total iterations and batch index. logging.info('epoch %d; iterations %d; i_batch %d', checkpoint.epoch.numpy(), iterations, i_batch) # Break if the number of iterations exceeds the max. if iterations >= num_iterations: break # Compute distributed step outputs. distributed_step_outputs = distributed_train_step_fn( strategy, batch, model, loss_functions, optimizer, iterations) # Save checkpoint, and optionally run the eval loops. if iterations % save_checkpoint_frequency == 0: checkpoint_manager.save(checkpoint_number=iterations) if eval_datasets is not None: eval_loop_fn( strategy=strategy, eval_base_folder=eval_folder, model=model, metrics=metrics, datasets=eval_datasets, summary_writer=eval_summary_writer, checkpoint_step=iterations) # Write summaries. if iterations % save_summaries_frequency == 0: tf.summary.experimental.set_step(step=iterations) summary_writer_fn(distributed_step_outputs) tf.summary.scalar('learning_rate', optimizer.learning_rate(iterations).numpy()) # Log steps/sec. if timer.should_trigger_for_step(iterations): elapsed_time, elapsed_steps = timer.update_last_triggered_step( iterations) if elapsed_time is not None: steps_per_second = elapsed_steps / elapsed_time tf.summary.scalar( 'steps/sec', steps_per_second, step=optimizer.iterations) # Increment epoch. checkpoint.epoch.assign_add(1) # Assign training_finished variable to True after training is finished and # save the last checkpoint. checkpoint.training_finished.assign(True) checkpoint_manager.save(checkpoint_number=optimizer.iterations.numpy()) # Generate a saved model. model.save(saved_model_folder) def train(strategy: tf.distribute.Strategy, train_folder: str, saved_model_folder: str, n_iterations: int, create_model_fn: Callable[..., tf.keras.Model], create_losses_fn: Callable[..., Dict[str, Tuple[Callable[..., tf.Tensor], Callable[..., tf.Tensor]]]], create_metrics_fn: Callable[..., Dict[str, tf.keras.metrics.Metric]], dataset: tf.data.Dataset, learning_rate: tf.keras.optimizers.schedules.LearningRateSchedule, eval_loop_fn: Callable[..., None], eval_folder: str, eval_datasets: Dict[str, tf.data.Dataset]): """Training function that is strategy agnostic. Args: strategy: A Tensorflow distributed strategy. train_folder: A path to where the summaries event files and checkpoints will be saved. saved_model_folder: A path to where the saved models are stored. n_iterations: An integer, the number of iterations to train for. create_model_fn: A callable that returns tf.keras.Model. create_losses_fn: A callable that returns the losses. create_metrics_fn: A function that returns the metrics dictionary. dataset: The tensorflow dataset object. learning_rate: Keras learning rate schedule object. eval_loop_fn: eval loop function. eval_folder: A path to where eval summaries event files and checkpoints will be saved. eval_datasets: The tensorflow evaluation dataset objects. """ train_loop( strategy=strategy, train_set=dataset, create_model_fn=create_model_fn, create_losses_fn=create_losses_fn, create_optimizer_fn=functools.partial( tf.keras.optimizers.Adam, learning_rate=learning_rate), distributed_train_step_fn=_distributed_train_step, eval_loop_fn=eval_loop_fn, create_metrics_fn=create_metrics_fn, eval_folder=eval_folder, eval_datasets=eval_datasets, summary_writer_fn=_summary_writer, train_folder=train_folder, saved_model_folder=saved_model_folder, num_iterations=n_iterations, save_summaries_frequency=3000, save_checkpoint_frequency=3000) def get_strategy(mode) -> tf.distribute.Strategy: """Creates a distributed strategy.""" strategy = None if mode == 'cpu': strategy = tf.distribute.OneDeviceStrategy('/cpu:0') elif mode == 'gpu': strategy = tf.distribute.MirroredStrategy() else: raise ValueError('Unsupported distributed mode.') return strategy