def testComputationValidMultiChannel(self, use_bias):
"""Run through for something with a known answer using snt.VALID padding."""
conv1 = snt.SeparableConv2D(
output_channels=3,
channel_multiplier=1,
kernel_shape=[3, 3],
padding=snt.VALID,
use_bias=use_bias,
initializers=create_separable_constant_initializers(
1.0, 1.0, 1.0, use_bias))
out = conv1(tf.constant(np.ones([1, 5, 5, 3], dtype=np.float32)))
expected_out = np.array([[[28] * 3] * 3] * 3)
if not use_bias:
expected_out -= 1
with self.test_session():
tf.variables_initializer(
[conv1.w_dw, conv1.w_pw, conv1.b] if use_bias else
[conv1.w_dw, conv1.w_pw]).run()
self.assertAllClose(np.reshape(out.eval(), [3, 3, 3]), expected_out)
评论列表
文章目录