def test_compiler_input_tensor(self):
input_tensor = tf.Variable(['foobar', 'baz'],
dtype=tf.string, name='input_variable')
init_op = tf.global_variables_initializer()
root_block = tdb.InputTransform(len) >> tdb.Scalar()
compiler = tdc.Compiler()
compiler.compile(root_block)
compiler.init_loom(max_depth=1, input_tensor=input_tensor)
output_tensor, = compiler.output_tensors
with self.test_session() as sess:
sess.run(init_op)
results = sess.run(output_tensor)
self.assertEqual(len(results), 2)
self.assertEqual(results[0], 6.)
self.assertEqual(results[1], 3.)
sess.run(input_tensor.assign(['foo', 'blah']))
results = sess.run(output_tensor)
self.assertEqual(len(results), 2)
self.assertEqual(results[0], 3.)
self.assertEqual(results[1], 4.)
评论列表
文章目录