def testShapesNotKnown(self, use_bias):
"""Test that the generated shapes are correct when input shape not known."""
inputs = tf.placeholder(
tf.float32, shape=[None, None, None, self.in_channels], name="inputs")
conv1 = snt.SeparableConv2D(
output_channels=self.out_channels_dw,
channel_multiplier=1,
kernel_shape=self.kernel_shape,
padding=snt.SAME,
use_bias=use_bias)
output = conv1(inputs)
with self.test_session():
tf.variables_initializer(
[conv1.w_dw, conv1.w_pw, conv1.b] if use_bias else
[conv1.w_dw, conv1.w_pw]).run()
output_eval = output.eval({inputs: np.zeros(self.input_shape)})
self.assertEqual(output_eval.shape, tuple(self.output_shape))
评论列表
文章目录