test_tractable.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:fang 作者: rgrosse 项目源码 文件源码
def check_partition_function():
    with misc.gnumpy_conversion_check('allow'):
        rbm = random_rbm()
        total = -np.infty

        for vis_ in itertools.product(*[[0, 1]] * NVIS):
            vis = gnp.garray(vis_)
            for hid_ in itertools.product(*[[0, 1]] * NHID):
                hid = gnp.garray(hid_)
                total = np.logaddexp(total, rbm.energy(vis[nax, :], hid[nax, :])[0])

        assert np.allclose(tractable.exact_partition_function(rbm, batch_units=BATCH_UNITS), total)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号