def testSingleUpdateKIs2(self):
predictions = tf.constant(self._np_predictions,
shape=(self._batch_size, self._num_classes),
dtype=tf.float32)
labels = tf.constant(
self._np_labels, shape=(self._batch_size,), dtype=tf.int64)
recall, update_op = metrics.streaming_recall_at_k(
predictions, labels, k=2)
sp_recall, sp_update_op = metrics.streaming_sparse_recall_at_k(
predictions, tf.reshape(labels, (self._batch_size, 1)), k=2)
with self.test_session() as sess:
sess.run(tf.local_variables_initializer())
self.assertEqual(0.5, sess.run(update_op))
self.assertEqual(0.5, recall.eval())
self.assertEqual(0.5, sess.run(sp_update_op))
self.assertEqual(0.5, sp_recall.eval())
评论列表
文章目录