utils_tf.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:FeatureSqueezing 作者: QData 项目源码 文件源码
def tf_model_eval_dist_tri_input(sess, x, model, X_test1, X_test2, X_test3, mode = 'max'):
    """
    Compute the accuracy of a TF model on some data
    :param sess: TF session to use when training the graph
    :param x: input placeholder
    :param model: model output predictions
    :param X_test[1,2,3]: numpy array with testing inputs
    :param Y_test: numpy array with training outputs
    :return: a float with the accuracy value
    """

    l1_dist_vec = np.zeros((len(X_test1)))

    with sess.as_default():
        # Compute number of batches
        nb_batches = int(math.ceil(float(len(X_test1)) / FLAGS.batch_size))
        assert nb_batches * FLAGS.batch_size >= len(X_test1)

        for batch in range(nb_batches):
            if batch % 100 == 0 and batch > 0:
                print("Batch " + str(batch))

            # Must not use the `batch_indices` function here, because it
            # repeats some examples.
            # It's acceptable to repeat during training, but not eval.
            start = batch * FLAGS.batch_size
            end = min(len(X_test1), start + FLAGS.batch_size)
            cur_batch_size = end - start

            pred_1 = model.eval(feed_dict={x: X_test1[start:end],keras.backend.learning_phase(): 0})
            pred_2 = model.eval(feed_dict={x: X_test2[start:end],keras.backend.learning_phase(): 0})
            pred_3 = model.eval(feed_dict={x: X_test3[start:end],keras.backend.learning_phase(): 0})

            l11 = np.sum(np.abs(pred_1 - pred_2), axis=1)
            l12 = np.sum(np.abs(pred_1 - pred_3), axis=1)
            l13 = np.sum(np.abs(pred_2 - pred_3), axis=1)

            if mode == 'max':
                l1_dist_vec[start:end] = np.max(np.array([l11, l12, l13]), axis=0)
            elif mode == 'mean':
                l1_dist_vec[start:end] = np.mean(np.array([l11, l12, l13]), axis=0)
        assert end >= len(X_test1)

        # Divide by number of examples to get final value

    return l1_dist_vec
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号