def exKxz_pairwise(self, Z, Xmu, Xcov):
"""
Computes <x_{t-1} K_{x_t z}>_q(x) for each pair of consecutive X's in
Xmu & Xcov.
:param Z: Fixed inputs (MxD).
:param Xmu: X means (T+1xD).
:param Xcov: 2xT+1xDxD. [0, t, :, :] contains covariances for x_t. [1, t, :, :] contains the cross covariances
for t and t+1.
:return: (TxMxD).
"""
self._check_quadrature()
# Slicing is NOT needed here. The desired behaviour is to *still* return an NxMxD matrix. As even when the
# kernel does not depend on certain inputs, the output matrix will still contain the outer product between the
# mean of x_{t-1} and K_{x_t Z}. The code here will do this correctly automatically, since the quadrature will
# still be done over the distribution x_{t-1, t}, only now the kernel will not depend on certain inputs.
# However, this does mean that at the time of running this function we need to know the input *size* of Xmu, not
# just `input_dim`.
M = tf.shape(Z)[0]
D = self.input_size if hasattr(self, 'input_size') else self.input_dim # Number of actual input dimensions
with tf.control_dependencies([
tf.assert_equal(tf.shape(Xmu)[1], tf.constant(D, dtype=settings.tf_int),
message="Numerical quadrature needs to know correct shape of Xmu.")
]):
Xmu = tf.identity(Xmu)
# First, transform the compact representation of Xmu and Xcov into a
# list of full distributions.
fXmu = tf.concat((Xmu[:-1, :], Xmu[1:, :]), 1) # Nx2D
fXcovt = tf.concat((Xcov[0, :-1, :, :], Xcov[1, :-1, :, :]), 2) # NxDx2D
fXcovb = tf.concat((tf.transpose(Xcov[1, :-1, :, :], (0, 2, 1)), Xcov[0, 1:, :, :]), 2)
fXcov = tf.concat((fXcovt, fXcovb), 1)
return mvnquad(lambda x: tf.expand_dims(self.K(x[:, :D], Z), 2) *
tf.expand_dims(x[:, D:], 1),
fXmu, fXcov, self.num_gauss_hermite_points,
2 * D, Dout=(M, D))
评论列表
文章目录