test_learner_hooks.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def test_sampling(self):
        hook = learner_hooks.TrainSampleHook(
            params={"every_n_steps": 10}, model_dir=self.model_dir,
            run_config=tf.contrib.learn.RunConfig())

        global_step = tf.contrib.framework.get_or_create_global_step()
        no_op = tf.no_op()
        hook.begin()
        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())
            sess.run(tf.tables_initializer())

            mon_sess = monitored_session._HookedSession(sess, [hook])
            # Should trigger for step 0
            sess.run(tf.assign(global_step, 0))
            mon_sess.run(no_op)

            outfile = os.path.join(self.sample_dir, "samples_000000.txt")
            with open(outfile, "rb") as readfile:
                self.assertIn("Prediction followed by Target @ Step 0",
                              readfile.read().decode("utf-8"))

            # Should not trigger for step 9
            sess.run(tf.assign(global_step, 9))
            mon_sess.run(no_op)
            outfile = os.path.join(self.sample_dir, "samples_000009.txt")
            self.assertFalse(os.path.exists(outfile))

            # Should trigger for step 10
            sess.run(tf.assign(global_step, 10))
            mon_sess.run(no_op)
            outfile = os.path.join(self.sample_dir, "samples_000010.txt")
            with open(outfile, "rb") as readfile:
                self.assertIn("Prediction followed by Target @ Step 10",
                              readfile.read().decode("utf-8"))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号