def test_capture(self):
global_step = tf.contrib.framework.get_or_create_global_step()
# Some test computation
some_weights = tf.get_variable("weigths", [2, 128])
computation = tf.nn.softmax(some_weights)
hook = hooks.MetadataCaptureHook(
params={"step": 5}, model_dir=self.model_dir,
run_config=tf.contrib.learn.RunConfig())
hook.begin()
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
#pylint: disable=W0212
mon_sess = monitored_session._HookedSession(sess, [hook])
# Should not trigger for step 0
sess.run(tf.assign(global_step, 0))
mon_sess.run(computation)
self.assertEqual(gfile.ListDirectory(self.model_dir), [])
# Should trigger *after* step 5
sess.run(tf.assign(global_step, 5))
mon_sess.run(computation)
self.assertEqual(gfile.ListDirectory(self.model_dir), [])
mon_sess.run(computation)
self.assertEqual(
set(gfile.ListDirectory(self.model_dir)),
set(["run_meta", "tfprof_log", "timeline.json"]))
评论列表
文章目录