gmm_ops.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def _define_full_covariance_probs(self, shard_id, shard):
    """Defines the full covariance probabilties per example in a class.

    Updates a matrix with dimension num_examples X num_classes.

    Args:
      shard_id: id of the current shard.
      shard: current data shard, 1 X num_examples X dimensions.
    """
    diff = shard - self._means
    cholesky = tf.cholesky(self._covs + self._min_var)
    log_det_covs = 2.0 * tf.reduce_sum(tf.log(tf.matrix_diag_part(cholesky)), 1)
    x_mu_cov = tf.square(
        tf.matrix_triangular_solve(
            cholesky, tf.transpose(
                diff, perm=[0, 2, 1]), lower=True))
    diag_m = tf.transpose(tf.reduce_sum(x_mu_cov, 1))
    self._probs[shard_id] = -0.5 * (
        diag_m + tf.to_float(self._dimensions) * tf.log(2 * np.pi) +
        log_det_covs)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号