def check_partition_function():
with misc.gnumpy_conversion_check('allow'):
rbm = random_rbm()
total = -np.infty
for vis_ in itertools.product(*[[0, 1]] * NVIS):
vis = gnp.garray(vis_)
for hid_ in itertools.product(*[[0, 1]] * NHID):
hid = gnp.garray(hid_)
total = np.logaddexp(total, rbm.energy(vis[nax, :], hid[nax, :])[0])
assert np.allclose(tractable.exact_partition_function(rbm, batch_units=BATCH_UNITS), total)
评论列表
文章目录