conv_test.py 文件源码

python
阅读 39 收藏 0 点赞 0 评论 0

项目:sonnet 作者: deepmind 项目源码 文件源码
def testMask4D(self):
    """4D Masks are applied properly."""

    # This mask, applied on an image filled with 1, should result in an image
    # filled with 17, as there are 18 weights but we zero out one of them.
    mask = np.ones([3, 3, 2, 1], dtype=np.float32)
    mask[0, 0, 0, :] = 0
    inputs = tf.constant(1.0, shape=(1, 5, 5, 2))
    conv1 = snt.Conv2D(
        output_channels=1,
        kernel_shape=3,
        mask=mask,
        padding=snt.VALID,
        use_bias=False,
        initializers=create_constant_initializers(1.0, 0.0, use_bias=False))
    out = conv1(inputs)
    expected_out = np.array([[17] * 3] * 3)
    with self.test_session():
      tf.variables_initializer([conv1.w]).run()
      self.assertAllClose(np.reshape(out.eval(), [3, 3]), expected_out)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号