def _log_prob(self, given):
mean, cov_tril = (self.path_param(self.mean),
self.path_param(self.cov_tril))
log_det = 2 * tf.reduce_sum(
tf.log(tf.matrix_diag_part(cov_tril)), axis=-1)
N = tf.cast(self._n_dim, self.dtype)
logZ = - N / 2 * tf.log(2 * tf.constant(np.pi, dtype=self.dtype)) - \
log_det / 2
# logZ.shape == batch_shape
if self._check_numerics:
logZ = tf.check_numerics(logZ, "log[det(Cov)]")
# (given-mean)' Sigma^{-1} (given-mean) =
# (g-m)' L^{-T} L^{-1} (g-m) = |x|^2, where Lx = g-m =: y.
y = tf.expand_dims(given - mean, -1)
L, _ = maybe_explicit_broadcast(
cov_tril, y, 'MultivariateNormalCholesky.cov_tril',
'expand_dims(given, -1)')
x = tf.matrix_triangular_solve(L, y, lower=True)
x = tf.squeeze(x, -1)
stoc_dist = -0.5 * tf.reduce_sum(tf.square(x), axis=-1)
return logZ + stoc_dist
评论列表
文章目录