vfm.py 文件源码

python
阅读 69 收藏 0 点赞 0 评论 0

项目:vfm 作者: cemoody 项目源码 文件源码
def __call__(self, loc, val, y, train=True):
        bs = val.data.shape[0]
        pred, kld0, kld1, kld2 = self.forward(loc, val, y, train=train)

        # Compute MSE loss
        mse = F.mean_squared_error(pred, y)
        rmse = F.sqrt(mse)  # Only used for reporting

        # Now compute the total KLD loss
        kldt = kld0 * self.lambda0 + kld1 * self.lambda1 + kld2 * self.lambda2

        # Total loss is MSE plus regularization losses
        loss = mse + kldt * (1.0 / self.total_nobs)

        # Log the errors
        logs = {'loss': loss, 'rmse': rmse, 'kld0': kld0, 'kld1': kld1,
                'kld2': kld2, 'kldt': kldt, 'bias': F.sum(self.bias_mu.b)}
        reporter.report(logs, self)
        return loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号