def any(x, axis=None, keepdims=False):
def _any(x, axis=axis, keepdims=keepdims):
y = torch.sum(x != 0, axis) != 0
# Since keepdims argument of torch not functional
return y if keepdims else torch.squeeze(y, axis)
def _compute_output_shape(x, axis=axis, keepdims=keepdims):
if axis is None:
return ()
shape = list(_get_shape(x))
if keepdims:
shape[axis] = 1
else:
del shape[axis]
return tuple(shape)
return get_op(_any, output_shape=_compute_output_shape, arguments=[axis, keepdims])(x)
评论列表
文章目录