def smallest_k_mx(matrix: mx.nd.NDArray, k: int,
only_first_row: bool = False) -> Tuple[Tuple[np.ndarray, np.ndarray], np.ndarray]:
"""
Find the smallest elements in a NDarray.
:param matrix: Any matrix.
:param k: The number of smallest elements to return.
:param only_first_row: If True the search is constrained to the first row of the matrix.
:return: The row indices, column indices and values of the k smallest items in matrix.
"""
if only_first_row:
matrix = mx.nd.reshape(matrix[0], shape=(1, -1))
# pylint: disable=unbalanced-tuple-unpacking
values, indices = mx.nd.topk(matrix, axis=None, k=k, ret_typ='both', is_ascend=True)
return np.unravel_index(indices.astype(np.int32).asnumpy(), matrix.shape), values
评论列表
文章目录