def _build_predict(self, Xnew, full_cov=False):
"""
Compute the mean and variance of the latent function at some new points
Xnew.
"""
_, _, Luu, L, _, _, gamma = self._build_common_terms()
Kus = self.feature.Kuf(self.kern, Xnew) # size M x Xnew
w = tf.matrix_triangular_solve(Luu, Kus, lower=True) # size M x Xnew
tmp = tf.matrix_triangular_solve(tf.transpose(L), gamma, lower=False)
mean = tf.matmul(w, tmp, transpose_a=True) + self.mean_function(Xnew)
intermediateA = tf.matrix_triangular_solve(L, w, lower=True)
if full_cov:
var = self.kern.K(Xnew) - tf.matmul(w, w, transpose_a=True) \
+ tf.matmul(intermediateA, intermediateA, transpose_a=True)
var = tf.tile(tf.expand_dims(var, 2), tf.stack([1, 1, tf.shape(self.Y)[1]]))
else:
var = self.kern.Kdiag(Xnew) - tf.reduce_sum(tf.square(w), 0) \
+ tf.reduce_sum(tf.square(intermediateA), 0) # size Xnew,
var = tf.tile(tf.expand_dims(var, 1), tf.stack([1, tf.shape(self.Y)[1]]))
return mean, var
评论列表
文章目录