conv_test.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:sonnet 作者: deepmind 项目源码 文件源码
def testInputTypeError(self, batch_size, in_length, in_channels, out_channels,
                         kernel_shape, padding, use_bias, out_shape,
                         stride_shape):
    """Errors are thrown for invalid input types."""
    conv1 = snt.Conv1DTranspose(
        output_channels=out_channels,
        output_shape=out_shape,
        kernel_shape=kernel_shape,
        padding=padding,
        stride=stride_shape,
        name="conv1",
        use_bias=use_bias)

    for dtype in (tf.float16, tf.float64):
      x = tf.constant(np.ones([batch_size, in_length,
                               in_channels]), dtype=dtype)
      err = "Input must have dtype tf.float32.*"
      with self.assertRaisesRegexp(TypeError, err):
        conv1(x)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号