def loss(self, x, samples):
# FoldedNormal logpdf
# https://en.wikipedia.org/wiki/Folded_normal_distribution
_, proposal_output = self.forward(x, samples)
batch_size = len(samples)
locations = proposal_output[:, 0]
scales = proposal_output[:, 1]
two_scales = 2 * scales + util.epsilon
half_log_two_pi_scales = 0.5 * torch.log(math.pi * two_scales + util.epsilon)
l = 0
for b in range(batch_size):
value = samples[b].value[0]
if value < 0:
l -= 0
else:
location = locations[b]
two_scale = two_scales[b]
half_log_two_pi_scale = half_log_two_pi_scales[b]
logpdf_1 = -half_log_two_pi_scale - ((value - location)**2) / two_scale
logpdf_2 = -half_log_two_pi_scale - ((value + location)**2) / two_scale
l -= util.logsumexp(torch.cat([logpdf_1, logpdf_2]))
return l
评论列表
文章目录