# Copyright 2022 Google LLC # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # https://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Evaluation library for frame interpolation.""" from typing import Dict, Mapping, Text from absl import logging import tensorflow as tf def _collect_tensors(tensors: tf.Tensor) -> tf.Tensor: """Collect tensors of the different replicas into a list.""" return tf.nest.flatten(tensors, expand_composites=True) @tf.function def _distributed_eval_step(strategy: tf.distribute.Strategy, batch: Dict[Text, tf.Tensor], model: tf.keras.Model, metrics: Dict[Text, tf.keras.metrics.Metric], checkpoint_step: int) -> Dict[Text, tf.Tensor]: """Distributed eval step. Args: strategy: A Tensorflow distribution strategy. batch: A batch of training examples. model: The Keras model to evaluate. metrics: The Keras metrics used for evaluation (a dictionary). checkpoint_step: The iteration number at which the checkpoint is restored. Returns: list of predictions from each replica. """ def _eval_step( batch: Dict[Text, tf.Tensor]) -> Dict[Text, tf.Tensor]: """Eval for one step.""" predictions = model(batch, training=False) # Note: these metrics expect batch and prediction dictionaries rather than # tensors like standard TF metrics do. This allows our losses and metrics to # use a richer set of inputs than just the predicted final image. for metric in metrics.values(): metric.update_state(batch, predictions, checkpoint_step=checkpoint_step) return predictions return strategy.run(_eval_step, args=(batch,)) def _summarize_image_tensors(combined, prefix, step): for name in combined: image = combined[name] if isinstance(image, tf.Tensor): if len(image.shape) == 4 and (image.shape[-1] == 1 or image.shape[-1] == 3): tf.summary.image(prefix + '/' + name, image, step=step) def eval_loop(strategy: tf.distribute.Strategy, eval_base_folder: str, model: tf.keras.Model, metrics: Dict[str, tf.keras.metrics.Metric], datasets: Mapping[str, tf.data.Dataset], summary_writer: tf.summary.SummaryWriter, checkpoint_step: int): """Eval function that is strategy agnostic. Args: strategy: A Tensorflow distributed strategy. eval_base_folder: A path to where the summaries event files and checkpoints will be saved. model: A function that returns the model. metrics: A function that returns the metrics dictionary. datasets: A dict of tf.data.Dataset to evaluate on. summary_writer: Eval summary writer. checkpoint_step: The number of iterations completed. """ logging.info('Saving eval summaries to: %s...', eval_base_folder) summary_writer.set_as_default() for dataset_name, dataset in datasets.items(): for metric in metrics.values(): metric.reset_states() logging.info('Loading %s testing data ...', dataset_name) dataset = strategy.experimental_distribute_dataset(dataset) logging.info('Evaluating %s ...', dataset_name) batch_idx = 0 max_batches_to_summarize = 10 for batch in dataset: predictions = _distributed_eval_step(strategy, batch, model, metrics, checkpoint_step) # Clip interpolator output to [0,1]. Clipping is done only # on the eval loop to get better metrics, but not on the training loop # so gradients are not killed. if strategy.num_replicas_in_sync > 1: predictions = { 'image': tf.concat(predictions['image'].values, axis=0) } predictions['image'] = tf.clip_by_value(predictions['image'], 0., 1.) if batch_idx % 10 == 0: logging.info('Evaluating batch %s', batch_idx) batch_idx = batch_idx + 1 if batch_idx < max_batches_to_summarize: # Loop through the global batch: prefix = f'{dataset_name}/eval_{batch_idx}' # Find all tensors that look like images, and summarize: combined = {**batch, **predictions} _summarize_image_tensors(combined, prefix, step=checkpoint_step) elif batch_idx == max_batches_to_summarize: tf.summary.flush() for name, metric in metrics.items(): tf.summary.scalar( f'{dataset_name}/{name}', metric.result(), step=checkpoint_step) tf.summary.flush() logging.info('Step {:2}, {} {}'.format(checkpoint_step, f'{dataset_name}/{name}', metric.result().numpy())) metric.reset_states()