|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
r"""The training loop for frame interpolation. |
|
|
|
gin_config: The gin configuration file containing model, losses and datasets. |
|
|
|
To run on GPUs: |
|
python3 -m frame_interpolation.training.train \ |
|
--gin_config <path to network.gin> \ |
|
--base_folder <base folder for all training runs> \ |
|
--label <descriptive label for the run> |
|
|
|
To debug the training loop on CPU: |
|
python3 -m frame_interpolation.training.train \ |
|
--gin_config <path to config.gin> \ |
|
--base_folder /tmp |
|
--label test_run \ |
|
--mode cpu |
|
|
|
The training output directory will be created at <base_folder>/<label>. |
|
""" |
|
import os |
|
|
|
from . import augmentation_lib |
|
from . import data_lib |
|
from . import eval_lib |
|
from . import metrics_lib |
|
from . import model_lib |
|
from . import train_lib |
|
from absl import app |
|
from absl import flags |
|
from absl import logging |
|
import gin.tf |
|
from ..losses import losses |
|
|
|
|
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' |
|
import tensorflow as tf |
|
tf.get_logger().setLevel('ERROR') |
|
|
|
|
|
_GIN_CONFIG = flags.DEFINE_string('gin_config', None, 'Gin config file.') |
|
_LABEL = flags.DEFINE_string('label', 'run0', |
|
'Descriptive label for this run.') |
|
_BASE_FOLDER = flags.DEFINE_string('base_folder', None, |
|
'Path to checkpoints/summaries.') |
|
_MODE = flags.DEFINE_enum('mode', 'gpu', ['cpu', 'gpu'], |
|
'Distributed strategy approach.') |
|
|
|
|
|
@gin.configurable('training') |
|
class TrainingOptions(object): |
|
"""Training-related options.""" |
|
|
|
def __init__(self, learning_rate: float, learning_rate_decay_steps: int, |
|
learning_rate_decay_rate: int, learning_rate_staircase: int, |
|
num_steps: int): |
|
self.learning_rate = learning_rate |
|
self.learning_rate_decay_steps = learning_rate_decay_steps |
|
self.learning_rate_decay_rate = learning_rate_decay_rate |
|
self.learning_rate_staircase = learning_rate_staircase |
|
self.num_steps = num_steps |
|
|
|
|
|
def main(argv): |
|
if len(argv) > 1: |
|
raise app.UsageError('Too many command-line arguments.') |
|
|
|
output_dir = os.path.join(_BASE_FOLDER.value, _LABEL.value) |
|
logging.info('Creating output_dir @ %s ...', output_dir) |
|
|
|
|
|
tf.io.gfile.makedirs(output_dir) |
|
tf.io.gfile.copy( |
|
_GIN_CONFIG.value, os.path.join(output_dir, 'config.gin'), overwrite=True) |
|
|
|
gin.external_configurable( |
|
tf.keras.optimizers.schedules.PiecewiseConstantDecay, |
|
module='tf.keras.optimizers.schedules') |
|
|
|
gin_configs = [_GIN_CONFIG.value] |
|
gin.parse_config_files_and_bindings( |
|
config_files=gin_configs, bindings=None, skip_unknown=True) |
|
|
|
training_options = TrainingOptions() |
|
|
|
learning_rate = tf.keras.optimizers.schedules.ExponentialDecay( |
|
training_options.learning_rate, |
|
training_options.learning_rate_decay_steps, |
|
training_options.learning_rate_decay_rate, |
|
training_options.learning_rate_staircase, |
|
name='learning_rate') |
|
|
|
|
|
augmentation_fns = augmentation_lib.data_augmentations() |
|
|
|
saved_model_folder = os.path.join(_BASE_FOLDER.value, _LABEL.value, |
|
'saved_model') |
|
train_folder = os.path.join(_BASE_FOLDER.value, _LABEL.value, 'train') |
|
eval_folder = os.path.join(_BASE_FOLDER.value, _LABEL.value, 'eval') |
|
|
|
train_lib.train( |
|
strategy=train_lib.get_strategy(_MODE.value), |
|
train_folder=train_folder, |
|
saved_model_folder=saved_model_folder, |
|
n_iterations=training_options.num_steps, |
|
create_model_fn=model_lib.create_model, |
|
create_losses_fn=losses.training_losses, |
|
create_metrics_fn=metrics_lib.create_metrics_fn, |
|
dataset=data_lib.create_training_dataset( |
|
augmentation_fns=augmentation_fns), |
|
learning_rate=learning_rate, |
|
eval_loop_fn=eval_lib.eval_loop, |
|
eval_folder=eval_folder, |
|
eval_datasets=data_lib.create_eval_datasets() or None) |
|
|
|
|
|
if __name__ == '__main__': |
|
app.run(main) |
|
|