def _test_streaming_sparse_recall_at_k(self,
predictions,
labels,
k,
expected,
class_id=None,
weights=None):
with tf.Graph().as_default() as g, self.test_session(g):
if weights is not None:
weights = tf.constant(weights, tf.float32)
metric, update = metrics.streaming_sparse_recall_at_k(
predictions=tf.constant(predictions, tf.float32),
labels=labels, k=k, class_id=class_id, weights=weights)
# Fails without initialized vars.
self.assertRaises(tf.OpError, metric.eval)
self.assertRaises(tf.OpError, update.eval)
tf.initialize_variables(tf.local_variables()).run()
# Run per-step op and assert expected values.
if math.isnan(expected):
_assert_nan(self, update.eval())
_assert_nan(self, metric.eval())
else:
self.assertEqual(expected, update.eval())
self.assertEqual(expected, metric.eval())
评论列表
文章目录