def check_call(self):
xp = self.link.xp
y = chainer.Variable(xp.random.uniform(
-1, 1, (5, 7)).astype(numpy.float32))
self.link.predictor = mock.MagicMock(return_value=y)
x = chainer.Variable(xp.asarray(self.x))
t = chainer.Variable(xp.asarray(self.t))
if self.x_num == 1:
loss = self.link(x, t)
self.link.predictor.assert_called_with(x)
elif self.x_num == 2:
x_ = chainer.Variable(xp.asarray(self.x.copy()))
loss = self.link(x, x_, t)
self.link.predictor.assert_called_with(x, x_)
self.assertTrue(hasattr(self.link, 'y'))
self.assertIsNotNone(self.link.y)
self.assertTrue(hasattr(self.link, 'loss'))
xp.testing.assert_allclose(self.link.loss.data, loss.data)
self.assertTrue(hasattr(self.link, 'accuracy'))
if self.compute_accuracy:
self.assertIsNotNone(self.link.accuracy)
else:
self.assertIsNone(self.link.accuracy)
test_sklearn_wrapper_classifier.py 文件源码
python
阅读 23
收藏 0
点赞 0
评论 0
评论列表
文章目录