def convert(model_dir, keras_model_file, tf_model_file, name_output='s1_output', num_output=1):
# Parameter False is for tf.keras in TF 1.4. For real Keras, use 0 as parameter
keras.backend.set_learning_phase(False)
keras_model = keras.models.load_model(os.path.join(model_dir, keras_model_file),
custom_objects={'custom_loss': YoloNet.custom_loss})
output = [None] * num_output
out_node_names = [None] * num_output
for i in range(num_output):
out_node_names[i] = name_output + str(i)
output[i] = tf.identity(keras_model.outputs[i], name=out_node_names[i])
sess = keras.backend.get_session()
constant_graph = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(),
out_node_names # All other operations relying on this will also be saved
)
output_file = os.path.join(model_dir, tf_model_file)
with tf.gfile.GFile(output_file, "wb") as f:
f.write(constant_graph.SerializeToString())
print("Converted model was saved as {}.".format(tf_model_file))
评论列表
文章目录