rbf.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:smt 作者: SMTorg 项目源码 文件源码
def _predict_output_derivatives(self, x):
        n = x.shape[0]
        nt = self.nt
        ny = self.training_points[None][0][1].shape[1]
        num = self.num

        dy_dstates = np.empty(n * num['dof'])
        self.rbfc.compute_jac(n, x.flatten(), dy_dstates)
        dy_dstates = dy_dstates.reshape((n, num['dof']))

        dstates_dytl = np.linalg.inv(self.mtx)

        ones = np.ones(self.nt)
        arange = np.arange(self.nt)
        dytl_dyt = csc_matrix((ones, (arange, arange)), shape=(num['dof'], self.nt))

        dy_dyt = (dytl_dyt.T.dot(dstates_dytl.T).dot(dy_dstates.T)).T
        dy_dyt = np.einsum('ij,k->ijk', dy_dyt, np.ones(ny))
        return {None: dy_dyt}
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号