def test_fit_gibbs():
# Gibbs on the RBM hidden layer should be able to recreate [[0], [1]]
# from the same input
rng = np.random.RandomState(42)
X = np.array([[0.], [1.]])
rbm1 = BernoulliRBM(n_components=2, batch_size=2,
n_iter=42, random_state=rng)
# you need that much iters
rbm1.fit(X)
assert_almost_equal(rbm1.components_,
np.array([[0.02649814], [0.02009084]]), decimal=4)
assert_almost_equal(rbm1.gibbs(X), X)
return rbm1
评论列表
文章目录