metric_ops_test.py 文件源码

python
阅读 128 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def testSingleUpdateWithErrorAndWeights2(self):
    np_predictions = np.matrix(('1 0 0;'
                                '0 0 -1;'
                                '1 0 0'))
    np_labels = np.matrix(('1 0 0;'
                           '0 0 1;'
                           '0 1 0'))

    predictions = tf.constant(np_predictions, shape=(3, 1, 3), dtype=tf.float32)
    labels = tf.constant(np_labels, shape=(3, 1, 3), dtype=tf.float32)
    weights = tf.constant([0, 1, 1], shape=(3, 1, 1), dtype=tf.float32)

    error, update_op = metrics.streaming_mean_cosine_distance(
        predictions, labels, dim=2, weights=weights)

    with self.test_session() as sess:
      sess.run(tf.initialize_local_variables())
      self.assertEqual(1.5, update_op.eval())
      self.assertEqual(1.5, error.eval())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号