def testInputTypeError(self, batch_size, in_length, in_channels, out_channels,
kernel_shape, padding, use_bias, out_shape,
stride_shape):
"""Errors are thrown for invalid input types."""
conv1 = snt.Conv1DTranspose(
output_channels=out_channels,
output_shape=out_shape,
kernel_shape=kernel_shape,
padding=padding,
stride=stride_shape,
name="conv1",
use_bias=use_bias)
for dtype in (tf.float16, tf.float64):
x = tf.constant(np.ones([batch_size, in_length,
in_channels]), dtype=dtype)
err = "Input must have dtype tf.float32.*"
with self.assertRaisesRegexp(TypeError, err):
conv1(x)
评论列表
文章目录