test_ops_reduction.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:ngraph 作者: NervanaSystems 项目源码 文件源码
def test_reduce_argmax():
    def argmax(ndarray, axis, keepdims=False):
        res = np.argmax(ndarray, axis=axis)
        if keepdims:
            res = np.expand_dims(res, axis=axis)
        return res

    data = np.array([[[5, 1], [20, 2]], [[30, 1], [40, 2]], [[55, 1], [60, 2]]], dtype=np.float32)

    assert np.array_equal(import_and_compute('ArgMax', data, axis=0),
                          argmax(data, keepdims=True, axis=0))
    assert np.array_equal(import_and_compute('ArgMax', data, axis=0, keepdims=0),
                          argmax(data, keepdims=False, axis=0))
    assert np.array_equal(import_and_compute('ArgMax', data, axis=1),
                          argmax(data, keepdims=True, axis=1))
    assert np.array_equal(import_and_compute('ArgMax', data, axis=1, keepdims=0),
                          argmax(data, keepdims=False, axis=1))
    assert np.array_equal(import_and_compute('ArgMax', data, axis=2),
                          argmax(data, keepdims=True, axis=2))
    assert np.array_equal(import_and_compute('ArgMax', data, axis=2, keepdims=0),
                          argmax(data, keepdims=False, axis=2))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号