conv_test.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:sonnet 作者: deepmind 项目源码 文件源码
def testSharing(self, batch_size, in_length, in_channels, out_channels,
                  kernel_shape, padding, use_bias, out_shape, stride_shape):
    """Sharing is working."""

    conv1 = snt.Conv1DTranspose(
        output_channels=out_channels,
        output_shape=out_shape,
        kernel_shape=kernel_shape,
        padding=padding,
        stride=stride_shape,
        name="conv1",
        use_bias=use_bias)

    x = np.random.randn(batch_size, in_length, in_channels)
    x1 = tf.constant(x, dtype=np.float32)
    x2 = tf.constant(x, dtype=np.float32)

    out1 = conv1(x1)
    out2 = conv1(x2)

    with self.test_session():
      tf.variables_initializer(
          [conv1.w, conv1.b] if use_bias else [conv1.w]).run()

      self.assertAllClose(
          out1.eval(),
          out2.eval())

      # Now change the weights
      w = np.random.randn(1, kernel_shape, out_channels, in_channels)
      conv1.w.assign(w).eval()

      self.assertAllClose(
          out1.eval(),
          out2.eval())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号