nuts_sampler.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:discontinuous-hmc 作者: aki-nishimura 项目源码 文件源码
def HMC(f, epsilon, n_step, theta0, logp0, grad0):

    p = random_momentum(len(theta0))
    joint0 = - compute_hamiltonian(logp0, p)

    nfevals_total = 0
    theta, p, grad, logp, nfevals = integrator(f, epsilon, theta0, p, grad0)
    nfevals_total += nfevals
    for i in range(1, n_step):
        theta, p, grad, logp, nfevals = integrator(f, epsilon, theta, p, grad)
        nfevals_total += nfevals

    joint = - compute_hamiltonian(logp, p)
    if math.isinf(logp):
        acceptprob = 0
    else:
        acceptprob = min(1, np.exp(joint - joint0))

    if acceptprob < np.random.rand():
        theta = theta0
        logp = logp0
        grad = grad0

    return theta, logp, grad, acceptprob, nfevals_total
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号