def testComparison(self):
# Here we compare the output with the `tf.tile` equivalent.
in_shape = [2, 3, 4]
inputs = tf.random_uniform(shape=in_shape)
dims = [0, 2]
multiples = [2, 4]
mod = snt.TileByDim(dims=dims, multiples=multiples)
output = mod(inputs)
multiple_tf = [2, 1, 4]
ref_output = tf.tile(inputs, multiples=multiple_tf)
with self.test_session() as sess:
actual, expected = sess.run([output, ref_output])
self.assertAllEqual(actual, expected)
评论列表
文章目录