def test_dequeue(self):
p = plan.TrainPlan()
p.compiler = block_compiler.Compiler().compile(blocks.Scalar())
p.is_chief_trainer = True
p.batch_size = 3
p.batches_per_epoch = 2
p.queue_capacity = 12
p.num_dequeuers = 1
p.ps_tasks = 1
q = p._create_queue(0)
p._setup_dequeuing([q])
input_batch = list(p.compiler.build_loom_inputs([7])) * 3
q_enqueue = q.enqueue_many([input_batch * 4])
p.losses['foo'], = p.compiler.output_tensors
p.train_op = tf.no_op()
p.finalize_stats()
p.logdir = self.get_temp_dir()
p.epochs = 2
p.print_file = six.StringIO()
init_op = tf.global_variables_initializer()
sv = p.create_supervisor()
with self.test_session() as sess:
sess.run(init_op)
sess.run(q_enqueue)
p.run(sv, sess)
expected = '\n'.join(['running train',
'train_size: 6',
'epoch: 1 train[loss: 7.000e+00]',
'epoch: 2 train[loss: 7.000e+00]',
'final model saved in file: %s' % p.logdir])
log_str = p.print_file.getvalue()
self.assertIn(expected, log_str)
评论列表
文章目录