def testInputTypeError(self, use_bias):
"""Test that errors are thrown for invalid input types."""
conv1 = snt.SeparableConv2D(
output_channels=3,
channel_multiplier=1,
kernel_shape=3,
padding=snt.SAME,
use_bias=use_bias,
initializers=create_separable_constant_initializers(
1.0, 1.0, 1.0, use_bias))
for dtype in (tf.float16, tf.float64):
x = tf.constant(np.ones([1, 5, 5, 1]), dtype=dtype)
err = "Input must have dtype tf.float32.*"
with self.assertRaisesRegexp(TypeError, err):
conv1(x)
评论列表
文章目录