conv2d_test.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码
def _VerifyValues(self, input_sizes, filter_sizes, stride, padding, expected):
    """Tests that tf.nn.conv2d produces the expected value.

    Args:
      input_sizes: Input tensor dimensions in
        [batch, input_rows, input_cols, input_depth].
      filter_sizes: Filter tensor dimensions in
        [kernel_rows, kernel_cols, input_depth, output_depth].
      stride: Stride.
      padding: Padding type.
      expected: Expected output.
    """
    total_size_1 = np.prod(input_sizes)
    total_size_2 = np.prod(filter_sizes)
    x1 = np.arange(1, total_size_1 + 1, dtype=np.float32).reshape(input_sizes)
    x2 = np.arange(1, total_size_2 + 1, dtype=np.float32).reshape(filter_sizes)
    strides = [1, stride, stride, 1]

    with self.test_session() as sess:
      with self.test_scope():
        t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes)
        t2 = array_ops.placeholder(dtypes.float32, shape=filter_sizes)
        out = nn_ops.conv2d(
            t1, t2, strides=strides, padding=padding, data_format="NHWC")
      value = sess.run(out, {t1: x1, t2: x2})
      self.assertArrayNear(expected, np.ravel(value), 1e-3)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号