validation_check.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:neurobind 作者: Kyubyong 项目源码 文件源码
def validation_check():
    # Load graph
    g = Graph(is_training=False); print("Graph loaded")

    # Load data
    X, Y = load_data(mode="val")

    with g.graph.as_default():
        sv = tf.train.Supervisor()
        with sv.managed_session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
            # Restore parameters
            sv.saver.restore(sess, tf.train.latest_checkpoint(hp.logdir)); print("Restored!")

            # Get model
            mname = open(hp.logdir + '/checkpoint', 'r').read().split('"')[1]  # model name

            # Inference
            if not os.path.exists(hp.results): os.mkdir(hp.results)
            with open(os.path.join(hp.results, "validation_results.txt"), 'a') as fout:
                expected, predicted = [], []
                for step in range(len(X) // hp.batch_size):
                    x = X[step * hp.batch_size: (step + 1) * hp.batch_size]
                    y = Y[step * hp.batch_size: (step + 1) * hp.batch_size]

                    # predict intensities
                    logits = sess.run(g.logits, {g.x: x})

                    expected.extend(list(y))
                    predicted.extend(list(logits))

                # Get spearman coefficients
                score, _ = spearmanr(expected, predicted)
                fout.write("{}\t{}\n".format(mname, score))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号