设置csr_matrix的行

发布于 2021-01-29 16:51:56

我有一个稀疏的csr_matrix,我想将单行的值更改为不同的值。但是,我找不到简单有效的实现方式。这是它要做的:

A = csr_matrix([[0, 1, 0],
                [1, 0, 1],
                [0, 1, 0]])
new_row = np.array([-1, -1, -1])
print(set_row_csr(A, 2, new_row).todense())

>>> [[ 0,  1, 0],
     [ 1,  0, 1],
     [-1, -1, -1]]

这是我目前的实现set_row_csr

def set_row_csr(A, row_idx, new_row):
    A[row_idx, :] = new_row
    return A

但是,这给了我一个SparseEfficiencyWarning。有没有一种方法可以在没有手动索引操作的情况下完成此操作,或者这是我唯一的出路?

关注者
0
被浏览
44
1 个回答
  • 面试哥
    面试哥 2021-01-29
    为面试而生,有面试问题,就找面试哥。

    最后,我设法通过索引处理来完成此任务。

    def set_row_csr(A, row_idx, new_row):
        '''
        Replace a row in a CSR sparse matrix A.
    
        Parameters
        ----------
        A: csr_matrix
            Matrix to change
        row_idx: int
            index of the row to be changed
        new_row: np.array
            list of new values for the row of A
    
        Returns
        -------
        None (the matrix A is changed in place)
    
        Prerequisites
        -------------
        The row index shall be smaller than the number of rows in A
        The number of elements in new row must be equal to the number of columns in matrix A
        '''
        assert sparse.isspmatrix_csr(A), 'A shall be a csr_matrix'
        assert row_idx < A.shape[0], \
                'The row index ({0}) shall be smaller than the number of rows in A ({1})' \
                .format(row_idx, A.shape[0])
        try:
            N_elements_new_row = len(new_row)
        except TypeError:
            msg = 'Argument new_row shall be a list or numpy array, is now a {0}'\
            .format(type(new_row))
            raise AssertionError(msg)
        N_cols = A.shape[1]
        assert N_cols == N_elements_new_row, \
                'The number of elements in new row ({0}) must be equal to ' \
                'the number of columns in matrix A ({1})' \
                .format(N_elements_new_row, N_cols)
    
        idx_start_row = A.indptr[row_idx]
        idx_end_row = A.indptr[row_idx + 1]
        additional_nnz = N_cols - (idx_end_row - idx_start_row)
    
        A.data = np.r_[A.data[:idx_start_row], new_row, A.data[idx_end_row:]]
        A.indices = np.r_[A.indices[:idx_start_row], np.arange(N_cols), A.indices[idx_end_row:]]
        A.indptr = np.r_[A.indptr[:row_idx + 1], A.indptr[(row_idx + 1):] + additional_nnz]
    


知识点
面圈网VIP题库

面圈网VIP题库全新上线,海量真题题库资源。 90大类考试,超10万份考试真题开放下载啦

去下载看看