def test_create_from_params(self):
params = plan.plan_default_params()
params.update({
'mode': plan.Plan.mode_keys.TRAIN,
'truncate_examples': 3,
'num_multiprocess_processes': 4,
'master': 'foo',
'batches_per_epoch': 123})
foo = tf.get_variable('foo', [], tf.float32, tf.constant_initializer(4))
p = plan.Plan.create_from_params(_setup_plan(
compiler=block_compiler.Compiler.create(blocks.Scalar()),
losses={'foo': foo},
examples=xrange(5)), params)
self.assertEqual(p.num_multiprocess_processes, 4)
self.assertEqual(p.master, 'foo')
self.assertEqual(p.batches_per_epoch, 123)
self.assertEqual(p.compute_summaries, True)
self.assertEqual(p.is_chief_trainer, True)
self.assertEqual(p.logdir, os.path.join('/tmp/', 'plan', 'run_0', 'train'))
self.assertEqual(p.rundir, os.path.join('/tmp/', 'plan', 'run_0'))
self.assertEqual(p.plandir, os.path.join('/tmp/', 'plan'))
self.assertEqual([0, 1, 2], list(p.examples))
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
self.assertEqual(4, p.loss_total.eval())
sess.run(p.train_op) # should make loss smaller
self.assertLess(p.loss_total.eval(), 4)
tf.flags.FLAGS.num_multiprocess_processes = 0
tf.flags.FLAGS.task = 42
train_op = tf.no_op()
p = plan.Plan.create_from_flags(_setup_plan(
compiler=block_compiler.Compiler.create(blocks.Scalar()),
losses={'foo': tf.constant(3.14)},
train_op=train_op,
examples=xrange(5)))
self.assertEqual(p.num_multiprocess_processes, 0)
self.assertEqual(p.compute_summaries, False)
self.assertEqual(p.is_chief_trainer, False)
self.assertEqual(p.train_op, train_op)
评论列表
文章目录