def compute_dr_wrt(self, wrt):
if wrt is not self.v:
return None
v = self.v.r.reshape(-1, 3)
blocks = -np.einsum('ij,ik->ijk', v, v) * (self.ss**(-3./2.)).reshape((-1, 1, 1))
for i in range(3):
blocks[:, i, i] += self.s_inv
if True: # pylint: disable=using-constant-test
data = blocks.ravel()
indptr = np.arange(0, (self.v.r.size+1)*3, 3)
indices = col(np.arange(0, self.v.r.size))
indices = np.hstack([indices, indices, indices])
indices = indices.reshape((-1, 3, 3))
indices = indices.transpose((0, 2, 1)).ravel()
result = sp.csc_matrix((data, indices, indptr), shape=(self.v.r.size, self.v.r.size))
return result
else:
matvec = lambda x: np.einsum('ijk,ik->ij', blocks, x.reshape((blocks.shape[0], 3))).ravel()
return sp.linalg.LinearOperator((self.v.r.size, self.v.r.size), matvec=matvec)
评论列表
文章目录