metric_ops_test.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def testSingleUpdateWithErrorAndWeights(self):
    with self.test_session() as sess:
      predictions = np.array([2, 4, 6, 8])
      labels = np.array([1, 3, 2, 7])
      weights = np.array([0, 1, 3, 1])
      predictions_t = tf.constant(predictions, shape=(1, 4), dtype=tf.float32)
      labels_t = tf.constant(labels, shape=(1, 4), dtype=tf.float32)
      weights_t = tf.constant(weights, shape=(1, 4), dtype=tf.float32)

      pearson_r, update_op = metrics.streaming_pearson_correlation(
          predictions_t, labels_t, weights=weights_t)

      p, l = _reweight(predictions, labels, weights)
      cmat = np.cov(p, l)
      expected_r = cmat[0, 1] / np.sqrt(cmat[0, 0] * cmat[1, 1])
      sess.run(tf.local_variables_initializer())
      self.assertAlmostEqual(expected_r, sess.run(update_op))
      self.assertAlmostEqual(expected_r, pearson_r.eval())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号