def testAggregateMultipleMetricsReturnsListsInOrder(self):
predictions = tf.ones((10, 4))
labels = tf.ones((10, 4)) * 3
value_tensors, update_ops = metrics.aggregate_metrics(
metrics.streaming_mean_absolute_error(
predictions, labels),
metrics.streaming_mean_squared_error(
predictions, labels))
self.assertEqual(len(value_tensors), 2)
self.assertEqual(len(update_ops), 2)
with self.test_session() as sess:
sess.run(tf.initialize_local_variables())
self.assertEqual(2, update_ops[0].eval())
self.assertEqual(4, update_ops[1].eval())
self.assertEqual(2, value_tensors[0].eval())
self.assertEqual(4, value_tensors[1].eval())
评论列表
文章目录