terpret_tf_log_runtime.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:TerpreT 作者: 51alg 项目源码 文件源码
def logsumexp(v, reduction_indices=None, keep_dims=False):
    if float(tf.__version__[:4]) > 0.10: # reduce_logsumexp does not exist below tfv0.11
        if isinstance(reduction_indices, int): # due to a bug in tfv0.11
            reduction_indices = [reduction_indices]
        return handle_inf(
                 tf.reduce_logsumexp(v,
                  reduction_indices, # this is a bit fragile. reduction_indices got renamed to axis in tfv0.12
                  keep_dims=keep_dims)
                 )
    else:
        m = tf.reduce_max(v, reduction_indices=reduction_indices, keep_dims=keep_dims)
        # Use SMALL_NUMBER to handle v = []
        return m + tf.log(tf.reduce_sum(tf.exp(v - m), 
                        reduction_indices=reduction_indices,
                        keep_dims=keep_dims) + SMALL_NUMBER)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号