test_net.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:skorch 作者: dnouri 项目源码 文件源码
def test_change_get_loss(self, net_cls, module_cls, data):
        from skorch.utils import to_var

        class MyNet(net_cls):
            # pylint: disable=unused-argument
            def get_loss(self, y_pred, y_true, X=None, training=False):
                y_true = to_var(y_true, use_cuda=False)
                loss_a = torch.abs(y_true.float() - y_pred[:, 1]).mean()
                loss_b = ((y_true.float() - y_pred[:, 1]) ** 2).mean()
                if training:
                    self.history.record_batch('loss_a', to_numpy(loss_a)[0])
                    self.history.record_batch('loss_b', to_numpy(loss_b)[0])
                return loss_a + loss_b

        X, y = data
        net = MyNet(module_cls, max_epochs=1)
        net.fit(X, y)

        diffs = []
        all_losses = net.history[
            -1, 'batches', :, ('train_loss', 'loss_a', 'loss_b')]
        diffs = [total - a - b for total, a, b in all_losses]
        assert np.allclose(diffs, 0, atol=1e-7)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号