def test_sampling(self):
hook = learner_hooks.TrainSampleHook(
params={"every_n_steps": 10}, model_dir=self.model_dir,
run_config=tf.contrib.learn.RunConfig())
global_step = tf.contrib.framework.get_or_create_global_step()
no_op = tf.no_op()
hook.begin()
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(tf.tables_initializer())
mon_sess = monitored_session._HookedSession(sess, [hook])
# Should trigger for step 0
sess.run(tf.assign(global_step, 0))
mon_sess.run(no_op)
outfile = os.path.join(self.sample_dir, "samples_000000.txt")
with open(outfile, "rb") as readfile:
self.assertIn("Prediction followed by Target @ Step 0",
readfile.read().decode("utf-8"))
# Should not trigger for step 9
sess.run(tf.assign(global_step, 9))
mon_sess.run(no_op)
outfile = os.path.join(self.sample_dir, "samples_000009.txt")
self.assertFalse(os.path.exists(outfile))
# Should trigger for step 10
sess.run(tf.assign(global_step, 10))
mon_sess.run(no_op)
outfile = os.path.join(self.sample_dir, "samples_000010.txt")
with open(outfile, "rb") as readfile:
self.assertIn("Prediction followed by Target @ Step 10",
readfile.read().decode("utf-8"))
评论列表
文章目录