def kl(self, old_dist_info, new_dist_info):
old_means = old_dist_info["mean"]
old_log_stds = old_dist_info["log_std"]
new_means = new_dist_info["mean"]
new_log_stds = new_dist_info["log_std"]
"""
Compute the KL divergence of two multivariate Gaussian distribution with
diagonal covariance matrices
"""
old_std = np.exp(old_log_stds)
new_std = np.exp(new_log_stds)
# means: (N*A)
# std: (N*A)
# formula:
# { (\mu_1 - \mu_2)^2 + \sigma_1^2 - \sigma_2^2 } / (2\sigma_2^2) +
# ln(\sigma_2/\sigma_1)
numerator = np.square(old_means - new_means) + \
np.square(old_std) - np.square(new_std)
denominator = 2 * np.square(new_std) + 1e-8
return np.sum(
numerator / denominator + new_log_stds - old_log_stds, axis=-1)
评论列表
文章目录