ais.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:restricted-boltzmann-machine-deep-belief-network-deep-boltzmann-machine-in-pytorch 作者: wmingwei 项目源码 文件源码
def ais(rbm, step = 100, M = 100, parallel = False, seed = None):

    W = rbm.W.data.numpy().T
    v_bias = rbm.v_bias.data.numpy()
    h_bias = rbm.h_bias.data.numpy()

    logZ0 = np.log((1+np.exp(v_bias))).sum() + np.log(1+np.exp(h_bias)).sum()
    ratio = []
    if parallel:
        num_cores = multiprocessing.cpu_count()

        results = Parallel(n_jobs=num_cores)(delayed(mcmc)(step = step, seed = seed, W = W, h_bias = h_bias, v_bias = v_bias) for i in range(M))


        results = np.array(results).reshape(len(results), 1)
        logZ = logZ0 + logmeanexp(results, axis = 0)
    else:
        for i in range(M):
            ratio.append(mcmc(step, seed = seed,  W = W, h_bias = h_bias, v_bias = v_bias))

        ratio = np.array(ratio).reshape(len(ratio),1)
        logZ = logZ0 + logmeanexp(ratio, axis = 0)

    return logZ
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号