def logaddexp(x1: T.FloatTensor, x2: T.FloatTensor) -> T.FloatTensor:
"""
Elementwise logaddexp function: log(exp(x1) + exp(x2))
Args:
x1: A tensor.
x2: A tensor.
Returns:
tensor: Elementwise logaddexp.
"""
# log(exp(x1) + exp(x2))
# = log( exp(x1) (1 + exp(x2 - x1))) = x1 + log(1 + exp(x2 - x1))
# = log( exp(x2) (exp(x1 - x2) + 1)) = x2 + log(1 + exp(x1 - x2))
diff = torch.min(x2 - x1, x1 - x2)
return torch.max(x1, x2) + torch.log1p(exp(diff))
评论列表
文章目录