convertmodel.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:DmsMsgRcg 作者: bshao001 项目源码 文件源码
def convert(model_dir, keras_model_file, tf_model_file, name_output='s1_output', num_output=1):
    # Parameter False is for tf.keras in TF 1.4. For real Keras, use 0 as parameter
    keras.backend.set_learning_phase(False)
    keras_model = keras.models.load_model(os.path.join(model_dir, keras_model_file),
                                          custom_objects={'custom_loss': YoloNet.custom_loss})

    output = [None] * num_output
    out_node_names = [None] * num_output
    for i in range(num_output):
        out_node_names[i] = name_output + str(i)
        output[i] = tf.identity(keras_model.outputs[i], name=out_node_names[i])

    sess = keras.backend.get_session()
    constant_graph = tf.graph_util.convert_variables_to_constants(
        sess,
        sess.graph.as_graph_def(),
        out_node_names  # All other operations relying on this will also be saved
    )
    output_file = os.path.join(model_dir, tf_model_file)
    with tf.gfile.GFile(output_file, "wb") as f:
        f.write(constant_graph.SerializeToString())

    print("Converted model was saved as {}.".format(tf_model_file))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号