|
import datetime |
|
import pickle |
|
import tensorflow as tf |
|
|
|
|
|
def save_checkpoint(model, current_step, epoch, output_path, **kwargs): |
|
""" Save TF Vocoder model """ |
|
state = { |
|
'model': model.weights, |
|
'step': current_step, |
|
'epoch': epoch, |
|
'date': datetime.date.today().strftime("%B %d, %Y"), |
|
} |
|
state.update(kwargs) |
|
pickle.dump(state, open(output_path, 'wb')) |
|
|
|
|
|
def load_checkpoint(model, checkpoint_path): |
|
""" Load TF Vocoder model """ |
|
checkpoint = pickle.load(open(checkpoint_path, 'rb')) |
|
chkp_var_dict = {var.name: var.numpy() for var in checkpoint['model']} |
|
tf_vars = model.weights |
|
for tf_var in tf_vars: |
|
layer_name = tf_var.name |
|
chkp_var_value = chkp_var_dict[layer_name] |
|
tf.keras.backend.set_value(tf_var, chkp_var_value) |
|
return model |
|
|