def testInputExampleIndex(self):
in1 = tf.random_normal((3, 5))
in2 = tf.random_normal((3, 9))
def build(inputs):
a, b = inputs
a.get_shape().assert_is_compatible_with([3 * 5])
b.get_shape().assert_is_compatible_with([3 * 9])
return b
op = snt.Module(build)
# Checks an error is thrown when the input example contains a different
# shape for the leading dimensions as the output.
with self.assertRaises(ValueError):
snt.BatchApply(op, n_dims=2, input_example_index=0)((in1, in2))
# Check correct operation when the specified input example contains the same
# shape for the leading dimensions as the output.
output = snt.BatchApply(op, n_dims=2, input_example_index=1)((in1, in2))
with self.test_session() as sess:
in2_np, out_np = sess.run([in2, output])
self.assertAllEqual(in2_np, out_np)
评论列表
文章目录