def exKxz_pairwise(self, Z, Xmu, Xcov):
"""
<x_t K_{x_{t-1}, Z}>_q_{x_{t-1:t}}
:param Z: MxD inducing inputs
:param Xmu: X mean (N+1xD)
:param Xcov: 2x(N+1)xDxD
:return: NxMxD
"""
msg_input_shape = "Currently cannot handle slicing in exKxz_pairwise."
assert_input_shape = tf.assert_equal(tf.shape(Xmu)[1], self.input_dim, message=msg_input_shape)
assert_cov_shape = tf.assert_equal(tf.shape(Xmu), tf.shape(Xcov)[1:3], name="assert_Xmu_Xcov_shape")
with tf.control_dependencies([assert_input_shape, assert_cov_shape]):
Xmu = tf.identity(Xmu)
N = tf.shape(Xmu)[0] - 1
D = tf.shape(Xmu)[1]
Xsigmb = tf.slice(Xcov, [0, 0, 0, 0], tf.stack([-1, N, -1, -1]))
Xsigm = Xsigmb[0, :, :, :] # NxDxD
Xsigmc = Xsigmb[1, :, :, :] # NxDxD
Xmum = tf.slice(Xmu, [0, 0], tf.stack([N, -1]))
Xmup = Xmu[1:, :]
lengthscales = self.lengthscales if self.ARD else tf.zeros((D,), dtype=settings.float_type) + self.lengthscales
scalemat = tf.expand_dims(tf.matrix_diag(lengthscales ** 2.0), 0) + Xsigm # NxDxD
det = tf.matrix_determinant(
tf.expand_dims(tf.eye(tf.shape(Xmu)[1], dtype=settings.float_type), 0) +
tf.reshape(lengthscales ** -2.0, (1, 1, -1)) * Xsigm) # N
vec = tf.expand_dims(tf.transpose(Z), 0) - tf.expand_dims(Xmum, 2) # NxDxM
smIvec = tf.matrix_solve(scalemat, vec) # NxDxM
q = tf.reduce_sum(smIvec * vec, [1]) # NxM
addvec = tf.matmul(smIvec, Xsigmc, transpose_a=True) + tf.expand_dims(Xmup, 1) # NxMxD
return self.variance * addvec * tf.reshape(det ** -0.5, (N, 1, 1)) * tf.expand_dims(tf.exp(-0.5 * q), 2)
评论列表
文章目录