def softplus_inverse(x, name=None):
"""Computes the inverse softplus, i.e., x = softplus_inverse(softplus(x)).
Mathematically this op is equivalent to:
```none
softplus_inverse = log(exp(x) - 1.)
Args:
x: Tensor
. Non-negative (not enforced), floating-point.
name: A name for the operation (optional).
Returns:
Tensor
. Has the same type/shape as input x
.
"""
with ops.name_scope(name, "softplus_inverse", values=[x]):
x = ops.convert_to_tensor(x, name="x")
# We begin by deriving a more numerically stable softplus_inverse:
# x = softplus(y) = Log[1 + exp{y}], (which means x > 0).
# ==> exp{x} = 1 + exp{y} (1)
# ==> y = Log[exp{x} - 1] (2)
# = Log[(exp{x} - 1) / exp{x}] + Log[exp{x}]
# = Log[(1 - exp{-x}) / 1] + Log[exp{x}]
# = Log[1 - exp{-x}] + x (3)
# (2) is the "obvious" inverse, but (3) is more stable than (2) for large x.
# For small x (e.g. x = 1e-10), (3) will become -inf since 1 - exp{-x} will
# be zero. To fix this, we use 1 - exp{-x} approx x for small x > 0.
#
# In addition to the numerically stable derivation above, we clamp
# small/large values to be congruent with the logic in:
# tensorflow/core/kernels/softplus_op.h
#
# Finally, we set the input to one whenever the input is too large or too
# small. This ensures that no unchosen codepath is +/- inf. This is
# necessary to ensure the gradient doesn't get NaNs. Recall that the
# gradient of `where` behaves like `pred*pred_true + (1-pred)*pred_false`
# thus an `inf` in an unselected path results in `0*inf=nan`. We are careful
# to overwrite `x` with ones only when we will never actually use this
# value. Note that we use ones and not zeros since `log(expm1(0.)) = -inf`.
threshold = np.log(np.finfo(x.dtype.as_numpy_dtype).eps) + 2.
is_too_small = math_ops.less(x, np.exp(threshold))
is_too_large = math_ops.greater(x, -threshold)
too_small_value = math_ops.log(x)
too_large_value = x
# This `where` will ultimately be a NOP because we won't select this
# codepath whenever we used the surrogate `ones_like`.
x = array_ops.where(math_ops.logical_or(is_too_small, is_too_large),
array_ops.ones_like(x), x)
y = x + math_ops.log(-math_ops.expm1(-x)) # == log(expm1(x))
return array_ops.where(is_too_small, too_small_value,
array_ops.where(is_too_large, too_large_value, y))
```