def testMask4D(self):
"""4D Masks are applied properly."""
# This mask, applied on an image filled with 1, should result in an image
# filled with 17, as there are 18 weights but we zero out one of them.
mask = np.ones([3, 3, 2, 1], dtype=np.float32)
mask[0, 0, 0, :] = 0
inputs = tf.constant(1.0, shape=(1, 5, 5, 2))
conv1 = snt.Conv2D(
output_channels=1,
kernel_shape=3,
mask=mask,
padding=snt.VALID,
use_bias=False,
initializers=create_constant_initializers(1.0, 0.0, use_bias=False))
out = conv1(inputs)
expected_out = np.array([[17] * 3] * 3)
with self.test_session():
tf.variables_initializer([conv1.w]).run()
self.assertAllClose(np.reshape(out.eval(), [3, 3]), expected_out)
评论列表
文章目录