def test_train_summaries(self):
with tf.Graph().as_default() as g, self.test_session(g):
with tf.control_dependencies(self._build_inference_graph()):
train_op = tf.assign_add(tf.contrib.framework.get_global_step(), 1)
loss_op = tf.constant(2.0)
tf.summary.scalar('loss', loss_op)
writer = learn.graph_actions.get_summary_writer(self._output_dir)
self._assert_summaries(self._output_dir, writer)
self._assert_ckpt(self._output_dir, False)
loss = learn.graph_actions._monitored_train( # pylint: disable=protected-access
g,
output_dir=self._output_dir,
train_op=train_op,
loss_op=loss_op,
steps=1)
meta_graph_def = meta_graph.create_meta_graph_def()
self.assertEqual(2.0, loss)
self._assert_summaries(self._output_dir, writer,
expected_graphs=[g],
expected_meta_graphs=[meta_graph_def],
expected_summaries={1: {'loss': 2.0}})
self._assert_ckpt(self._output_dir, True)
评论列表
文章目录