def forward_prop_random_thru_post_mm(self, model, mx, vx, mu, Su):
Kuu_noiseless = compute_kernel(
2 * model.ls, 2 * model.sf, model.zu, model.zu)
Kuu = Kuu_noiseless + np.diag(jitter * np.ones((self.M, )))
# TODO: remove inv
Kuuinv = np.linalg.inv(Kuu)
A = np.dot(Kuuinv, mu)
Smm = Su + np.outer(mu, mu)
B_sto = np.dot(Kuuinv, np.dot(Smm, Kuuinv)) - Kuuinv
psi0 = np.exp(2.0 * model.sf)
psi1, psi2 = compute_psi_weave(
2 * model.ls, 2 * model.sf, mx, vx, model.zu)
mout = np.einsum('nm,md->nd', psi1, A)
Bpsi2 = np.einsum('ab,nab->n', B_sto, psi2)[:, np.newaxis]
vout = psi0 + Bpsi2 - mout**2
return mout, vout
评论列表
文章目录