auto_vfm.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:vfm 作者: cemoody 项目源码 文件源码
def __call__(self, loc, val, y, train=True):
        bs = val.data.shape[0]
        ret = self.forward(loc, val, y, train=train)
        pred, kld0, kld1, kldg, kldi, hypg, hypi = ret

        # Compute MSE loss
        mse = F.mean_squared_error(pred, y)
        rmse = F.sqrt(mse)  # Only used for reporting

        # Now compute the total KLD loss
        kldt = kld0 * self.lambda0 + kld1 * self.lambda1
        kldt += kldg + kldi + hypg + hypi

        # Total loss is MSE plus regularization losses
        loss = mse + kldt * (1.0 / self.total_nobs)

        # Log the errors
        logs = {'loss': loss, 'rmse': rmse, 'kld0': kld0, 'kld1': kld1,
                'kldg': kldg, 'kldi': kldi, 'hypg': hypg, 'hypi': hypi,
                'hypglv': F.sum(self.hyper_feat_lv_vec.b),
                'hypilv': F.sum(self.hyper_feat_delta_lv.b),
                'kldt': kldt, 'bias': F.sum(self.bias_mu.b)}
        reporter.report(logs, self)
        return loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号