metric_ops_test.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def testAggregateMultipleMetricsReturnsListsInOrder(self):
    predictions = tf.ones((10, 4))
    labels = tf.ones((10, 4)) * 3
    names_to_values, names_to_updates = metrics.aggregate_metric_map(
        {
            'm1': metrics.streaming_mean_absolute_error(
                predictions, labels),
            'm2': metrics.streaming_mean_squared_error(
                predictions, labels),
        })

    self.assertEqual(2, len(names_to_values))
    self.assertEqual(2, len(names_to_updates))

    with self.test_session() as sess:
      sess.run(tf.initialize_local_variables())
      self.assertEqual(2, names_to_updates['m1'].eval())
      self.assertEqual(4, names_to_updates['m2'].eval())
      self.assertEqual(2, names_to_values['m1'].eval())
      self.assertEqual(4, names_to_values['m2'].eval())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号