def test_train_max_steps_is_not_incremental(self):
with tf.Graph().as_default() as g, self.test_session(g):
with tf.control_dependencies(self._build_inference_graph()):
train_op = tf.assign_add(tf.contrib.framework.get_global_step(), 1)
learn.graph_actions.train(g, output_dir=self._output_dir,
train_op=train_op, loss_op=tf.constant(2.0),
max_steps=10)
step = tf.contrib.framework.load_variable(
self._output_dir, tf.contrib.framework.get_global_step().name)
self.assertEqual(10, step)
with tf.Graph().as_default() as g, self.test_session(g):
with tf.control_dependencies(self._build_inference_graph()):
train_op = tf.assign_add(tf.contrib.framework.get_global_step(), 1)
learn.graph_actions.train(g, output_dir=self._output_dir,
train_op=train_op, loss_op=tf.constant(2.0),
max_steps=15)
step = tf.contrib.framework.load_variable(
self._output_dir, tf.contrib.framework.get_global_step().name)
self.assertEqual(15, step)
评论列表
文章目录