util.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:pyro 作者: uber 项目源码 文件源码
def log_sum_exp(vecs):
    n = len(vecs.size())
    if n == 1:
        vecs = vecs.view(1, -1)
    _, idx = torch.max(vecs, 1)
    max_score = torch.index_select(vecs, 1, idx.view(-1))
    ret = max_score + torch.log(torch.sum(torch.exp(vecs - max_score.expand_as(vecs))))
    if n == 1:
        return ret.view(-1)
    return ret
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号