|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Evaluation library for frame interpolation.""" |
|
from typing import Dict, Mapping, Text |
|
|
|
from absl import logging |
|
import tensorflow as tf |
|
|
|
|
|
def _collect_tensors(tensors: tf.Tensor) -> tf.Tensor: |
|
"""Collect tensors of the different replicas into a list.""" |
|
return tf.nest.flatten(tensors, expand_composites=True) |
|
|
|
|
|
@tf.function |
|
def _distributed_eval_step(strategy: tf.distribute.Strategy, |
|
batch: Dict[Text, tf.Tensor], model: tf.keras.Model, |
|
metrics: Dict[Text, tf.keras.metrics.Metric], |
|
checkpoint_step: int) -> Dict[Text, tf.Tensor]: |
|
"""Distributed eval step. |
|
|
|
Args: |
|
strategy: A Tensorflow distribution strategy. |
|
batch: A batch of training examples. |
|
model: The Keras model to evaluate. |
|
metrics: The Keras metrics used for evaluation (a dictionary). |
|
checkpoint_step: The iteration number at which the checkpoint is restored. |
|
|
|
Returns: |
|
list of predictions from each replica. |
|
""" |
|
|
|
def _eval_step( |
|
batch: Dict[Text, tf.Tensor]) -> Dict[Text, tf.Tensor]: |
|
"""Eval for one step.""" |
|
predictions = model(batch, training=False) |
|
|
|
|
|
|
|
for metric in metrics.values(): |
|
metric.update_state(batch, predictions, checkpoint_step=checkpoint_step) |
|
return predictions |
|
|
|
return strategy.run(_eval_step, args=(batch,)) |
|
|
|
|
|
def _summarize_image_tensors(combined, prefix, step): |
|
for name in combined: |
|
image = combined[name] |
|
if isinstance(image, tf.Tensor): |
|
if len(image.shape) == 4 and (image.shape[-1] == 1 or |
|
image.shape[-1] == 3): |
|
tf.summary.image(prefix + '/' + name, image, step=step) |
|
|
|
|
|
def eval_loop(strategy: tf.distribute.Strategy, |
|
eval_base_folder: str, |
|
model: tf.keras.Model, |
|
metrics: Dict[str, tf.keras.metrics.Metric], |
|
datasets: Mapping[str, tf.data.Dataset], |
|
summary_writer: tf.summary.SummaryWriter, |
|
checkpoint_step: int): |
|
"""Eval function that is strategy agnostic. |
|
|
|
Args: |
|
strategy: A Tensorflow distributed strategy. |
|
eval_base_folder: A path to where the summaries event files and |
|
checkpoints will be saved. |
|
model: A function that returns the model. |
|
metrics: A function that returns the metrics dictionary. |
|
datasets: A dict of tf.data.Dataset to evaluate on. |
|
summary_writer: Eval summary writer. |
|
checkpoint_step: The number of iterations completed. |
|
""" |
|
logging.info('Saving eval summaries to: %s...', eval_base_folder) |
|
summary_writer.set_as_default() |
|
|
|
for dataset_name, dataset in datasets.items(): |
|
for metric in metrics.values(): |
|
metric.reset_states() |
|
|
|
logging.info('Loading %s testing data ...', dataset_name) |
|
dataset = strategy.experimental_distribute_dataset(dataset) |
|
|
|
logging.info('Evaluating %s ...', dataset_name) |
|
batch_idx = 0 |
|
max_batches_to_summarize = 10 |
|
for batch in dataset: |
|
predictions = _distributed_eval_step(strategy, batch, model, metrics, |
|
checkpoint_step) |
|
|
|
|
|
|
|
if strategy.num_replicas_in_sync > 1: |
|
predictions = { |
|
'image': tf.concat(predictions['image'].values, axis=0) |
|
} |
|
predictions['image'] = tf.clip_by_value(predictions['image'], 0., 1.) |
|
if batch_idx % 10 == 0: |
|
logging.info('Evaluating batch %s', batch_idx) |
|
batch_idx = batch_idx + 1 |
|
if batch_idx < max_batches_to_summarize: |
|
|
|
prefix = f'{dataset_name}/eval_{batch_idx}' |
|
|
|
combined = {**batch, **predictions} |
|
_summarize_image_tensors(combined, prefix, step=checkpoint_step) |
|
|
|
elif batch_idx == max_batches_to_summarize: |
|
tf.summary.flush() |
|
|
|
for name, metric in metrics.items(): |
|
tf.summary.scalar( |
|
f'{dataset_name}/{name}', metric.result(), step=checkpoint_step) |
|
tf.summary.flush() |
|
logging.info('Step {:2}, {} {}'.format(checkpoint_step, |
|
f'{dataset_name}/{name}', |
|
metric.result().numpy())) |
|
metric.reset_states() |
|
|