1 / training /train_lib.py
CHEN11102's picture
Upload 47 files
708d62c verified
# Copyright 2022 Google LLC
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Training library for frame interpolation using distributed strategy."""
import functools
from typing import Any, Callable, Dict, Text, Tuple
from absl import logging
import tensorflow as tf
def _concat_tensors(tensors: tf.Tensor) -> tf.Tensor:
"""Concat tensors of the different replicas."""
return tf.concat(tf.nest.flatten(tensors, expand_composites=True), axis=0)
@tf.function
def _distributed_train_step(strategy: tf.distribute.Strategy,
batch: Dict[Text, tf.Tensor], model: tf.keras.Model,
loss_functions: Dict[Text,
Tuple[Callable[..., tf.Tensor],
Callable[...,
tf.Tensor]]],
optimizer: tf.keras.optimizers.Optimizer,
iterations: int) -> Dict[Text, Any]:
"""Distributed training step.
Args:
strategy: A Tensorflow distribution strategy.
batch: A batch of training examples.
model: The Keras model to train.
loss_functions: The list of Keras losses used to train the model.
optimizer: The Keras optimizer used to train the model.
iterations: Iteration number used to sample weights to each loss.
Returns:
A dictionary of train step outputs.
"""
def _train_step(batch: Dict[Text, tf.Tensor]) -> Dict[Text, tf.Tensor]:
"""Train for one step."""
with tf.GradientTape() as tape:
predictions = model(batch, training=True)
losses = []
for (loss_value, loss_weight) in loss_functions.values():
losses.append(loss_value(batch, predictions) * loss_weight(iterations))
loss = tf.add_n(losses)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
# post process for visualization
all_data = {'loss': loss}
all_data.update(batch)
all_data.update(predictions)
return all_data
step_outputs = strategy.run(_train_step, args=(batch,))
loss = strategy.reduce(
tf.distribute.ReduceOp.MEAN, step_outputs['loss'], axis=None)
x0 = _concat_tensors(step_outputs['x0'])
x1 = _concat_tensors(step_outputs['x1'])
y = _concat_tensors(step_outputs['y'])
pred_y = _concat_tensors(step_outputs['image'])
scalar_summaries = {'training_loss': loss}
image_summaries = {
'x0': x0,
'x1': x1,
'y': y,
'pred_y': pred_y
}
extra_images = {
'importance0', 'importance1', 'x0_warped', 'x1_warped', 'fg_image',
'bg_image', 'fg_alpha', 'x1_unfiltered_warped'
}
for image in extra_images:
if image in step_outputs:
image_summaries[image] = _concat_tensors(step_outputs[image])
return {
'loss': loss,
'scalar_summaries': scalar_summaries,
'image_summaries': {
f'training/{name}': value for name, value in image_summaries.items()
}
}
def _summary_writer(summaries_dict: Dict[Text, Any]) -> None:
"""Adds scalar and image summaries."""
# Adds scalar summaries.
for key, scalars in summaries_dict['scalar_summaries'].items():
tf.summary.scalar(key, scalars)
# Adds image summaries.
for key, images in summaries_dict['image_summaries'].items():
tf.summary.image(key, tf.clip_by_value(images, 0.0, 1.0))
tf.summary.histogram(key + '_h', images)
def train_loop(
strategy: tf.distribute.Strategy,
train_set: tf.data.Dataset,
create_model_fn: Callable[..., tf.keras.Model],
create_losses_fn: Callable[..., Dict[str, Tuple[Callable[..., tf.Tensor],
Callable[..., tf.Tensor]]]],
create_optimizer_fn: Callable[..., tf.keras.optimizers.Optimizer],
distributed_train_step_fn: Callable[[
tf.distribute.Strategy, Dict[str, tf.Tensor], tf.keras.Model, Dict[
str,
Tuple[Callable[..., tf.Tensor],
Callable[..., tf.Tensor]]], tf.keras.optimizers.Optimizer, int
], Dict[str, Any]],
eval_loop_fn: Callable[..., None],
create_metrics_fn: Callable[..., Dict[str, tf.keras.metrics.Metric]],
eval_folder: Dict[str, Any],
eval_datasets: Dict[str, tf.data.Dataset],
summary_writer_fn: Callable[[Dict[str, Any]], None],
train_folder: str,
saved_model_folder: str,
num_iterations: int,
save_summaries_frequency: int = 500,
save_checkpoint_frequency: int = 500,
checkpoint_max_to_keep: int = 10,
checkpoint_save_every_n_hours: float = 2.,
timing_frequency: int = 100,
logging_frequency: int = 10):
"""A Tensorflow 2 eager mode training loop.
Args:
strategy: A Tensorflow distributed strategy.
train_set: A tf.data.Dataset to loop through for training.
create_model_fn: A callable that returns a tf.keras.Model.
create_losses_fn: A callable that returns a tf.keras.losses.Loss.
create_optimizer_fn: A callable that returns a
tf.keras.optimizers.Optimizer.
distributed_train_step_fn: A callable that takes a distribution strategy, a
Dict[Text, tf.Tensor] holding the batch of training data, a
tf.keras.Model, a tf.keras.losses.Loss, a tf.keras.optimizers.Optimizer,
iteartion number to sample a weight value to loos functions,
and returns a dictionary to be passed to the summary_writer_fn.
eval_loop_fn: Eval loop function.
create_metrics_fn: create_metric_fn.
eval_folder: A path to where the summary event files and checkpoints will be
saved.
eval_datasets: A dictionary of evalution tf.data.Dataset to loop through for
evaluation.
summary_writer_fn: A callable that takes the output of
distributed_train_step_fn and writes summaries to be visualized in
TensorBoard.
train_folder: A path to where the summaries event files and checkpoints
will be saved.
saved_model_folder: A path to where the saved models are stored.
num_iterations: An integer, the number of iterations to train for.
save_summaries_frequency: The iteration frequency with which summaries are
saved.
save_checkpoint_frequency: The iteration frequency with which model
checkpoints are saved.
checkpoint_max_to_keep: The maximum number of checkpoints to keep.
checkpoint_save_every_n_hours: The frequency in hours to keep checkpoints.
timing_frequency: The iteration frequency with which to log timing.
logging_frequency: How often to output with logging.info().
"""
logging.info('Creating training tensorboard summaries ...')
summary_writer = tf.summary.create_file_writer(train_folder)
if eval_datasets is not None:
logging.info('Creating eval tensorboard summaries ...')
eval_summary_writer = tf.summary.create_file_writer(eval_folder)
train_set = strategy.experimental_distribute_dataset(train_set)
with strategy.scope():
logging.info('Building model ...')
model = create_model_fn()
loss_functions = create_losses_fn()
optimizer = create_optimizer_fn()
if eval_datasets is not None:
metrics = create_metrics_fn()
logging.info('Creating checkpoint ...')
checkpoint = tf.train.Checkpoint(
model=model,
optimizer=optimizer,
step=optimizer.iterations,
epoch=tf.Variable(0, dtype=tf.int64, trainable=False),
training_finished=tf.Variable(False, dtype=tf.bool, trainable=False))
logging.info('Restoring old model (if exists) ...')
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
directory=train_folder,
max_to_keep=checkpoint_max_to_keep,
keep_checkpoint_every_n_hours=checkpoint_save_every_n_hours)
with strategy.scope():
if checkpoint_manager.latest_checkpoint:
checkpoint.restore(checkpoint_manager.latest_checkpoint)
logging.info('Creating Timer ...')
timer = tf.estimator.SecondOrStepTimer(every_steps=timing_frequency)
timer.update_last_triggered_step(optimizer.iterations.numpy())
logging.info('Training on devices: %s.', [
el.name.split('/physical_device:')[-1]
for el in tf.config.get_visible_devices()
])
# Re-assign training_finished=False, in case we restored a checkpoint.
checkpoint.training_finished.assign(False)
while optimizer.iterations.numpy() < num_iterations:
for i_batch, batch in enumerate(train_set):
summary_writer.set_as_default()
iterations = optimizer.iterations.numpy()
if iterations % logging_frequency == 0:
# Log epoch, total iterations and batch index.
logging.info('epoch %d; iterations %d; i_batch %d',
checkpoint.epoch.numpy(), iterations,
i_batch)
# Break if the number of iterations exceeds the max.
if iterations >= num_iterations:
break
# Compute distributed step outputs.
distributed_step_outputs = distributed_train_step_fn(
strategy, batch, model, loss_functions, optimizer, iterations)
# Save checkpoint, and optionally run the eval loops.
if iterations % save_checkpoint_frequency == 0:
checkpoint_manager.save(checkpoint_number=iterations)
if eval_datasets is not None:
eval_loop_fn(
strategy=strategy,
eval_base_folder=eval_folder,
model=model,
metrics=metrics,
datasets=eval_datasets,
summary_writer=eval_summary_writer,
checkpoint_step=iterations)
# Write summaries.
if iterations % save_summaries_frequency == 0:
tf.summary.experimental.set_step(step=iterations)
summary_writer_fn(distributed_step_outputs)
tf.summary.scalar('learning_rate',
optimizer.learning_rate(iterations).numpy())
# Log steps/sec.
if timer.should_trigger_for_step(iterations):
elapsed_time, elapsed_steps = timer.update_last_triggered_step(
iterations)
if elapsed_time is not None:
steps_per_second = elapsed_steps / elapsed_time
tf.summary.scalar(
'steps/sec', steps_per_second, step=optimizer.iterations)
# Increment epoch.
checkpoint.epoch.assign_add(1)
# Assign training_finished variable to True after training is finished and
# save the last checkpoint.
checkpoint.training_finished.assign(True)
checkpoint_manager.save(checkpoint_number=optimizer.iterations.numpy())
# Generate a saved model.
model.save(saved_model_folder)
def train(strategy: tf.distribute.Strategy, train_folder: str,
saved_model_folder: str, n_iterations: int,
create_model_fn: Callable[..., tf.keras.Model],
create_losses_fn: Callable[..., Dict[str,
Tuple[Callable[..., tf.Tensor],
Callable[...,
tf.Tensor]]]],
create_metrics_fn: Callable[..., Dict[str, tf.keras.metrics.Metric]],
dataset: tf.data.Dataset,
learning_rate: tf.keras.optimizers.schedules.LearningRateSchedule,
eval_loop_fn: Callable[..., None],
eval_folder: str,
eval_datasets: Dict[str, tf.data.Dataset]):
"""Training function that is strategy agnostic.
Args:
strategy: A Tensorflow distributed strategy.
train_folder: A path to where the summaries event files and checkpoints
will be saved.
saved_model_folder: A path to where the saved models are stored.
n_iterations: An integer, the number of iterations to train for.
create_model_fn: A callable that returns tf.keras.Model.
create_losses_fn: A callable that returns the losses.
create_metrics_fn: A function that returns the metrics dictionary.
dataset: The tensorflow dataset object.
learning_rate: Keras learning rate schedule object.
eval_loop_fn: eval loop function.
eval_folder: A path to where eval summaries event files and checkpoints
will be saved.
eval_datasets: The tensorflow evaluation dataset objects.
"""
train_loop(
strategy=strategy,
train_set=dataset,
create_model_fn=create_model_fn,
create_losses_fn=create_losses_fn,
create_optimizer_fn=functools.partial(
tf.keras.optimizers.Adam, learning_rate=learning_rate),
distributed_train_step_fn=_distributed_train_step,
eval_loop_fn=eval_loop_fn,
create_metrics_fn=create_metrics_fn,
eval_folder=eval_folder,
eval_datasets=eval_datasets,
summary_writer_fn=_summary_writer,
train_folder=train_folder,
saved_model_folder=saved_model_folder,
num_iterations=n_iterations,
save_summaries_frequency=3000,
save_checkpoint_frequency=3000)
def get_strategy(mode) -> tf.distribute.Strategy:
"""Creates a distributed strategy."""
strategy = None
if mode == 'cpu':
strategy = tf.distribute.OneDeviceStrategy('/cpu:0')
elif mode == 'gpu':
strategy = tf.distribute.MirroredStrategy()
else:
raise ValueError('Unsupported distributed mode.')
return strategy