def kl_divergence(p_samples, q_samples):
# estimate densities
# p_samples = np.nan_to_num(p_samples)
# q_samples = np.nan_to_num(q_samples)
if isinstance(p_samples, tuple):
idx, p_samples = p_samples
if idx not in _cached_p_pdf:
_cached_p_pdf[idx] = sc.gaussian_kde(p_samples)
p_pdf = _cached_p_pdf[idx]
else:
p_pdf = sc.gaussian_kde(p_samples)
q_pdf = sc.gaussian_kde(q_samples)
# joint support
left = min(min(p_samples), min(q_samples))
right = max(max(p_samples), max(q_samples))
p_samples_num = p_samples.shape[0]
q_samples_num = q_samples.shape[0]
# quantise
lin = np.linspace(left, right, min(max(p_samples_num, q_samples_num), MAX_GRID_POINTS))
p = p_pdf.pdf(lin)
q = q_pdf.pdf(lin)
# KL
kl = min(sc.entropy(p, q), MAX_KL)
return kl
评论列表
文章目录