def test_transform():
X = Xdigits[:100]
rbm1 = BernoulliRBM(n_components=16, batch_size=5,
n_iter=5, random_state=42)
rbm1.fit(X)
Xt1 = rbm1.transform(X)
Xt2 = rbm1._mean_hiddens(X)
assert_array_equal(Xt1, Xt2)
评论列表
文章目录