def testInputTypeError(self, use_bias):
"""Errors are thrown for invalid input types."""
conv1 = snt.Conv2D(output_channels=1,
kernel_shape=3,
stride=1,
padding=snt.SAME,
name="conv1",
use_bias=use_bias,
initializers=create_constant_initializers(
1.0, 1.0, use_bias))
for dtype in (tf.float16, tf.float64):
x = tf.constant(np.ones([1, 5, 5, 1]), dtype=dtype)
err = "Input must have dtype tf.float32.*"
with self.assertRaisesRegexp(TypeError, err):
conv1(x)
评论列表
文章目录