graph_saver.py 文件源码

python
阅读 16 收藏 0 点赞 0 评论 0

项目:han 作者: croath 项目源码 文件源码
def main(_):
    output_node_names = "output_prob"

    session_config = tf.ConfigProto()
    session_config.gpu_options.per_process_gpu_memory_fraction = FLAGS.gpu_fraction

    with tf.Session(config=session_config) as sess:
        ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
        saver = tf.train.import_meta_graph(ckpt + '.meta')
        if ckpt:
            saver.restore(sess, ckpt)

        # for node in input_graph_def.node:
        #     print(node.name, node.op, node.input)

        # Retrieve the protobuf graph definition and fix the batch norm nodes
        # Fix for bug of BN.
        # Ref 1 Solution: https://github.com/davidsandberg/facenet/issues/161
        # Ref 2 Official Issue: https://github.com/tensorflow/tensorflow/issues/3628
        gd = sess.graph.as_graph_def()
        for node in gd.node:
            if node.op == 'RefSwitch':
                node.op = 'Switch'
                for index in range(len(node.input)):
                    if 'moving_' in node.input[index]:
                        node.input[index] = node.input[index] + '/read'
            elif node.op == 'AssignSub':
                node.op = 'Sub'
                if 'use_locking' in node.attr: del node.attr['use_locking']
            elif node.op == 'AssignAdd':
                node.op = 'Add'
                if 'use_locking' in node.attr: del node.attr['use_locking']

        output_graph_def = graph_util.convert_variables_to_constants(
            sess, # The session is used to retrieve the weights
            gd, # The graph_def is used to retrieve the nodes
            output_node_names.split(",") # The output node names are used to select the usefull nodes
        )

        with tf.gfile.GFile(os.path.join(FLAGS.model_dir, 'model.pb'), "wb") as f:
            f.write(output_graph_def.SerializeToString())
        print("%d ops in the final graph." % len(output_graph_def.node))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号