def test_finalize_stats_summaries(self):
p = plan.Plan(None)
p.save_summaries_secs = 42
p.losses['foo'] = tf.constant([1.0])
p.losses['bar'] = tf.constant([2.0, 3.0])
p.metrics['baz'] = tf.constant(4)
p.metrics['qux'] = tf.constant([5.0, 6.0])
p.finalize_stats()
with self.test_session():
self.assertEqual(6, p.loss_total.eval({p.batch_size_placeholder: 1}))
summary = tf.Summary()
summary.ParseFromString(p.summaries.eval({p.batch_size_placeholder: 1}))
qux_string = tf.summary.histogram('qux', [5, 6]).eval()
qux_proto = tf.Summary()
qux_proto.ParseFromString(qux_string)
qux_histogram = qux_proto.value[0].histo
expected_values = [
tf.Summary.Value(tag='foo', simple_value=1),
tf.Summary.Value(tag='bar', simple_value=5),
tf.Summary.Value(tag='loss_total', simple_value=6),
tf.Summary.Value(tag='baz', simple_value=4),
tf.Summary.Value(tag='qux', histo=qux_histogram)]
six.assertCountEqual(self, expected_values, summary.value)
summary.ParseFromString(p.summaries.eval({p.batch_size_placeholder: 2}))
expected_values = [
tf.Summary.Value(tag='foo', simple_value=0.5),
tf.Summary.Value(tag='bar', simple_value=2.5),
tf.Summary.Value(tag='loss_total', simple_value=3),
tf.Summary.Value(tag='baz', simple_value=4),
tf.Summary.Value(tag='qux', histo=qux_histogram)]
six.assertCountEqual(self, expected_values, summary.value)
评论列表
文章目录