plan_test.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:fold 作者: tensorflow 项目源码 文件源码
def test_create_from_params(self):
    params = plan.plan_default_params()
    params.update({
        'mode': plan.Plan.mode_keys.TRAIN,
        'truncate_examples': 3,
        'num_multiprocess_processes': 4,
        'master': 'foo',
        'batches_per_epoch': 123})
    foo = tf.get_variable('foo', [], tf.float32, tf.constant_initializer(4))
    p = plan.Plan.create_from_params(_setup_plan(
        compiler=block_compiler.Compiler.create(blocks.Scalar()),
        losses={'foo': foo},
        examples=xrange(5)), params)
    self.assertEqual(p.num_multiprocess_processes, 4)
    self.assertEqual(p.master, 'foo')
    self.assertEqual(p.batches_per_epoch, 123)
    self.assertEqual(p.compute_summaries, True)
    self.assertEqual(p.is_chief_trainer, True)
    self.assertEqual(p.logdir, os.path.join('/tmp/', 'plan', 'run_0', 'train'))
    self.assertEqual(p.rundir, os.path.join('/tmp/', 'plan', 'run_0'))
    self.assertEqual(p.plandir, os.path.join('/tmp/', 'plan'))
    self.assertEqual([0, 1, 2], list(p.examples))
    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      self.assertEqual(4, p.loss_total.eval())
      sess.run(p.train_op)  # should make loss smaller
      self.assertLess(p.loss_total.eval(), 4)

    tf.flags.FLAGS.num_multiprocess_processes = 0
    tf.flags.FLAGS.task = 42
    train_op = tf.no_op()
    p = plan.Plan.create_from_flags(_setup_plan(
        compiler=block_compiler.Compiler.create(blocks.Scalar()),
        losses={'foo': tf.constant(3.14)},
        train_op=train_op,
        examples=xrange(5)))
    self.assertEqual(p.num_multiprocess_processes, 0)
    self.assertEqual(p.compute_summaries, False)
    self.assertEqual(p.is_chief_trainer, False)
    self.assertEqual(p.train_op, train_op)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号