test_sklearn_wrapper_classifier.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:chainer_sklearn 作者: corochann 项目源码 文件源码
def check_call(self):
        xp = self.link.xp

        y = chainer.Variable(xp.random.uniform(
            -1, 1, (5, 7)).astype(numpy.float32))
        self.link.predictor = mock.MagicMock(return_value=y)

        x = chainer.Variable(xp.asarray(self.x))
        t = chainer.Variable(xp.asarray(self.t))
        if self.x_num == 1:
            loss = self.link(x, t)
            self.link.predictor.assert_called_with(x)
        elif self.x_num == 2:
            x_ = chainer.Variable(xp.asarray(self.x.copy()))
            loss = self.link(x, x_, t)
            self.link.predictor.assert_called_with(x, x_)

        self.assertTrue(hasattr(self.link, 'y'))
        self.assertIsNotNone(self.link.y)

        self.assertTrue(hasattr(self.link, 'loss'))
        xp.testing.assert_allclose(self.link.loss.data, loss.data)

        self.assertTrue(hasattr(self.link, 'accuracy'))
        if self.compute_accuracy:
            self.assertIsNotNone(self.link.accuracy)
        else:
            self.assertIsNone(self.link.accuracy)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号