diagonal_gaussian.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:third_person_im 作者: bstadie 项目源码 文件源码
def kl_sym(self, old_dist_info_vars, new_dist_info_vars):
        old_means = old_dist_info_vars["mean"]
        old_log_stds = old_dist_info_vars["log_std"]
        new_means = new_dist_info_vars["mean"]
        new_log_stds = new_dist_info_vars["log_std"]
        """
        Compute the KL divergence of two multivariate Gaussian distribution with
        diagonal covariance matrices
        """
        old_std = TT.exp(old_log_stds)
        new_std = TT.exp(new_log_stds)
        # means: (N*A)
        # std: (N*A)
        # formula:
        # { (\mu_1 - \mu_2)^2 + \sigma_1^2 - \sigma_2^2 } / (2\sigma_2^2) +
        # ln(\sigma_2/\sigma_1)
        numerator = TT.square(old_means - new_means) + \
                    TT.square(old_std) - TT.square(new_std)
        denominator = 2 * TT.square(new_std) + 1e-8
        return TT.sum(
            numerator / denominator + new_log_stds - old_log_stds, axis=-1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号