File size: 4,691 Bytes
0074fed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
# Adapted from https://gist.github.com/qqaatw/82b47c2b3da602fa1df604167bfcb9b0
import getopt
import sys
import re
import tensorflow.compat.v1 as tf
usage_str = ('python tensorflow_rename_variables.py '
'--checkpoint_dir=path/to/dir/ --replace_from=substr '
'--replace_to=substr --add_prefix=abc --dry_run')
find_usage_str = ('python tensorflow_rename_variables.py '
'--checkpoint_dir=path/to/dir/ --find_str=[\'!\']substr')
comp_usage_str = ('python tensorflow_rename_variables.py '
'--checkpoint_dir=path/to/dir/ '
'--checkpoint_dir2=path/to/dir/')
def print_usage_str():
print('Please specify a checkpoint_dir. Usage:')
print('%s\nor\n%s\nor\n%s' % (usage_str, find_usage_str, comp_usage_str))
print('Note: checkpoint_dir should be a *DIR*, not a file')
def compare(checkpoint_dir, checkpoint_dir2):
import difflib
with tf.Session():
list1 = [el1 for (el1, el2) in
tf.train.list_variables(checkpoint_dir)]
list2 = [el1 for (el1, el2) in
tf.train.list_variables(checkpoint_dir2)]
for k1 in list1:
if k1 in list2:
continue
else:
print('{} close matches: {}'.format(
k1, difflib.get_close_matches(k1, list2)))
def find(checkpoint_dir, find_str):
with tf.Session():
negate = find_str.startswith('!')
if negate:
find_str = find_str[1:]
for var_name, _ in tf.train.list_variables(checkpoint_dir):
if negate and find_str not in var_name:
print('%s missing from %s.' % (find_str, var_name))
if not negate and find_str in var_name:
print('Found %s in %s.' % (find_str, var_name))
def rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run):
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
print('print: ', checkpoint)
with tf.Session() as sess:
for var_name, _ in tf.train.list_variables(checkpoint_dir):
# Load the variable
var= tf.train.load_variable(checkpoint_dir, var_name)
# Set the new name
if None not in [replace_from, replace_to]:
new_name = re.sub(replace_from, replace_to, var_name)
if add_prefix:
new_name = add_prefix + new_name
if dry_run:
print('%s would be renamed to %s.' % (var_name,
new_name))
else:
if var_name != new_name:
print('Renaming %s to %s.' % (var_name, new_name))
# Create the variable, potentially renaming it
var = tf.Variable(var, name=new_name)
if not dry_run:
# Save the variables
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
#saver.save(sess, checkpoint.model_checkpoint_path)
saver.save(sess, "renamed-model.ckpt")
def main(argv):
checkpoint_dir = None
checkpoint_dir2 = None
replace_from = None
replace_to = None
add_prefix = None
dry_run = False
find_str = None
try:
opts, args = getopt.getopt(argv, 'h', ['help=', 'checkpoint_dir=',
'replace_from=', 'replace_to=',
'add_prefix=', 'dry_run',
'find_str=',
'checkpoint_dir2='])
except getopt.GetoptError as e:
print(e)
print_usage_str()
sys.exit(2)
for opt, arg in opts:
if opt in ('-h', '--help'):
print(usage_str)
sys.exit()
elif opt == '--checkpoint_dir':
checkpoint_dir = arg
elif opt == '--checkpoint_dir2':
checkpoint_dir2 = arg
elif opt == '--replace_from':
replace_from = arg
elif opt == '--replace_to':
replace_to = arg
elif opt == '--add_prefix':
add_prefix = arg
elif opt == '--dry_run':
dry_run = True
elif opt == '--find_str':
find_str = arg
if not checkpoint_dir:
print_usage_str()
sys.exit(2)
if checkpoint_dir2:
compare(checkpoint_dir, checkpoint_dir2)
elif find_str:
find(checkpoint_dir, find_str)
else:
rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run)
if __name__ == '__main__':
main(sys.argv[1:]) |