def log_beta(t):
"""
Computes log Beta function.
:param t:
:type t: torch.autograd.Variable of dimension 1 or 2
:rtype: torch.autograd.Variable of float (if t.dim() == 1) or torch.Tensor (if t.dim() == 2)
"""
assert t.dim() in (1, 2)
if t.dim() == 1:
numer = torch.sum(log_gamma(t))
denom = log_gamma(torch.sum(t))
else:
numer = torch.sum(log_gamma(t), 1)
denom = log_gamma(torch.sum(t, 1))
return numer - denom
评论列表
文章目录