rais.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:restricted-boltzmann-machine-deep-belief-network-deep-boltzmann-machine-in-pytorch 作者: wmingwei 项目源码 文件源码
def rais(self, data, step = 1000, M = 100, parallel = False, seed = None):
        num_data = data.shape[0]
        result = 0
        if not parallel:
            p = []
            for i in range(M):
                logw = self.mcmc_r(data, step, num_data)
                p.append(logw)

            p = np.array(p)
            logmeanp = logmeanexp(p, axis = 0)
        else:
            num_cores = multiprocessing.cpu_count()

            p = Parallel(n_jobs=num_cores)(delayed(self.mcmc_r)(v = data, step = step, num_data = num_data, seed = seed) for i in range(M))

            p = np.array(p)

            logmeanp = logmeanexp(p, axis = 0)

        result = logmeanp.mean()

        return result
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号