def test_create_from_flags(self):
tf.flags.FLAGS.mode = plan.Plan.mode_keys.TRAIN
tf.flags.FLAGS.truncate_examples = 3
tf.flags.FLAGS.num_multiprocess_processes = 4
tf.flags.FLAGS.master = 'foo'
tf.flags.FLAGS.batches_per_epoch = 123
foo = tf.get_variable('foo', [], tf.float32, tf.constant_initializer(4))
p = plan.Plan.create_from_flags(_setup_plan(
compiler=block_compiler.Compiler.create(blocks.Scalar()),
losses={'foo': foo},
examples=xrange(5)))
self.assertEqual(p.num_multiprocess_processes, 4)
self.assertEqual(p.master, 'foo')
self.assertEqual(p.batches_per_epoch, 123)
self.assertEqual(p.compute_summaries, True)
self.assertEqual(p.is_chief_trainer, True)
self.assertEqual(p.logdir, os.path.join('/tmp/', 'plan', 'run_0', 'train'))
self.assertEqual(p.rundir, os.path.join('/tmp/', 'plan', 'run_0'))
self.assertEqual(p.plandir, os.path.join('/tmp/', 'plan'))
self.assertEqual([0, 1, 2], list(p.examples))
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
self.assertEqual(4, p.loss_total.eval())
sess.run(p.train_op) # should make loss smaller
self.assertLess(p.loss_total.eval(), 4)
tf.flags.FLAGS.num_multiprocess_processes = 0
tf.flags.FLAGS.task = 42
train_op = tf.no_op()
p = plan.Plan.create_from_flags(_setup_plan(
compiler=block_compiler.Compiler.create(blocks.Scalar()),
losses={'foo': tf.constant(3.14)},
train_op=train_op,
examples=xrange(5)))
self.assertEqual(p.num_multiprocess_processes, 0)
self.assertEqual(p.compute_summaries, False)
self.assertEqual(p.is_chief_trainer, False)
self.assertEqual(p.train_op, train_op)
评论列表
文章目录