def test_train_loss(self):
with tf.Graph().as_default() as g, self.test_session(g):
tf.contrib.framework.create_global_step()
loss_var = tf.contrib.framework.local_variable(10.0)
train_op = tf.group(
tf.assign_add(tf.contrib.framework.get_global_step(), 1),
tf.assign_add(loss_var, -1.0))
self._assert_summaries(self._output_dir)
self._assert_ckpt(self._output_dir, False)
loss = learn.graph_actions.train(
g, output_dir=self._output_dir, train_op=train_op,
loss_op=loss_var.value(), steps=6)
# TODO(ebrevdo,ptucker,ispir): this meta_graph_def lacks the
# SaverDef, so we can't add it to the summary assertion test below.
# meta_graph_def = meta_graph.create_meta_graph_def()
self.assertEqual(4.0, loss)
self._assert_summaries(self._output_dir, expected_graphs=[g])
self._assert_ckpt(self._output_dir, True)
评论列表
文章目录