def weights_rbf(self, unit_sp, hypers):
d, n = unit_sp.shape
# GP kernel hyper-parameters
alpha, el, jitter = hypers['sig_var'], hypers['lengthscale'], hypers['noise_var']
assert len(el) == d
# pre-allocation for convenience
eye_d, eye_n, eye_y = np.eye(d), np.eye(n), np.eye(n + d * n)
K = self.kern_eq_der(unit_sp, hypers) # evaluate kernel matrix BOTTLENECK
iK = cho_solve(cho_factor(K + jitter * eye_y), eye_y) # invert kernel matrix BOTTLENECK
Lam = np.diag(el ** 2)
iLam = np.diag(el ** -1) # sqrt(Lambda^-1)
iiLam = np.diag(el ** -2) # Lambda^-1
inn = iLam.dot(unit_sp) # (x-m)^T*iLam # (N, D)
B = iiLam + eye_d # P*Lambda^-1+I, (P+Lam)^-1 = Lam^-1*(P*Lam^-1+I)^-1 # (D, D)
cho_B = cho_factor(B)
t = cho_solve(cho_B, inn) # dot(inn, inv(B)) # (x-m)^T*iLam*(P+Lambda)^-1 # (D, N)
l = np.exp(-0.5 * np.sum(inn * t, 0)) # (N, 1)
q = (alpha ** 2 / np.sqrt(det(B))) * l # (N, 1)
Sig_q = cho_solve(cho_B, eye_d) # B^-1*I
eta = Sig_q.dot(unit_sp) # (D,N) Sig_q*x
mu_q = iiLam.dot(eta) # (D,N)
r = q[na, :] * iiLam.dot(mu_q - unit_sp) # -t.dot(iLam) * q # (D, N)
q_tilde = np.hstack((q.T, r.T.ravel())) # (1, N+N*D)
# weights for mean
wm = q_tilde.dot(iK)
# quantities for cross-covariance "weights"
iLamSig = iiLam.dot(Sig_q) # (D,D)
r_tilde = (q[na, na, :] * iLamSig[..., na] + mu_q[na, ...] * r[:, na, :]).T.reshape(n * d, d).T # (D, N*D)
R_tilde = np.hstack((q[na, :] * mu_q, r_tilde)) # (D, N+N*D)
# input-output covariance (cross-covariance) "weights"
Wcc = R_tilde.dot(iK) # (D, N+N*D)
# quantities for covariance weights
zet = 2 * np.log(alpha) - 0.5 * np.sum(inn * inn, 0) # (D,N) 2log(alpha) - 0.5*(x-m)^T*Lambda^-1*(x-m)
inn = iiLam.dot(unit_sp) # inp / el[:, na]**2
R = 2 * iiLam + eye_d # 2P*Lambda^-1 + I
# (N,N)
Q = (1.0 / np.sqrt(det(R))) * np.exp((zet[:, na] + zet[:, na].T) + maha(inn.T, -inn.T, V=0.5 * solve(R, eye_d)))
cho_LamSig = cho_factor(Lam + Sig_q)
Sig_Q = cho_solve(cho_LamSig, Sig_q).dot(iiLam) # (D,D) Lambda^-1 (Lambda*(Lambda+Sig_q)^-1*Sig_q) Lambda^-1
eta_tilde = iiLam.dot(cho_solve(cho_LamSig, eta)) # Lambda^-1(Lambda+Sig_q)^-1*eta
ETA = eta_tilde[..., na] + eta_tilde[:, na, :] # (D,N,N) pairwise sum of pre-multiplied eta's (D,N,N)
# mu_Q = ETA + in_mean[:, na] # (D,N,N)
xnmu = inn[..., na] - ETA # (D,N,N) x_n - mu^Q_nm
# xmmu = sigmas[:, na, :] - mu_Q # x_m - mu^Q_nm
E_dff = (-Q[na, ...] * xnmu).swapaxes(0, 1).reshape(d * n, n)
# (D,D,N,N) (x_n - mu^Q_nm)(x_m - mu^Q_nm)^T + Sig_Q
T = xnmu[:, na, ...] * xnmu.swapaxes(1, 2)[na, ...] + Sig_Q[..., na, na]
E_dffd = (Q[na, na, ...] * T).swapaxes(0, 3).reshape(d * n, -1) # (N*D, N*D)
Q_tilde = np.vstack((np.hstack((Q, E_dff.T)), np.hstack((E_dff, E_dffd)))) # (N+N*D, N+N*D)
# weights for covariance
iKQ = iK.dot(Q_tilde)
Wc = iKQ.dot(iK)
# model variance
self.model_var = np.diag((alpha ** 2 - np.trace(iKQ)) * np.ones((d, 1)))
return wm, Wc, Wcc
评论列表
文章目录