def logsumexp(v, reduction_indices=None, keep_dims=False):
if float(tf.__version__[:4]) > 0.10: # reduce_logsumexp does not exist below tfv0.11
if isinstance(reduction_indices, int): # due to a bug in tfv0.11
reduction_indices = [reduction_indices]
return handle_inf(
tf.reduce_logsumexp(v,
reduction_indices, # this is a bit fragile. reduction_indices got renamed to axis in tfv0.12
keep_dims=keep_dims)
)
else:
m = tf.reduce_max(v, reduction_indices=reduction_indices, keep_dims=keep_dims)
# Use SMALL_NUMBER to handle v = []
return m + tf.log(tf.reduce_sum(tf.exp(v - m),
reduction_indices=reduction_indices,
keep_dims=keep_dims) + SMALL_NUMBER)
评论列表
文章目录