1
File size: 5,281 Bytes
708d62c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# Copyright 2022 Google LLC

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Evaluation library for frame interpolation."""
from typing import Dict, Mapping, Text

from absl import logging
import tensorflow as tf


def _collect_tensors(tensors: tf.Tensor) -> tf.Tensor:
  """Collect tensors of the different replicas into a list."""
  return tf.nest.flatten(tensors, expand_composites=True)


@tf.function
def _distributed_eval_step(strategy: tf.distribute.Strategy,
                           batch: Dict[Text, tf.Tensor], model: tf.keras.Model,
                           metrics: Dict[Text, tf.keras.metrics.Metric],
                           checkpoint_step: int) -> Dict[Text, tf.Tensor]:
  """Distributed eval step.

  Args:
    strategy: A Tensorflow distribution strategy.
    batch: A batch of training examples.
    model: The Keras model to evaluate.
    metrics: The Keras metrics used for evaluation (a dictionary).
    checkpoint_step: The iteration number at which the checkpoint is restored.

  Returns:
    list of predictions from each replica.
  """

  def _eval_step(
      batch: Dict[Text, tf.Tensor]) -> Dict[Text, tf.Tensor]:
    """Eval for one step."""
    predictions = model(batch, training=False)
    # Note: these metrics expect batch and prediction dictionaries rather than
    # tensors like standard TF metrics do. This allows our losses and metrics to
    # use a richer set of inputs than just the predicted final image.
    for metric in metrics.values():
      metric.update_state(batch, predictions, checkpoint_step=checkpoint_step)
    return predictions

  return strategy.run(_eval_step, args=(batch,))


def _summarize_image_tensors(combined, prefix, step):
  for name in combined:
    image = combined[name]
    if isinstance(image, tf.Tensor):
      if len(image.shape) == 4 and (image.shape[-1] == 1 or
                                    image.shape[-1] == 3):
        tf.summary.image(prefix + '/' + name, image, step=step)


def eval_loop(strategy: tf.distribute.Strategy,
              eval_base_folder: str,
              model: tf.keras.Model,
              metrics: Dict[str, tf.keras.metrics.Metric],
              datasets: Mapping[str, tf.data.Dataset],
              summary_writer: tf.summary.SummaryWriter,
              checkpoint_step: int):
  """Eval function that is strategy agnostic.

  Args:
    strategy: A Tensorflow distributed strategy.
    eval_base_folder: A path to where the summaries event files and
      checkpoints will be saved.
    model: A function that returns the model.
    metrics: A function that returns the metrics dictionary.
    datasets: A dict of tf.data.Dataset to evaluate on.
    summary_writer: Eval summary writer.
    checkpoint_step: The number of iterations completed.
  """
  logging.info('Saving eval summaries to: %s...', eval_base_folder)
  summary_writer.set_as_default()

  for dataset_name, dataset in datasets.items():
    for metric in metrics.values():
      metric.reset_states()

    logging.info('Loading %s testing data ...', dataset_name)
    dataset = strategy.experimental_distribute_dataset(dataset)

    logging.info('Evaluating %s ...', dataset_name)
    batch_idx = 0
    max_batches_to_summarize = 10
    for batch in dataset:
      predictions = _distributed_eval_step(strategy, batch, model, metrics,
                                           checkpoint_step)
      # Clip interpolator output to [0,1]. Clipping is done only
      # on the eval loop to get better metrics, but not on the training loop
      # so gradients are not killed.
      if strategy.num_replicas_in_sync > 1:
        predictions = {
            'image': tf.concat(predictions['image'].values, axis=0)
        }
      predictions['image'] = tf.clip_by_value(predictions['image'], 0., 1.)
      if batch_idx % 10 == 0:
        logging.info('Evaluating batch %s', batch_idx)
      batch_idx = batch_idx + 1
      if batch_idx < max_batches_to_summarize:
        # Loop through the global batch:
        prefix = f'{dataset_name}/eval_{batch_idx}'
        # Find all tensors that look like images, and summarize:
        combined = {**batch, **predictions}
        _summarize_image_tensors(combined, prefix, step=checkpoint_step)

      elif batch_idx == max_batches_to_summarize:
        tf.summary.flush()

    for name, metric in metrics.items():
      tf.summary.scalar(
          f'{dataset_name}/{name}', metric.result(), step=checkpoint_step)
      tf.summary.flush()
      logging.info('Step {:2}, {} {}'.format(checkpoint_step,
                                             f'{dataset_name}/{name}',
                                             metric.result().numpy()))
      metric.reset_states()