def kl(dist_a, dist_b, allow_nan=False, name=None):
"""Get the KL-divergence KL(dist_a || dist_b).
If there is no KL method registered specifically for `type(dist_a)` and
`type(dist_b)`, then the class hierarchies of these types are searched.
If one KL method is registered between any pairs of classes in these two
parent hierarchies, it is used.
If more than one such registered method exists, the method whose registered
classes have the shortest sum MRO paths to the input types is used.
If more than one such shortest path exists, the first method
identified in the search is used (favoring a shorter MRO distance to
`type(dist_a)`).
Args:
dist_a: The first distribution.
dist_b: The second distribution.
allow_nan: If `False` (default), a runtime error is raised
if the KL returns NaN values for any batch entry of the given
distributions. If `True`, the KL may return a NaN for the given entry.
name: (optional) Name scope to use for created operations.
Returns:
A Tensor with the batchwise KL-divergence between dist_a and dist_b.
Raises:
NotImplementedError: If no KL method is defined for distribution types
of dist_a and dist_b.
"""
kl_fn = _registered_kl(type(dist_a), type(dist_b))
if kl_fn is None:
raise NotImplementedError(
"No KL(dist_a || dist_b) registered for dist_a type %s and dist_b "
"type %s" % ((type(dist_a).__name__, type(dist_b).__name__)))
with ops.name_scope("KullbackLeibler"):
kl_t = kl_fn(dist_a, dist_b, name=name)
if allow_nan:
return kl_t
# Check KL for NaNs
kl_t = array_ops.identity(kl_t, name="kl")
with ops.control_dependencies([
control_flow_ops.Assert(
math_ops.logical_not(
math_ops.reduce_any(math_ops.is_nan(kl_t))),
["KL calculation between %s and %s returned NaN values "
"(and was called with allow_nan=False). Values:"
% (dist_a.name, dist_b.name), kl_t])]):
return array_ops.identity(kl_t, name="checked_kl")
评论列表
文章目录