def test_run_key_fn(self):
p = plan.InferPlan()
p.compiler = block_compiler.Compiler.create(
blocks.Scalar() >> blocks.Function(tf.negative))
p.logdir = self.get_temp_dir()
p.examples = xrange(5)
p.outputs = p.compiler.output_tensors
results = []
p.results_fn = results.append
p.key_fn = str
p.batch_size = 3
p.chunk_size = 2
with self.test_session() as sess:
p.run(session=sess)
self.assertEqual(1, len(results))
self.assertEqual(
[('0', (-0,)), ('1', (-1,)), ('2', (-2,)), ('3', (-3,)), ('4', (-4,))],
list(results[0]))
评论列表
文章目录