test_ops_convpool.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:ngraph 作者: NervanaSystems 项目源码 文件源码
def test_pool_average_3d(ndarray_1x1x4x4):
    x = np.broadcast_to(ndarray_1x1x4x4, (1, 1, 4, 4, 4))

    node = onnx.helper.make_node('AveragePool', inputs=['x'], outputs=['y'],
                                 kernel_shape=(2, 2, 2), strides=(2, 2, 2))
    y = np.array([[[13.5, 15.5],
                   [21.5, 23.5]],

                  [[13.5, 15.5],
                   [21.5, 23.5]]], dtype=np.float32).reshape(1, 1, 2, 2, 2)
    ng_results = convert_and_calculate(node, [x], [y])

    assert np.array_equal(ng_results, [y])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号