def test_predict(self):
rs = np.random.RandomState(1)
X = rs.rand(20, 10)
Y = rs.rand(10, 1)
model = RandomForestWithInstances(np.zeros((10,), dtype=np.uint), bounds=np.array(
list(map(lambda x: (0, 10), range(10))), dtype=object))
model.train(X[:10], Y[:10])
m_hat, v_hat = model.predict(X[10:])
self.assertEqual(m_hat.shape, (10, 1))
self.assertEqual(v_hat.shape, (10, 1))
评论列表
文章目录