plan_test.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:fold 作者: tensorflow 项目源码 文件源码
def test_finalize_stats_summaries(self):
    p = plan.Plan(None)
    p.save_summaries_secs = 42
    p.losses['foo'] = tf.constant([1.0])
    p.losses['bar'] = tf.constant([2.0, 3.0])
    p.metrics['baz'] = tf.constant(4)
    p.metrics['qux'] = tf.constant([5.0, 6.0])
    p.finalize_stats()
    with self.test_session():
      self.assertEqual(6, p.loss_total.eval({p.batch_size_placeholder: 1}))
      summary = tf.Summary()
      summary.ParseFromString(p.summaries.eval({p.batch_size_placeholder: 1}))
      qux_string = tf.summary.histogram('qux', [5, 6]).eval()
      qux_proto = tf.Summary()
      qux_proto.ParseFromString(qux_string)
      qux_histogram = qux_proto.value[0].histo
      expected_values = [
          tf.Summary.Value(tag='foo', simple_value=1),
          tf.Summary.Value(tag='bar', simple_value=5),
          tf.Summary.Value(tag='loss_total', simple_value=6),
          tf.Summary.Value(tag='baz', simple_value=4),
          tf.Summary.Value(tag='qux', histo=qux_histogram)]
      six.assertCountEqual(self, expected_values, summary.value)
      summary.ParseFromString(p.summaries.eval({p.batch_size_placeholder: 2}))
      expected_values = [
          tf.Summary.Value(tag='foo', simple_value=0.5),
          tf.Summary.Value(tag='bar', simple_value=2.5),
          tf.Summary.Value(tag='loss_total', simple_value=3),
          tf.Summary.Value(tag='baz', simple_value=4),
          tf.Summary.Value(tag='qux', histo=qux_histogram)]
      six.assertCountEqual(self, expected_values, summary.value)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号