def testComparison(self):
# Here we compare the output with the tf.slice equivalent.
in_shape = [2, 3, 4]
inputs = tf.random_uniform(shape=in_shape)
dims = [0, 2]
begin = [1, 2]
size = [1, 2]
mod = snt.SliceByDim(dims=dims, begin=begin, size=size)
output = mod(inputs)
begin_tf = [1, 0, 2]
size_tf = [1, -1, 2]
ref_output = tf.slice(inputs, begin=begin_tf, size=size_tf)
with self.test_session() as sess:
actual, expected = sess.run([output, ref_output])
self.assertAllEqual(actual, expected)
评论列表
文章目录