def testAggregateMultipleMetricsReturnsListsInOrder(self):
predictions = tf.ones((10, 4))
labels = tf.ones((10, 4)) * 3
names_to_values, names_to_updates = metrics.aggregate_metric_map(
{
'm1': metrics.streaming_mean_absolute_error(
predictions, labels),
'm2': metrics.streaming_mean_squared_error(
predictions, labels),
})
self.assertEqual(2, len(names_to_values))
self.assertEqual(2, len(names_to_updates))
with self.test_session() as sess:
sess.run(tf.initialize_local_variables())
self.assertEqual(2, names_to_updates['m1'].eval())
self.assertEqual(4, names_to_updates['m2'].eval())
self.assertEqual(2, names_to_values['m1'].eval())
self.assertEqual(4, names_to_values['m2'].eval())
评论列表
文章目录