metric_ops_test.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def test_top_k_rank_invalid(self):
    with self.test_session():
      # top_k_predictions has rank < 2.
      top_k_predictions = [9, 4, 6, 2, 0]
      sp_labels = tf.SparseTensorValue(
          indices=np.array([[0,], [1,], [2,]], np.int64),
          values=np.array([2, 7, 8], np.int64),
          shape=np.array([10,], np.int64))

      with self.assertRaises(ValueError):
        precision, _ = metrics.streaming_sparse_precision_at_top_k(
            top_k_predictions=tf.constant(top_k_predictions, tf.int64),
            labels=sp_labels)
        tf.initialize_variables(tf.local_variables()).run()
        precision.eval()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号