def Linear_RBF_eKxzKzx(self, Ka, Kb, Z, Xmu, Xcov):
Xcov = self._slice_cov(Xcov)
Z, Xmu = self._slice(Z, Xmu)
lin, rbf = (Ka, Kb) if isinstance(Ka, Linear) else (Kb, Ka)
if not isinstance(lin, Linear):
TypeError("{in_lin} is not {linear}".format(in_lin=str(type(lin)), linear=str(Linear)))
if not isinstance(rbf, RBF):
TypeError("{in_rbf} is not {rbf}".format(in_rbf=str(type(rbf)), rbf=str(RBF)))
if lin.ARD or type(lin.active_dims) is not slice or type(rbf.active_dims) is not slice:
raise NotImplementedError("Active dims and/or Linear ARD not implemented. "
"Switching to quadrature.")
D = tf.shape(Xmu)[1]
M = tf.shape(Z)[0]
N = tf.shape(Xmu)[0]
if rbf.ARD:
lengthscales = rbf.lengthscales
else:
lengthscales = tf.zeros((D, ), dtype=settings.float_type) + rbf.lengthscales
lengthscales2 = lengthscales ** 2.0
const = rbf.variance * lin.variance * tf.reduce_prod(lengthscales)
gaussmat = Xcov + tf.matrix_diag(lengthscales2)[None, :, :] # NxDxD
det = tf.matrix_determinant(gaussmat) ** -0.5 # N
cgm = tf.cholesky(gaussmat) # NxDxD
tcgm = tf.tile(cgm[:, None, :, :], [1, M, 1, 1])
vecmin = Z[None, :, :] - Xmu[:, None, :] # NxMxD
d = tf.matrix_triangular_solve(tcgm, vecmin[:, :, :, None]) # NxMxDx1
exp = tf.exp(-0.5 * tf.reduce_sum(d ** 2.0, [2, 3])) # NxM
# exp = tf.Print(exp, [tf.shape(exp)])
vecplus = (Z[None, :, :, None] / lengthscales2[None, None, :, None] +
tf.matrix_solve(Xcov, Xmu[:, :, None])[:, None, :, :]) # NxMxDx1
mean = tf.cholesky_solve(
tcgm, tf.matmul(tf.tile(Xcov[:, None, :, :], [1, M, 1, 1]), vecplus))
mean = mean[:, :, :, 0] * lengthscales2[None, None, :] # NxMxD
a = tf.matmul(tf.tile(Z[None, :, :], [N, 1, 1]),
mean * exp[:, :, None] * det[:, None, None] * const, transpose_b=True)
return a + tf.transpose(a, [0, 2, 1])
评论列表
文章目录