plan_test.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:fold 作者: tensorflow 项目源码 文件源码
def test_create_from_flags(self):
    tf.flags.FLAGS.mode = plan.Plan.mode_keys.TRAIN
    tf.flags.FLAGS.truncate_examples = 3
    tf.flags.FLAGS.num_multiprocess_processes = 4
    tf.flags.FLAGS.master = 'foo'
    tf.flags.FLAGS.batches_per_epoch = 123
    foo = tf.get_variable('foo', [], tf.float32, tf.constant_initializer(4))
    p = plan.Plan.create_from_flags(_setup_plan(
        compiler=block_compiler.Compiler.create(blocks.Scalar()),
        losses={'foo': foo},
        examples=xrange(5)))
    self.assertEqual(p.num_multiprocess_processes, 4)
    self.assertEqual(p.master, 'foo')
    self.assertEqual(p.batches_per_epoch, 123)
    self.assertEqual(p.compute_summaries, True)
    self.assertEqual(p.is_chief_trainer, True)
    self.assertEqual(p.logdir, os.path.join('/tmp/', 'plan', 'run_0', 'train'))
    self.assertEqual(p.rundir, os.path.join('/tmp/', 'plan', 'run_0'))
    self.assertEqual(p.plandir, os.path.join('/tmp/', 'plan'))
    self.assertEqual([0, 1, 2], list(p.examples))
    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      self.assertEqual(4, p.loss_total.eval())
      sess.run(p.train_op)  # should make loss smaller
      self.assertLess(p.loss_total.eval(), 4)

    tf.flags.FLAGS.num_multiprocess_processes = 0
    tf.flags.FLAGS.task = 42
    train_op = tf.no_op()
    p = plan.Plan.create_from_flags(_setup_plan(
        compiler=block_compiler.Compiler.create(blocks.Scalar()),
        losses={'foo': tf.constant(3.14)},
        train_op=train_op,
        examples=xrange(5)))
    self.assertEqual(p.num_multiprocess_processes, 0)
    self.assertEqual(p.compute_summaries, False)
    self.assertEqual(p.is_chief_trainer, False)
    self.assertEqual(p.train_op, train_op)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号