conv_test.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:sonnet 作者: deepmind 项目源码 文件源码
def testShapesNotKnown(self, use_bias):
    """The generated shapes are correct when input shape not known."""

    batch_size = 5
    in_length = 32
    in_channels = out_channels = 5
    kernel_shape = 3

    inputs = tf.placeholder(
        tf.float32,
        shape=[None, None, in_channels],
        name="inputs")

    conv1 = snt.Conv1D(
        name="conv1",
        output_channels=out_channels,
        kernel_shape=kernel_shape,
        padding=snt.SAME,
        stride=1,
        use_bias=use_bias)

    output = conv1(inputs)

    with self.test_session():
      tf.variables_initializer(
          [conv1.w, conv1.b] if use_bias else [conv1.w]).run()

      output_eval = output.eval({
          inputs: np.zeros([batch_size, in_length, in_channels])})

      self.assertEqual(
          output_eval.shape,
          (batch_size, in_length, out_channels))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号