File size: 13,782 Bytes
708d62c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 |
# Copyright 2022 Google LLC
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Training library for frame interpolation using distributed strategy."""
import functools
from typing import Any, Callable, Dict, Text, Tuple
from absl import logging
import tensorflow as tf
def _concat_tensors(tensors: tf.Tensor) -> tf.Tensor:
"""Concat tensors of the different replicas."""
return tf.concat(tf.nest.flatten(tensors, expand_composites=True), axis=0)
@tf.function
def _distributed_train_step(strategy: tf.distribute.Strategy,
batch: Dict[Text, tf.Tensor], model: tf.keras.Model,
loss_functions: Dict[Text,
Tuple[Callable[..., tf.Tensor],
Callable[...,
tf.Tensor]]],
optimizer: tf.keras.optimizers.Optimizer,
iterations: int) -> Dict[Text, Any]:
"""Distributed training step.
Args:
strategy: A Tensorflow distribution strategy.
batch: A batch of training examples.
model: The Keras model to train.
loss_functions: The list of Keras losses used to train the model.
optimizer: The Keras optimizer used to train the model.
iterations: Iteration number used to sample weights to each loss.
Returns:
A dictionary of train step outputs.
"""
def _train_step(batch: Dict[Text, tf.Tensor]) -> Dict[Text, tf.Tensor]:
"""Train for one step."""
with tf.GradientTape() as tape:
predictions = model(batch, training=True)
losses = []
for (loss_value, loss_weight) in loss_functions.values():
losses.append(loss_value(batch, predictions) * loss_weight(iterations))
loss = tf.add_n(losses)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
# post process for visualization
all_data = {'loss': loss}
all_data.update(batch)
all_data.update(predictions)
return all_data
step_outputs = strategy.run(_train_step, args=(batch,))
loss = strategy.reduce(
tf.distribute.ReduceOp.MEAN, step_outputs['loss'], axis=None)
x0 = _concat_tensors(step_outputs['x0'])
x1 = _concat_tensors(step_outputs['x1'])
y = _concat_tensors(step_outputs['y'])
pred_y = _concat_tensors(step_outputs['image'])
scalar_summaries = {'training_loss': loss}
image_summaries = {
'x0': x0,
'x1': x1,
'y': y,
'pred_y': pred_y
}
extra_images = {
'importance0', 'importance1', 'x0_warped', 'x1_warped', 'fg_image',
'bg_image', 'fg_alpha', 'x1_unfiltered_warped'
}
for image in extra_images:
if image in step_outputs:
image_summaries[image] = _concat_tensors(step_outputs[image])
return {
'loss': loss,
'scalar_summaries': scalar_summaries,
'image_summaries': {
f'training/{name}': value for name, value in image_summaries.items()
}
}
def _summary_writer(summaries_dict: Dict[Text, Any]) -> None:
"""Adds scalar and image summaries."""
# Adds scalar summaries.
for key, scalars in summaries_dict['scalar_summaries'].items():
tf.summary.scalar(key, scalars)
# Adds image summaries.
for key, images in summaries_dict['image_summaries'].items():
tf.summary.image(key, tf.clip_by_value(images, 0.0, 1.0))
tf.summary.histogram(key + '_h', images)
def train_loop(
strategy: tf.distribute.Strategy,
train_set: tf.data.Dataset,
create_model_fn: Callable[..., tf.keras.Model],
create_losses_fn: Callable[..., Dict[str, Tuple[Callable[..., tf.Tensor],
Callable[..., tf.Tensor]]]],
create_optimizer_fn: Callable[..., tf.keras.optimizers.Optimizer],
distributed_train_step_fn: Callable[[
tf.distribute.Strategy, Dict[str, tf.Tensor], tf.keras.Model, Dict[
str,
Tuple[Callable[..., tf.Tensor],
Callable[..., tf.Tensor]]], tf.keras.optimizers.Optimizer, int
], Dict[str, Any]],
eval_loop_fn: Callable[..., None],
create_metrics_fn: Callable[..., Dict[str, tf.keras.metrics.Metric]],
eval_folder: Dict[str, Any],
eval_datasets: Dict[str, tf.data.Dataset],
summary_writer_fn: Callable[[Dict[str, Any]], None],
train_folder: str,
saved_model_folder: str,
num_iterations: int,
save_summaries_frequency: int = 500,
save_checkpoint_frequency: int = 500,
checkpoint_max_to_keep: int = 10,
checkpoint_save_every_n_hours: float = 2.,
timing_frequency: int = 100,
logging_frequency: int = 10):
"""A Tensorflow 2 eager mode training loop.
Args:
strategy: A Tensorflow distributed strategy.
train_set: A tf.data.Dataset to loop through for training.
create_model_fn: A callable that returns a tf.keras.Model.
create_losses_fn: A callable that returns a tf.keras.losses.Loss.
create_optimizer_fn: A callable that returns a
tf.keras.optimizers.Optimizer.
distributed_train_step_fn: A callable that takes a distribution strategy, a
Dict[Text, tf.Tensor] holding the batch of training data, a
tf.keras.Model, a tf.keras.losses.Loss, a tf.keras.optimizers.Optimizer,
iteartion number to sample a weight value to loos functions,
and returns a dictionary to be passed to the summary_writer_fn.
eval_loop_fn: Eval loop function.
create_metrics_fn: create_metric_fn.
eval_folder: A path to where the summary event files and checkpoints will be
saved.
eval_datasets: A dictionary of evalution tf.data.Dataset to loop through for
evaluation.
summary_writer_fn: A callable that takes the output of
distributed_train_step_fn and writes summaries to be visualized in
TensorBoard.
train_folder: A path to where the summaries event files and checkpoints
will be saved.
saved_model_folder: A path to where the saved models are stored.
num_iterations: An integer, the number of iterations to train for.
save_summaries_frequency: The iteration frequency with which summaries are
saved.
save_checkpoint_frequency: The iteration frequency with which model
checkpoints are saved.
checkpoint_max_to_keep: The maximum number of checkpoints to keep.
checkpoint_save_every_n_hours: The frequency in hours to keep checkpoints.
timing_frequency: The iteration frequency with which to log timing.
logging_frequency: How often to output with logging.info().
"""
logging.info('Creating training tensorboard summaries ...')
summary_writer = tf.summary.create_file_writer(train_folder)
if eval_datasets is not None:
logging.info('Creating eval tensorboard summaries ...')
eval_summary_writer = tf.summary.create_file_writer(eval_folder)
train_set = strategy.experimental_distribute_dataset(train_set)
with strategy.scope():
logging.info('Building model ...')
model = create_model_fn()
loss_functions = create_losses_fn()
optimizer = create_optimizer_fn()
if eval_datasets is not None:
metrics = create_metrics_fn()
logging.info('Creating checkpoint ...')
checkpoint = tf.train.Checkpoint(
model=model,
optimizer=optimizer,
step=optimizer.iterations,
epoch=tf.Variable(0, dtype=tf.int64, trainable=False),
training_finished=tf.Variable(False, dtype=tf.bool, trainable=False))
logging.info('Restoring old model (if exists) ...')
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
directory=train_folder,
max_to_keep=checkpoint_max_to_keep,
keep_checkpoint_every_n_hours=checkpoint_save_every_n_hours)
with strategy.scope():
if checkpoint_manager.latest_checkpoint:
checkpoint.restore(checkpoint_manager.latest_checkpoint)
logging.info('Creating Timer ...')
timer = tf.estimator.SecondOrStepTimer(every_steps=timing_frequency)
timer.update_last_triggered_step(optimizer.iterations.numpy())
logging.info('Training on devices: %s.', [
el.name.split('/physical_device:')[-1]
for el in tf.config.get_visible_devices()
])
# Re-assign training_finished=False, in case we restored a checkpoint.
checkpoint.training_finished.assign(False)
while optimizer.iterations.numpy() < num_iterations:
for i_batch, batch in enumerate(train_set):
summary_writer.set_as_default()
iterations = optimizer.iterations.numpy()
if iterations % logging_frequency == 0:
# Log epoch, total iterations and batch index.
logging.info('epoch %d; iterations %d; i_batch %d',
checkpoint.epoch.numpy(), iterations,
i_batch)
# Break if the number of iterations exceeds the max.
if iterations >= num_iterations:
break
# Compute distributed step outputs.
distributed_step_outputs = distributed_train_step_fn(
strategy, batch, model, loss_functions, optimizer, iterations)
# Save checkpoint, and optionally run the eval loops.
if iterations % save_checkpoint_frequency == 0:
checkpoint_manager.save(checkpoint_number=iterations)
if eval_datasets is not None:
eval_loop_fn(
strategy=strategy,
eval_base_folder=eval_folder,
model=model,
metrics=metrics,
datasets=eval_datasets,
summary_writer=eval_summary_writer,
checkpoint_step=iterations)
# Write summaries.
if iterations % save_summaries_frequency == 0:
tf.summary.experimental.set_step(step=iterations)
summary_writer_fn(distributed_step_outputs)
tf.summary.scalar('learning_rate',
optimizer.learning_rate(iterations).numpy())
# Log steps/sec.
if timer.should_trigger_for_step(iterations):
elapsed_time, elapsed_steps = timer.update_last_triggered_step(
iterations)
if elapsed_time is not None:
steps_per_second = elapsed_steps / elapsed_time
tf.summary.scalar(
'steps/sec', steps_per_second, step=optimizer.iterations)
# Increment epoch.
checkpoint.epoch.assign_add(1)
# Assign training_finished variable to True after training is finished and
# save the last checkpoint.
checkpoint.training_finished.assign(True)
checkpoint_manager.save(checkpoint_number=optimizer.iterations.numpy())
# Generate a saved model.
model.save(saved_model_folder)
def train(strategy: tf.distribute.Strategy, train_folder: str,
saved_model_folder: str, n_iterations: int,
create_model_fn: Callable[..., tf.keras.Model],
create_losses_fn: Callable[..., Dict[str,
Tuple[Callable[..., tf.Tensor],
Callable[...,
tf.Tensor]]]],
create_metrics_fn: Callable[..., Dict[str, tf.keras.metrics.Metric]],
dataset: tf.data.Dataset,
learning_rate: tf.keras.optimizers.schedules.LearningRateSchedule,
eval_loop_fn: Callable[..., None],
eval_folder: str,
eval_datasets: Dict[str, tf.data.Dataset]):
"""Training function that is strategy agnostic.
Args:
strategy: A Tensorflow distributed strategy.
train_folder: A path to where the summaries event files and checkpoints
will be saved.
saved_model_folder: A path to where the saved models are stored.
n_iterations: An integer, the number of iterations to train for.
create_model_fn: A callable that returns tf.keras.Model.
create_losses_fn: A callable that returns the losses.
create_metrics_fn: A function that returns the metrics dictionary.
dataset: The tensorflow dataset object.
learning_rate: Keras learning rate schedule object.
eval_loop_fn: eval loop function.
eval_folder: A path to where eval summaries event files and checkpoints
will be saved.
eval_datasets: The tensorflow evaluation dataset objects.
"""
train_loop(
strategy=strategy,
train_set=dataset,
create_model_fn=create_model_fn,
create_losses_fn=create_losses_fn,
create_optimizer_fn=functools.partial(
tf.keras.optimizers.Adam, learning_rate=learning_rate),
distributed_train_step_fn=_distributed_train_step,
eval_loop_fn=eval_loop_fn,
create_metrics_fn=create_metrics_fn,
eval_folder=eval_folder,
eval_datasets=eval_datasets,
summary_writer_fn=_summary_writer,
train_folder=train_folder,
saved_model_folder=saved_model_folder,
num_iterations=n_iterations,
save_summaries_frequency=3000,
save_checkpoint_frequency=3000)
def get_strategy(mode) -> tf.distribute.Strategy:
"""Creates a distributed strategy."""
strategy = None
if mode == 'cpu':
strategy = tf.distribute.OneDeviceStrategy('/cpu:0')
elif mode == 'gpu':
strategy = tf.distribute.MirroredStrategy()
else:
raise ValueError('Unsupported distributed mode.')
return strategy
|