def loss(self, x, samples):
_, proposal_output = self.forward(x, samples)
batch_size = len(samples)
means = proposal_output[:,0:self.mixture_components]
stds = proposal_output[:,self.mixture_components:2*self.mixture_components]
coeffs = proposal_output[:,2*self.mixture_components:3*self.mixture_components]
l = 0
for b in range(batch_size):
value = samples[b].value[0]
prior_min = samples[b].distribution.prior_min
prior_max = samples[b].distribution.prior_max
ll = 0
for c in range(self.mixture_components):
mean = means[b,c]
std = stds[b,c]
coeff = coeffs[b,c]
xi = (value - mean) / std
phi_min = 0.5 * (1 + util.erf(((prior_min - mean) / std) * util.one_over_sqrt_two))
phi_max = 0.5 * (1 + util.erf(((prior_max - mean) / std) * util.one_over_sqrt_two))
ll += coeff * util.one_over_sqrt_two_pi * torch.exp(-0.5 * xi * xi) / (std * (phi_max - phi_min))
l -= torch.log(ll + util.epsilon)
return l
评论列表
文章目录