|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
r"""Converts TF2 training checkpoint to a saved model. |
|
|
|
The model must match the checkpoint, so the gin config must be given. |
|
|
|
Usage example: |
|
python3 -m frame_interpolation.training.build_saved_model_cli \ |
|
--gin_config <filepath of the gin config the training session was based> \ |
|
--base_folder <base folder of training sessions> \ |
|
--label <the name of the run> |
|
|
|
This will produce a saved model into: <base_folder>/<label>/saved_model |
|
""" |
|
import os |
|
from typing import Sequence |
|
|
|
from . import model_lib |
|
from absl import app |
|
from absl import flags |
|
from absl import logging |
|
import gin.tf |
|
import tensorflow as tf |
|
tf.get_logger().setLevel('ERROR') |
|
|
|
_GIN_CONFIG = flags.DEFINE_string( |
|
name='gin_config', |
|
default='config.gin', |
|
help='Gin config file, saved in the training session <root folder>.') |
|
_LABEL = flags.DEFINE_string( |
|
name='label', |
|
default=None, |
|
required=True, |
|
help='Descriptive label for the training session.') |
|
_BASE_FOLDER = flags.DEFINE_string( |
|
name='base_folder', |
|
default=None, |
|
help='Path to all training sessions.') |
|
_MODE = flags.DEFINE_enum( |
|
name='mode', |
|
default=None, |
|
enum_values=['cpu', 'gpu', 'tpu'], |
|
help='Distributed strategy approach.') |
|
|
|
|
|
def _build_saved_model(checkpoint_path: str, config_files: Sequence[str], |
|
output_model_path: str): |
|
"""Builds a saved model based on the checkpoint directory.""" |
|
gin.parse_config_files_and_bindings( |
|
config_files=config_files, |
|
bindings=None, |
|
skip_unknown=True) |
|
model = model_lib.create_model() |
|
checkpoint = tf.train.Checkpoint(model=model) |
|
checkpoint_file = tf.train.latest_checkpoint(checkpoint_path) |
|
try: |
|
logging.info('Restoring from %s', checkpoint_file) |
|
status = checkpoint.restore(checkpoint_file) |
|
status.assert_existing_objects_matched() |
|
status.expect_partial() |
|
model.save(output_model_path) |
|
except (tf.errors.NotFoundError, AssertionError) as err: |
|
logging.info('Failed to restore checkpoint from %s. Error:\n%s', |
|
checkpoint_file, err) |
|
|
|
|
|
def main(argv): |
|
if len(argv) > 1: |
|
raise app.UsageError('Too many command-line arguments.') |
|
|
|
checkpoint_path = os.path.join(_BASE_FOLDER.value, _LABEL.value, 'train') |
|
if not tf.io.gfile.exists(_GIN_CONFIG.value): |
|
config_file = os.path.join(_BASE_FOLDER.value, _LABEL.value, |
|
_GIN_CONFIG.value) |
|
else: |
|
config_file = _GIN_CONFIG.value |
|
output_model_path = os.path.join(_BASE_FOLDER.value, _LABEL.value, |
|
'saved_model') |
|
_build_saved_model( |
|
checkpoint_path=checkpoint_path, |
|
config_files=[config_file], |
|
output_model_path=output_model_path) |
|
logging.info('The saved model stored into %s/.', output_model_path) |
|
|
|
if __name__ == '__main__': |
|
app.run(main) |
|
|