def test_reduce_argmax():
def argmax(ndarray, axis, keepdims=False):
res = np.argmax(ndarray, axis=axis)
if keepdims:
res = np.expand_dims(res, axis=axis)
return res
data = np.array([[[5, 1], [20, 2]], [[30, 1], [40, 2]], [[55, 1], [60, 2]]], dtype=np.float32)
assert np.array_equal(import_and_compute('ArgMax', data, axis=0),
argmax(data, keepdims=True, axis=0))
assert np.array_equal(import_and_compute('ArgMax', data, axis=0, keepdims=0),
argmax(data, keepdims=False, axis=0))
assert np.array_equal(import_and_compute('ArgMax', data, axis=1),
argmax(data, keepdims=True, axis=1))
assert np.array_equal(import_and_compute('ArgMax', data, axis=1, keepdims=0),
argmax(data, keepdims=False, axis=1))
assert np.array_equal(import_and_compute('ArgMax', data, axis=2),
argmax(data, keepdims=True, axis=2))
assert np.array_equal(import_and_compute('ArgMax', data, axis=2, keepdims=0),
argmax(data, keepdims=False, axis=2))
评论列表
文章目录