freezemodel.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:DmsMsgRcg 作者: bshao001 项目源码 文件源码
def freeze(model_scope, model_dir, model_file):
    """
    Args:
        model_scope: The prefix of all variables in the model.
        model_dir: The full path to the folder in which the result file locates.
        model_file: The file that saves the training results, without file suffix / extension.
    """
    saver = tf.train.import_meta_graph(os.path.join(model_dir, model_file + ".meta"))
    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()

    with tf.Session() as sess:
        saver.restore(sess, os.path.join(model_dir, model_file))

        print("# All operations:")
        for op in graph.get_operations():
            print(op.name)

        output_node_names = [v.name.split(":")[0] for v in tf.trainable_variables()]
        output_node_names.append("{}/readout/logits".format(model_scope))
        output_graph_def = tf.graph_util.convert_variables_to_constants(
            sess,
            input_graph_def,
            output_node_names
        )

        output_file = os.path.join(model_dir, model_file + ".pb")
        with tf.gfile.GFile(output_file, "wb") as f:
            f.write(output_graph_def.SerializeToString())

        print("Freezed model was saved as {}.pb.".format(model_file))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号