text_cnn_eval.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:conll16st-hd-sdp 作者: tbmihailov 项目源码 文件源码
def text_cnn_load_model_and_eval_v2(x_test_s1,
                                    x_test_s2,
                  checkpoint_file,
                  allow_soft_placement,
                  log_device_placement,
                  embeddings):
    graph = tf.Graph()
    with graph.as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=allow_soft_placement,
            log_device_placement=log_device_placement)
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            # Load the saved meta graph and restore variables
            saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
            saver.restore(sess, checkpoint_file)

            # Get the placeholders from the graph by name
            input_x_s1 = graph.get_operation_by_name("input_x_s1").outputs[0]
            input_x_s2 = graph.get_operation_by_name("input_x_s2").outputs[0]
            # input_y = graph.get_operation_by_name("input_y").outputs[0]
            dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0]

            # Tensors we want to evaluate
            predictions = graph.get_operation_by_name("output/predictions").outputs[0]

            # Generate batches for one epoch
            batch_size = 50
            batches = data_helpers.batch_iter(list(zip(x_test_s1, x_test_s2)), batch_size, 1, shuffle=False)

            # Collect the predictions here
            all_predictions = []

            # Load embeddings placeholder
            embedding_size = embeddings.shape[1]
            embeddings_number = embeddings.shape[0]
            print 'embedding_size:%s, embeddings_number:%s' % (embedding_size, embeddings_number)
            # with tf.name_scope("embedding"):
            #     embeddings_placeholder = tf.placeholder(tf.float32, shape=[embeddings_number, embedding_size])
            embeddings_placeholder = graph.get_operation_by_name("embedding/Placeholder").outputs[0]

            for batch in batches:
                x_test_batch_s1, x_test_batch_s2 = zip(*batch)
                batch_predictions = sess.run(predictions, {input_x_s1: x_test_batch_s1,
                                                           input_x_s2: x_test_batch_s2,
                                                           dropout_keep_prob: 1.0,
                                                           embeddings_placeholder: embeddings})
                all_predictions = np.concatenate([all_predictions, batch_predictions])

    return all_predictions
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号