设置csr_matrix的行
发布于 2021-01-29 16:51:56
我有一个稀疏的csr_matrix,我想将单行的值更改为不同的值。但是,我找不到简单有效的实现方式。这是它要做的:
A = csr_matrix([[0, 1, 0],
[1, 0, 1],
[0, 1, 0]])
new_row = np.array([-1, -1, -1])
print(set_row_csr(A, 2, new_row).todense())
>>> [[ 0, 1, 0],
[ 1, 0, 1],
[-1, -1, -1]]
这是我目前的实现set_row_csr
:
def set_row_csr(A, row_idx, new_row):
A[row_idx, :] = new_row
return A
但是,这给了我一个SparseEfficiencyWarning
。有没有一种方法可以在没有手动索引操作的情况下完成此操作,或者这是我唯一的出路?
关注者
0
被浏览
44
1 个回答
-
最后,我设法通过索引处理来完成此任务。
def set_row_csr(A, row_idx, new_row): ''' Replace a row in a CSR sparse matrix A. Parameters ---------- A: csr_matrix Matrix to change row_idx: int index of the row to be changed new_row: np.array list of new values for the row of A Returns ------- None (the matrix A is changed in place) Prerequisites ------------- The row index shall be smaller than the number of rows in A The number of elements in new row must be equal to the number of columns in matrix A ''' assert sparse.isspmatrix_csr(A), 'A shall be a csr_matrix' assert row_idx < A.shape[0], \ 'The row index ({0}) shall be smaller than the number of rows in A ({1})' \ .format(row_idx, A.shape[0]) try: N_elements_new_row = len(new_row) except TypeError: msg = 'Argument new_row shall be a list or numpy array, is now a {0}'\ .format(type(new_row)) raise AssertionError(msg) N_cols = A.shape[1] assert N_cols == N_elements_new_row, \ 'The number of elements in new row ({0}) must be equal to ' \ 'the number of columns in matrix A ({1})' \ .format(N_elements_new_row, N_cols) idx_start_row = A.indptr[row_idx] idx_end_row = A.indptr[row_idx + 1] additional_nnz = N_cols - (idx_end_row - idx_start_row) A.data = np.r_[A.data[:idx_start_row], new_row, A.data[idx_end_row:]] A.indices = np.r_[A.indices[:idx_start_row], np.arange(N_cols), A.indices[idx_end_row:]] A.indptr = np.r_[A.indptr[:row_idx + 1], A.indptr[(row_idx + 1):] + additional_nnz]