plan_test.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:fold 作者: tensorflow 项目源码 文件源码
def test_dequeue(self):
    p = plan.TrainPlan()
    p.compiler = block_compiler.Compiler().compile(blocks.Scalar())
    p.is_chief_trainer = True
    p.batch_size = 3
    p.batches_per_epoch = 2
    p.queue_capacity = 12
    p.num_dequeuers = 1
    p.ps_tasks = 1
    q = p._create_queue(0)
    p._setup_dequeuing([q])
    input_batch = list(p.compiler.build_loom_inputs([7])) * 3
    q_enqueue = q.enqueue_many([input_batch * 4])
    p.losses['foo'], = p.compiler.output_tensors
    p.train_op = tf.no_op()
    p.finalize_stats()
    p.logdir = self.get_temp_dir()
    p.epochs = 2
    p.print_file = six.StringIO()
    init_op = tf.global_variables_initializer()
    sv = p.create_supervisor()
    with self.test_session() as sess:
      sess.run(init_op)
      sess.run(q_enqueue)
      p.run(sv, sess)
    expected = '\n'.join(['running train',
                          'train_size: 6',
                          'epoch:    1 train[loss: 7.000e+00]',
                          'epoch:    2 train[loss: 7.000e+00]',
                          'final model saved in file: %s' % p.logdir])
    log_str = p.print_file.getvalue()
    self.assertIn(expected, log_str)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号