def compute_cav_grad_u(self, dmu, dSu, alpha):
# return grads wrt u params and Kuuinv
triu_ind = np.triu_indices(self.M)
diag_ind = np.diag_indices(self.M)
beta = (self.N - alpha) * 1.0 / self.N
if self.nat_param:
dSu_via_m = np.einsum('da,db->dab', dmu, beta * self.theta_2)
dSu += dSu_via_m
dSuinv = - np.einsum('dab,dbc,dce->dae', self.Suhat, dSu, self.Suhat)
dKuuinv = np.sum(dSuinv, axis=0)
dtheta1 = beta * dSuinv
deta2 = beta * np.einsum('dab,db->da', self.Suhat, dmu)
else:
Suhat = self.Suhat
Suinv = self.Suinv
mu = self.mu
data_f_2 = np.einsum('dab,db->da', Suinv, mu)
dSuhat_via_mhat = np.einsum('da,db->dab', dmu, beta * data_f_2)
dSuhat = dSu + dSuhat_via_mhat
dmuhat = dmu
dSuhatinv = - np.einsum('dab,dbc,dce->dae', Suhat, dSuhat, Suhat)
dSuinv_1 = beta * dSuhatinv
Suhatdmu = np.einsum('dab,db->da', Suhat, dmuhat)
dSuinv = dSuinv_1 + beta * np.einsum('da,db->dab', Suhatdmu, mu)
dtheta1 = - np.einsum('dab,dbc,dce->dae', Suinv, dSuinv, Suinv)
deta2 = beta * np.einsum('dab,db->da', Suinv, Suhatdmu)
dKuuinv = (1 - beta) / beta * np.sum(dSuinv_1, axis=0)
dtheta1T = np.transpose(dtheta1, [0, 2, 1])
dtheta1_R = np.einsum('dab,dbc->dac', self.theta_1_R, dtheta1 + dtheta1T)
deta1_R = np.zeros([self.Dout, self.M * (self.M + 1) / 2])
for d in range(self.Dout):
dtheta1_R_d = dtheta1_R[d, :, :]
theta1_R_d = self.theta_1_R[d, :, :]
dtheta1_R_d[diag_ind] = dtheta1_R_d[diag_ind] * theta1_R_d[diag_ind]
dtheta1_R_d = dtheta1_R_d[triu_ind]
deta1_R[d, :] = dtheta1_R_d.reshape((dtheta1_R_d.shape[0], ))
return deta1_R, deta2, dKuuinv
评论列表
文章目录