conv_test.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:sonnet 作者: deepmind 项目源码 文件源码
def testInputTypeError(self):
    """Errors are thrown for invalid input types."""
    conv1 = snt.Conv3D(output_channels=1,
                       kernel_shape=3,
                       stride=1,
                       padding=snt.SAME,
                       name="conv1",
                       initializers={
                           "w": tf.constant_initializer(1.0),
                           "b": tf.constant_initializer(1.0),
                       })

    for dtype in (tf.float16, tf.float64):
      x = tf.constant(np.ones([1, 5, 5, 5, 1]), dtype=dtype)
      self.assertRaisesRegexp(TypeError, "Input must have dtype tf.float32.*",
                              conv1, x)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号