def compute_posterior_grad_u(self, dmu, dSu):
# return grads wrt u params and Kuuinv
triu_ind = np.triu_indices(self.M)
diag_ind = np.diag_indices(self.M)
if self.nat_param:
dSu_via_m = np.einsum('da,db->dab', dmu, self.theta_2)
dSu += dSu_via_m
dSuinv = - np.einsum('dab,dbc,dce->dae', self.Su, dSu, self.Su)
dKuuinv = np.sum(dSuinv, axis=0)
dtheta1 = dSuinv
deta2 = np.einsum('dab,db->da', self.Su, dmu)
else:
deta2 = dmu
dtheta1 = dSu
dKuuinv = 0
dtheta1T = np.transpose(dtheta1, [0, 2, 1])
dtheta1_R = np.einsum('dab,dbc->dac', self.theta_1_R, dtheta1 + dtheta1T)
deta1_R = np.zeros([self.Dout, self.M * (self.M + 1) / 2])
for d in range(self.Dout):
dtheta1_R_d = dtheta1_R[d, :, :]
theta1_R_d = self.theta_1_R[d, :, :]
dtheta1_R_d[diag_ind] = dtheta1_R_d[diag_ind] * theta1_R_d[diag_ind]
dtheta1_R_d = dtheta1_R_d[triu_ind]
deta1_R[d, :] = dtheta1_R_d.reshape((dtheta1_R_d.shape[0], ))
return deta1_R, deta2, dKuuinv
评论列表
文章目录