File size: 5,281 Bytes
708d62c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
# Copyright 2022 Google LLC
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Evaluation library for frame interpolation."""
from typing import Dict, Mapping, Text
from absl import logging
import tensorflow as tf
def _collect_tensors(tensors: tf.Tensor) -> tf.Tensor:
"""Collect tensors of the different replicas into a list."""
return tf.nest.flatten(tensors, expand_composites=True)
@tf.function
def _distributed_eval_step(strategy: tf.distribute.Strategy,
batch: Dict[Text, tf.Tensor], model: tf.keras.Model,
metrics: Dict[Text, tf.keras.metrics.Metric],
checkpoint_step: int) -> Dict[Text, tf.Tensor]:
"""Distributed eval step.
Args:
strategy: A Tensorflow distribution strategy.
batch: A batch of training examples.
model: The Keras model to evaluate.
metrics: The Keras metrics used for evaluation (a dictionary).
checkpoint_step: The iteration number at which the checkpoint is restored.
Returns:
list of predictions from each replica.
"""
def _eval_step(
batch: Dict[Text, tf.Tensor]) -> Dict[Text, tf.Tensor]:
"""Eval for one step."""
predictions = model(batch, training=False)
# Note: these metrics expect batch and prediction dictionaries rather than
# tensors like standard TF metrics do. This allows our losses and metrics to
# use a richer set of inputs than just the predicted final image.
for metric in metrics.values():
metric.update_state(batch, predictions, checkpoint_step=checkpoint_step)
return predictions
return strategy.run(_eval_step, args=(batch,))
def _summarize_image_tensors(combined, prefix, step):
for name in combined:
image = combined[name]
if isinstance(image, tf.Tensor):
if len(image.shape) == 4 and (image.shape[-1] == 1 or
image.shape[-1] == 3):
tf.summary.image(prefix + '/' + name, image, step=step)
def eval_loop(strategy: tf.distribute.Strategy,
eval_base_folder: str,
model: tf.keras.Model,
metrics: Dict[str, tf.keras.metrics.Metric],
datasets: Mapping[str, tf.data.Dataset],
summary_writer: tf.summary.SummaryWriter,
checkpoint_step: int):
"""Eval function that is strategy agnostic.
Args:
strategy: A Tensorflow distributed strategy.
eval_base_folder: A path to where the summaries event files and
checkpoints will be saved.
model: A function that returns the model.
metrics: A function that returns the metrics dictionary.
datasets: A dict of tf.data.Dataset to evaluate on.
summary_writer: Eval summary writer.
checkpoint_step: The number of iterations completed.
"""
logging.info('Saving eval summaries to: %s...', eval_base_folder)
summary_writer.set_as_default()
for dataset_name, dataset in datasets.items():
for metric in metrics.values():
metric.reset_states()
logging.info('Loading %s testing data ...', dataset_name)
dataset = strategy.experimental_distribute_dataset(dataset)
logging.info('Evaluating %s ...', dataset_name)
batch_idx = 0
max_batches_to_summarize = 10
for batch in dataset:
predictions = _distributed_eval_step(strategy, batch, model, metrics,
checkpoint_step)
# Clip interpolator output to [0,1]. Clipping is done only
# on the eval loop to get better metrics, but not on the training loop
# so gradients are not killed.
if strategy.num_replicas_in_sync > 1:
predictions = {
'image': tf.concat(predictions['image'].values, axis=0)
}
predictions['image'] = tf.clip_by_value(predictions['image'], 0., 1.)
if batch_idx % 10 == 0:
logging.info('Evaluating batch %s', batch_idx)
batch_idx = batch_idx + 1
if batch_idx < max_batches_to_summarize:
# Loop through the global batch:
prefix = f'{dataset_name}/eval_{batch_idx}'
# Find all tensors that look like images, and summarize:
combined = {**batch, **predictions}
_summarize_image_tensors(combined, prefix, step=checkpoint_step)
elif batch_idx == max_batches_to_summarize:
tf.summary.flush()
for name, metric in metrics.items():
tf.summary.scalar(
f'{dataset_name}/{name}', metric.result(), step=checkpoint_step)
tf.summary.flush()
logging.info('Step {:2}, {} {}'.format(checkpoint_step,
f'{dataset_name}/{name}',
metric.result().numpy()))
metric.reset_states()
|