def __call__(self, loc, val, y, train=True):
bs = val.data.shape[0]
ret = self.forward(loc, val, y, train=train)
pred, kld0, kld1, kldg, kldi, hypg, hypi = ret
# Compute MSE loss
mse = F.mean_squared_error(pred, y)
rmse = F.sqrt(mse) # Only used for reporting
# Now compute the total KLD loss
kldt = kld0 * self.lambda0 + kld1 * self.lambda1
kldt += kldg + kldi + hypg + hypi
# Total loss is MSE plus regularization losses
loss = mse + kldt * (1.0 / self.total_nobs)
# Log the errors
logs = {'loss': loss, 'rmse': rmse, 'kld0': kld0, 'kld1': kld1,
'kldg': kldg, 'kldi': kldi, 'hypg': hypg, 'hypi': hypi,
'hypglv': F.sum(self.hyper_feat_lv_vec.b),
'hypilv': F.sum(self.hyper_feat_delta_lv.b),
'kldt': kldt, 'bias': F.sum(self.bias_mu.b)}
reporter.report(logs, self)
return loss
评论列表
文章目录