metric_ops_test.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def _test_streaming_sparse_recall_at_k(self,
                                         predictions,
                                         labels,
                                         k,
                                         expected,
                                         class_id=None,
                                         weights=None):
    with tf.Graph().as_default() as g, self.test_session(g):
      if weights is not None:
        weights = tf.constant(weights, tf.float32)
      metric, update = metrics.streaming_sparse_recall_at_k(
          predictions=tf.constant(predictions, tf.float32),
          labels=labels, k=k, class_id=class_id, weights=weights)

      # Fails without initialized vars.
      self.assertRaises(tf.OpError, metric.eval)
      self.assertRaises(tf.OpError, update.eval)
      tf.initialize_variables(tf.local_variables()).run()

      # Run per-step op and assert expected values.
      if math.isnan(expected):
        _assert_nan(self, update.eval())
        _assert_nan(self, metric.eval())
      else:
        self.assertEqual(expected, update.eval())
        self.assertEqual(expected, metric.eval())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号