def loss(self, x, samples):
_, proposal_output = self.forward(x, samples)
batch_size = len(samples)
means = proposal_output[:, 0]
stds = proposal_output[:, 1]
two_std_squares = 2 * stds * stds + util.epsilon
two_pi_std_squares = math.pi * two_std_squares
half_log_two_pi_std_squares = 0.5 * torch.log(two_pi_std_squares + util.epsilon)
l = 0
for b in range(batch_size):
value = samples[b].value[0]
mean = means[b]
two_std_square = two_std_squares[b]
half_log_two_pi_std_square = half_log_two_pi_std_squares[b]
l += half_log_two_pi_std_square + ((value - mean)**2) / two_std_square
return l
评论列表
文章目录