def test_top_k_rank_invalid(self):
with self.test_session():
# top_k_predictions has rank < 2.
top_k_predictions = [9, 4, 6, 2, 0]
sp_labels = tf.SparseTensorValue(
indices=np.array([[0,], [1,], [2,]], np.int64),
values=np.array([2, 7, 8], np.int64),
shape=np.array([10,], np.int64))
with self.assertRaises(ValueError):
precision, _ = metrics.streaming_sparse_precision_at_top_k(
top_k_predictions=tf.constant(top_k_predictions, tf.int64),
labels=sp_labels)
tf.initialize_variables(tf.local_variables()).run()
precision.eval()
评论列表
文章目录