def log_sum_exp(input, keepdim=False):
assert input.dim() == 2
max_scores, _ = input.max(dim=-1, keepdim=True)
output = input - max_scores.expand_as(input)
return max_scores + torch.log(torch.sum(torch.exp(output), dim=-1, keepdim=keepdim))
评论列表
文章目录