utils.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:zhusuan 作者: thu-ml 项目源码 文件源码
def log_sum_exp(x, axis=None, keep_dims=False):
    """
    Deprecated: Use tf.reduce_logsumexp().

    Tensorflow numerically stable log sum of exps across the `axis`.

    :param x: A Tensor or numpy array.
    :param axis: An int or list or tuple. The dimensions to reduce.
        If `None` (the default), reduces all dimensions.
    :param keep_dims: Bool. If true, retains reduced dimensions with length 1.
        Default to be False.

    :return: A Tensor after the computation of log sum exp along given axes of
        x.
    """
    x = tf.cast(x, dtype=tf.float32)
    x_max = tf.reduce_max(x, axis=axis, keep_dims=True)
    ret = tf.log(tf.reduce_sum(tf.exp(x - x_max), axis=axis,
                               keep_dims=True)) + x_max
    if not keep_dims:
        ret = tf.reduce_sum(ret, axis=axis)
    return ret
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号