def __call__(self, loc, val, y, train=True):
bs = val.data.shape[0]
pred, kld0, kld1, kld2 = self.forward(loc, val, y, train=train)
# Compute MSE loss
mse = F.mean_squared_error(pred, y)
rmse = F.sqrt(mse) # Only used for reporting
# Now compute the total KLD loss
kldt = kld0 * self.lambda0 + kld1 * self.lambda1 + kld2 * self.lambda2
# Total loss is MSE plus regularization losses
loss = mse + kldt * (1.0 / self.total_nobs)
# Log the errors
logs = {'loss': loss, 'rmse': rmse, 'kld0': kld0, 'kld1': kld1,
'kld2': kld2, 'kldt': kldt, 'bias': F.sum(self.bias_mu.b)}
reporter.report(logs, self)
return loss
评论列表
文章目录