def log_sum_exp(vecs):
n = len(vecs.size())
if n == 1:
vecs = vecs.view(1, -1)
_, idx = torch.max(vecs, 1)
max_score = torch.index_select(vecs, 1, idx.view(-1))
ret = max_score + torch.log(torch.sum(torch.exp(vecs - max_score.expand_as(vecs))))
if n == 1:
return ret.view(-1)
return ret
评论列表
文章目录