metric_ops_test.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def testSingleUpdateSomeMissingKIs2(self):
    predictions = tf.constant(self._np_predictions,
                              shape=(self._batch_size, self._num_classes),
                              dtype=tf.float32)
    labels = tf.constant(
        self._np_labels, shape=(self._batch_size,), dtype=tf.int64)
    weights = tf.constant([0, 1, 0, 1], shape=(self._batch_size,),
                          dtype=tf.float32)
    recall, update_op = metrics.streaming_recall_at_k(
        predictions, labels, k=2, weights=weights)
    sp_recall, sp_update_op = metrics.streaming_sparse_recall_at_k(
        predictions, tf.reshape(labels, (self._batch_size, 1)), k=2,
        weights=weights)

    with self.test_session() as sess:
      sess.run(tf.local_variables_initializer())
      self.assertEqual(1.0, sess.run(update_op))
      self.assertEqual(1.0, recall.eval())
      self.assertEqual(1.0, sess.run(sp_update_op))
      self.assertEqual(1.0, sp_recall.eval())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号