multivariate.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:zhusuan 作者: thu-ml 项目源码 文件源码
def _log_prob(self, given):
        mean, cov_tril = (self.path_param(self.mean),
                          self.path_param(self.cov_tril))
        log_det = 2 * tf.reduce_sum(
            tf.log(tf.matrix_diag_part(cov_tril)), axis=-1)
        N = tf.cast(self._n_dim, self.dtype)
        logZ = - N / 2 * tf.log(2 * tf.constant(np.pi, dtype=self.dtype)) - \
            log_det / 2
        # logZ.shape == batch_shape
        if self._check_numerics:
            logZ = tf.check_numerics(logZ, "log[det(Cov)]")
        # (given-mean)' Sigma^{-1} (given-mean) =
        # (g-m)' L^{-T} L^{-1} (g-m) = |x|^2, where Lx = g-m =: y.
        y = tf.expand_dims(given - mean, -1)
        L, _ = maybe_explicit_broadcast(
            cov_tril, y, 'MultivariateNormalCholesky.cov_tril',
            'expand_dims(given, -1)')
        x = tf.matrix_triangular_solve(L, y, lower=True)
        x = tf.squeeze(x, -1)
        stoc_dist = -0.5 * tf.reduce_sum(tf.square(x), axis=-1)
        return logZ + stoc_dist
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号