def testTupleSelect(self):
"""Test where idx is a tuple."""
shape0 = [1, 2]
shape1 = [1, 2, 3]
shape2 = [1, 2, 3, 4]
input0 = tf.random_uniform(shape=shape0)
input1 = tf.random_uniform(shape=shape1)
input2 = tf.random_uniform(shape=shape2)
mod = snt.SelectInput(idx=(0, 2))
output = mod(input0, input1, input2)
output0 = tf.identity(input0)
output2 = tf.identity(input2)
with self.test_session() as sess:
out = sess.run([output, [output0, output2]])
self.assertAllEqual(out[0][0], out[1][0])
self.assertAllEqual(out[0][1], out[1][1])
评论列表
文章目录