metric_ops_test.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def testEffectivelyEquivalentSizesWithDynamicallyShapedWeight(self):
    predictions = tf.convert_to_tensor([1, 1, 1])  # shape 3,
    labels = tf.expand_dims(tf.convert_to_tensor([1, 0, 0]), 1)  # shape 3, 1

    weights = [[100], [1], [1]]  # shape 3, 1
    weights_placeholder = tf.placeholder(dtype=tf.int32, name='weights')
    feed_dict = {weights_placeholder: weights}

    with self.test_session() as sess:
      accuracy, update_op = metrics.streaming_accuracy(
          predictions, labels, weights_placeholder)

      sess.run(tf.local_variables_initializer())
      # if streaming_accuracy does not flatten the weight, accuracy would be
      # 0.33333334 due to an intended broadcast of weight. Due to flattening,
      # it will be higher than .95
      self.assertGreater(update_op.eval(feed_dict=feed_dict), .95)
      self.assertGreater(accuracy.eval(feed_dict=feed_dict), .95)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号