def mvnquad(func, means, covs, H, Din, Dout=()):
"""
Computes N Gaussian expectation integrals of a single function 'f'
using Gauss-Hermite quadrature.
:param f: integrand function. Takes one input of shape ?xD.
:param means: NxD
:param covs: NxDxD
:param H: Number of Gauss-Hermite evaluation points.
:param Din: Number of input dimensions. Needs to be known at call-time.
:param Dout: Number of output dimensions. Defaults to (). Dout is assumed
to leave out the item index, i.e. f actually maps (?xD)->(?x*Dout).
:return: quadratures (N,*Dout)
"""
xn, wn = mvhermgauss(H, Din)
N = tf.shape(means)[0]
# transform points based on Gaussian parameters
cholXcov = tf.cholesky(covs) # NxDxD
Xt = tf.matmul(cholXcov, tf.tile(xn[None, :, :], (N, 1, 1)), transpose_b=True) # NxDxH**D
X = 2.0 ** 0.5 * Xt + tf.expand_dims(means, 2) # NxDxH**D
Xr = tf.reshape(tf.transpose(X, [2, 0, 1]), (-1, Din)) # (H**D*N)xD
# perform quadrature
fX = tf.reshape(func(Xr), (H ** Din, N,) + Dout)
wr = np.reshape(wn * np.pi ** (-Din * 0.5),
(-1,) + (1,) * (1 + len(Dout)))
return tf.reduce_sum(fX * wr, 0)
评论列表
文章目录