def test_train_override_saver(self):
with tf.Graph().as_default() as g, self.test_session(g):
saver = tf.test.mock.Mock()
tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
with tf.control_dependencies(self._build_inference_graph()):
train_op = tf.assign_add(tf.contrib.framework.get_global_step(), 1)
self._assert_ckpt(self._output_dir, False)
loss = learn.graph_actions._monitored_train( # pylint: disable=protected-access
g,
output_dir=self._output_dir,
train_op=train_op,
loss_op=tf.constant(2.0),
steps=1)
self.assertEqual(2.0, loss)
self._assert_ckpt(self._output_dir, False)
self.assertTrue(saver.build.called)
self.assertEqual(1, saver.save.call_count)
# TODO(ispir): remove following tests after deprecated train.
评论列表
文章目录