aep_models.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:geepee 作者: thangbui 项目源码 文件源码
def compute_cav_grad_u(self, dmu, dSu, alpha):
        # return grads wrt u params and Kuuinv
        triu_ind = np.triu_indices(self.M)
        diag_ind = np.diag_indices(self.M)
        beta = (self.N - alpha) * 1.0 / self.N
        if self.nat_param:
            dSu_via_m = np.einsum('da,db->dab', dmu, beta * self.theta_2)
            dSu += dSu_via_m
            dSuinv = - np.einsum('dab,dbc,dce->dae', self.Suhat, dSu, self.Suhat)
            dKuuinv = np.sum(dSuinv, axis=0)
            dtheta1 = beta * dSuinv
            deta2 = beta * np.einsum('dab,db->da', self.Suhat, dmu)
        else:
            Suhat = self.Suhat
            Suinv = self.Suinv
            mu = self.mu
            data_f_2 = np.einsum('dab,db->da', Suinv, mu)
            dSuhat_via_mhat = np.einsum('da,db->dab', dmu, beta * data_f_2)
            dSuhat = dSu + dSuhat_via_mhat
            dmuhat = dmu
            dSuhatinv = - np.einsum('dab,dbc,dce->dae', Suhat, dSuhat, Suhat)
            dSuinv_1 = beta * dSuhatinv
            Suhatdmu = np.einsum('dab,db->da', Suhat, dmuhat)
            dSuinv = dSuinv_1 + beta * np.einsum('da,db->dab', Suhatdmu, mu)
            dtheta1 = - np.einsum('dab,dbc,dce->dae', Suinv, dSuinv, Suinv)
            deta2 = beta * np.einsum('dab,db->da', Suinv, Suhatdmu)
            dKuuinv = (1 - beta) / beta * np.sum(dSuinv_1, axis=0)

        dtheta1T = np.transpose(dtheta1, [0, 2, 1])
        dtheta1_R = np.einsum('dab,dbc->dac', self.theta_1_R, dtheta1 + dtheta1T)
        deta1_R = np.zeros([self.Dout, self.M * (self.M + 1) / 2])
        for d in range(self.Dout):
            dtheta1_R_d = dtheta1_R[d, :, :]
            theta1_R_d = self.theta_1_R[d, :, :]
            dtheta1_R_d[diag_ind] = dtheta1_R_d[diag_ind] * theta1_R_d[diag_ind]
            dtheta1_R_d = dtheta1_R_d[triu_ind]
            deta1_R[d, :] = dtheta1_R_d.reshape((dtheta1_R_d.shape[0], ))

        return deta1_R, deta2, dKuuinv
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号