impex.py 文件源码

python
阅读 16 收藏 0 点赞 0 评论 0

项目:vae-style-transfer 作者: sunsided 项目源码 文件源码
def export_graph(input_path, output_path, output_nodes, debug=False):
    # todo: might want to look at http://stackoverflow.com/a/39578062/195651

    checkpoint = tf.train.latest_checkpoint(input_path)
    importer = tf.train.import_meta_graph(checkpoint + '.meta', clear_devices=True)

    graph = tf.get_default_graph()  # type: tf.Graph
    gd = graph.as_graph_def()  # type: tf.GraphDef

    if debug:
        op_names = [op.name for op in graph.get_operations()]
        print(op_names)

    # fix batch norm nodes
    # https://github.com/tensorflow/tensorflow/issues/3628
    for node in gd.node:
        if node.op == 'RefSwitch':
            node.op = 'Switch'
            for index in range(len(node.input)):
                if 'moving_' in node.input[index]:
                    node.input[index] += '/read'
        elif node.op == 'AssignSub':
            node.op = 'Sub'
            if 'use_locking' in node.attr:
                del node.attr['use_locking']
        elif node.op == 'AssignAdd':
            node.op = 'Add'
            if 'use_locking' in node.attr:
                del node.attr['use_locking']

    if debug:
        print('Freezing the graph ...')
    with tf.Session() as sess:
        importer.restore(sess, checkpoint)
        output_graph_def = graph_util.convert_variables_to_constants(sess, gd, output_nodes)
        tf.train.write_graph(output_graph_def, path.dirname(output_path), path.basename(output_path), as_text=False)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号