base_models.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:geepee 作者: thangbui 项目源码 文件源码
def compute_posterior_grad_u(self, dmu, dSu):
        # return grads wrt u params and Kuuinv
        triu_ind = np.triu_indices(self.M)
        diag_ind = np.diag_indices(self.M)
        if self.nat_param:
            dSu_via_m = np.einsum('da,db->dab', dmu, self.theta_2)
            dSu += dSu_via_m
            dSuinv = - np.einsum('dab,dbc,dce->dae', self.Su, dSu, self.Su)
            dKuuinv = np.sum(dSuinv, axis=0)
            dtheta1 = dSuinv
            deta2 = np.einsum('dab,db->da', self.Su, dmu)
        else:
            deta2 = dmu
            dtheta1 = dSu
            dKuuinv = 0

        dtheta1T = np.transpose(dtheta1, [0, 2, 1])
        dtheta1_R = np.einsum('dab,dbc->dac', self.theta_1_R, dtheta1 + dtheta1T)
        deta1_R = np.zeros([self.Dout, self.M * (self.M + 1) / 2])
        for d in range(self.Dout):
            dtheta1_R_d = dtheta1_R[d, :, :]
            theta1_R_d = self.theta_1_R[d, :, :]
            dtheta1_R_d[diag_ind] = dtheta1_R_d[diag_ind] * theta1_R_d[diag_ind]
            dtheta1_R_d = dtheta1_R_d[triu_ind]
            deta1_R[d, :] = dtheta1_R_d.reshape((dtheta1_R_d.shape[0], ))

        return deta1_R, deta2, dKuuinv
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号