def log_Bernoulli(x, mean, average=False, dim=None):
probs = torch.clamp( mean, min=1e-7, max=1.-1e-7 )
log_bernoulli = x * torch.log( probs ) + (1. - x ) * torch.log( 1. - probs )
if average:
return torch.mean( log_bernoulli, dim )
else:
return torch.sum( log_bernoulli, dim )
评论列表
文章目录