rais_dbn.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:restricted-boltzmann-machine-deep-belief-network-deep-boltzmann-machine-in-pytorch 作者: wmingwei 项目源码 文件源码
def ulogprob(v_input, dbn, M = 1000, parallel = False):
    logw = np.zeros([M, len(v_input)])
    # samples = v_input
    if not parallel:
        for i in range(M):
            # samples = v_input
            # for l in range(dbn.n_layers-1):
            #     logw[i,:] += -dbn.rbm_layers[l].free_energy(samples,dbn.rbm_layers[l].W)[0]
            #     samples = dbn.rbm_layers[l].sample_h_given_v(samples,dbn.rbm_layers[l].W,dbn.rbm_layers[l].h_bias)[0]
            #     logw[i,:] -= -dbn.rbm_layers[l].free_energy_hidden(samples,dbn.rbm_layers[l].W)[0]
            # logw[i,:] += -dbn.rbm_layers[-1].free_energy(samples,dbn.rbm_layers[-1].W)[0]
            logw[i,:] += important_sampling(v_input, dbn)
    else:
        num_cores = multiprocessing.cpu_count()

        results = Parallel(n_jobs=num_cores)(delayed(important_sampling)(v_input = v_input, dbn = dbn) for i in range(M))
        logw += np.asarray(results)

    return logmeanexp(logw,0)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号