lanpip / models /freeze_graph.py
Parechan's picture
Upload 58 files
f915f2e verified
# EDITED on 10. 9. 2017 for meta graph freezing
#
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Converts checkpoint variables into Const ops in a standalone GraphDef file.
This script is designed to take a GraphDef proto, a SaverDef proto, and a set of
variable values stored in a checkpoint file, and output a GraphDef with all of
the variable ops converted into const ops containing the values of the
variables.
It's useful to do this when we need to load a single file in C++, especially in
environments like mobile or embedded where we may not have access to the
RestoreTensor ops and file loading calls that they rely on.
An example of command-line usage is:
bazel build tensorflow/python/tools:freeze_graph && \
bazel-bin/tensorflow/python/tools/freeze_graph \
--input_graph=some_graph_def.pb \
--input_checkpoint=model.ckpt-8361242 \
--output_graph=/tmp/frozen_graph.pb --output_node_names=softmax
You can also look at freeze_graph_test.py for an example of how to use it.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
from google.protobuf import text_format
from tensorflow.contrib.saved_model.python.saved_model import reader
from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import saver_pb2
from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import session
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import importer
from tensorflow.python.platform import app
from tensorflow.python.platform import gfile
from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.training import saver as saver_lib
FLAGS = None
def freeze_graph_with_def_protos(input_graph_def,
input_saver_def,
input_checkpoint,
output_node_names,
restore_op_name,
filename_tensor_name,
output_graph,
clear_devices,
initializer_nodes,
variable_names_blacklist="",
input_meta_graph_def=None,
input_saved_model_dir=None,
saved_model_tags=None):
"""Converts all variables in a graph and checkpoint into constants."""
del restore_op_name, filename_tensor_name # Unused by updated loading code.
# 'input_checkpoint' may be a prefix if we're using Saver V2 format
if (not input_saved_model_dir and
not saver_lib.checkpoint_exists(input_checkpoint)):
print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
return -1
if not output_node_names:
print("You need to supply the name of a node to --output_node_names.")
return -1
# Remove all the explicit device specifications for this node. This helps to
# make the graph more portable.
if clear_devices:
if input_meta_graph_def:
for node in input_meta_graph_def.graph_def.node:
node.device = ""
elif input_graph_def:
for node in input_graph_def.node:
node.device = ""
if input_graph_def:
_ = importer.import_graph_def(input_graph_def, name="")
with session.Session() as sess:
if input_saver_def:
saver = saver_lib.Saver(saver_def=input_saver_def)
saver.restore(sess, input_checkpoint)
elif input_meta_graph_def:
restorer = saver_lib.import_meta_graph(
input_meta_graph_def, clear_devices=True)
restorer.restore(sess, input_checkpoint)
if initializer_nodes:
sess.run(initializer_nodes.split(","))
elif input_saved_model_dir:
if saved_model_tags is None:
saved_model_tags = []
loader.load(sess, saved_model_tags, input_saved_model_dir)
else:
var_list = {}
reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
try:
tensor = sess.graph.get_tensor_by_name(key + ":0")
except KeyError:
# This tensor doesn't exist in the graph (for example it's
# 'global_step' or a similar housekeeping element) so skip it.
continue
var_list[key] = tensor
saver = saver_lib.Saver(var_list=var_list)
saver.restore(sess, input_checkpoint)
if initializer_nodes:
sess.run(initializer_nodes.split(","))
variable_names_blacklist = (variable_names_blacklist.split(",")
if variable_names_blacklist else None)
if input_meta_graph_def:
output_graph_def = graph_util.convert_variables_to_constants(
sess,
input_meta_graph_def.graph_def,
output_node_names.split(","),
variable_names_blacklist=variable_names_blacklist)
else:
output_graph_def = graph_util.convert_variables_to_constants(
sess,
input_graph_def,
output_node_names.split(","),
variable_names_blacklist=variable_names_blacklist)
# Write GraphDef to file if output path has been given.
if output_graph:
with gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
return output_graph_def
def _parse_input_graph_proto(input_graph, input_binary):
"""Parser input tensorflow graph into GraphDef proto."""
if not gfile.Exists(input_graph):
print("Input graph file '" + input_graph + "' does not exist!")
return -1
input_graph_def = graph_pb2.GraphDef()
mode = "rb" if input_binary else "r"
with gfile.FastGFile(input_graph, mode) as f:
if input_binary:
input_graph_def.ParseFromString(f.read())
else:
text_format.Merge(f.read(), input_graph_def)
return input_graph_def
def _parse_input_meta_graph_proto(input_graph, input_binary):
"""Parser input tensorflow graph into MetaGraphDef proto."""
if not gfile.Exists(input_graph):
print("Input meta graph file '" + input_graph + "' does not exist!")
return -1
input_meta_graph_def = MetaGraphDef()
mode = "rb" if input_binary else "r"
with gfile.FastGFile(input_graph, mode) as f:
if input_binary:
input_meta_graph_def.ParseFromString(f.read())
else:
text_format.Merge(f.read(), input_meta_graph_def)
print("Loaded meta graph file '" + input_graph)
return input_meta_graph_def
def _parse_input_saver_proto(input_saver, input_binary):
"""Parser input tensorflow Saver into SaverDef proto."""
if not gfile.Exists(input_saver):
print("Input saver file '" + input_saver + "' does not exist!")
return -1
mode = "rb" if input_binary else "r"
with gfile.FastGFile(input_saver, mode) as f:
saver_def = saver_pb2.SaverDef()
if input_binary:
saver_def.ParseFromString(f.read())
else:
text_format.Merge(f.read(), saver_def)
return saver_def
def get_meta_graph_def(saved_model_dir, tag_set):
"""Gets MetaGraphDef from SavedModel.
Returns the MetaGraphDef for the given tag-set and SavedModel directory.
Args:
saved_model_dir: Directory containing the SavedModel to inspect or execute.
tag_set: Group of tag(s) of the MetaGraphDef to load, in string format,
separated by ','. For tag-set contains multiple tags, all tags must be
passed in.
Raises:
RuntimeError: An error when the given tag-set does not exist in the
SavedModel.
Returns:
A MetaGraphDef corresponding to the tag-set.
"""
saved_model = reader.read_saved_model(saved_model_dir)
set_of_tags = set(tag_set.split(','))
for meta_graph_def in saved_model.meta_graphs:
if set(meta_graph_def.meta_info_def.tags) == set_of_tags:
return meta_graph_def
raise RuntimeError('MetaGraphDef associated with tag-set ' + tag_set +
' could not be found in SavedModel')
def freeze_graph(input_graph,
input_saver,
input_binary,
input_checkpoint,
output_node_names,
restore_op_name,
filename_tensor_name,
output_graph,
clear_devices,
initializer_nodes,
variable_names_blacklist="",
input_meta_graph=None,
input_saved_model_dir=None,
saved_model_tags=tag_constants.SERVING):
"""Converts all variables in a graph and checkpoint into constants."""
input_graph_def = None
if input_saved_model_dir:
input_graph_def = get_meta_graph_def(
input_saved_model_dir, saved_model_tags).graph_def
elif input_graph:
input_graph_def = _parse_input_graph_proto(input_graph, input_binary)
input_meta_graph_def = None
if input_meta_graph:
input_meta_graph_def = _parse_input_meta_graph_proto(
input_meta_graph, input_binary)
input_saver_def = None
if input_saver:
input_saver_def = _parse_input_saver_proto(input_saver, input_binary)
freeze_graph_with_def_protos(
input_graph_def, input_saver_def, input_checkpoint, output_node_names,
restore_op_name, filename_tensor_name, output_graph, clear_devices,
initializer_nodes, variable_names_blacklist, input_meta_graph_def,
input_saved_model_dir, saved_model_tags.split(","))