def test_reduce_argmin():
def argmin(ndarray, axis, keepdims=False):
res = np.argmin(ndarray, axis=axis)
if keepdims:
res = np.expand_dims(res, axis=axis)
return res
data = np.array([[[5, 1], [20, 2]], [[30, 1], [40, 2]], [[55, 1], [60, 2]]], dtype=np.float32)
assert np.array_equal(import_and_compute('ArgMin', data, axis=0),
argmin(data, keepdims=True, axis=0))
assert np.array_equal(import_and_compute('ArgMin', data, axis=0, keepdims=0),
argmin(data, keepdims=False, axis=0))
assert np.array_equal(import_and_compute('ArgMin', data, axis=1),
argmin(data, keepdims=True, axis=1))
assert np.array_equal(import_and_compute('ArgMin', data, axis=1, keepdims=0),
argmin(data, keepdims=False, axis=1))
assert np.array_equal(import_and_compute('ArgMin', data, axis=2),
argmin(data, keepdims=True, axis=2))
assert np.array_equal(import_and_compute('ArgMin', data, axis=2, keepdims=0),
argmin(data, keepdims=False, axis=2))
评论列表
文章目录