def ais(rbm, step = 100, M = 100, parallel = False, seed = None):
W = rbm.W.data.numpy().T
v_bias = rbm.v_bias.data.numpy()
h_bias = rbm.h_bias.data.numpy()
logZ0 = np.log((1+np.exp(v_bias))).sum() + np.log(1+np.exp(h_bias)).sum()
ratio = []
if parallel:
num_cores = multiprocessing.cpu_count()
results = Parallel(n_jobs=num_cores)(delayed(mcmc)(step = step, seed = seed, W = W, h_bias = h_bias, v_bias = v_bias) for i in range(M))
results = np.array(results).reshape(len(results), 1)
logZ = logZ0 + logmeanexp(results, axis = 0)
else:
for i in range(M):
ratio.append(mcmc(step, seed = seed, W = W, h_bias = h_bias, v_bias = v_bias))
ratio = np.array(ratio).reshape(len(ratio),1)
logZ = logZ0 + logmeanexp(ratio, axis = 0)
return logZ
ais.py 文件源码
python
阅读 21
收藏 0
点赞 0
评论 0
评论列表
文章目录