hmc.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:a-nice-mc 作者: ermongroup 项目源码 文件源码
def simulate_dynamics(initial_pos, initial_vel, stepsize, n_steps, energy_fn):
    def leapfrog(pos, vel, step, i):
        de_dp_ = tf.gradients(tf.reduce_sum(energy_fn(pos)), pos)[0]
        new_vel_ = vel - step * de_dp_
        new_pos_ = pos + step * new_vel_
        return [new_pos_, new_vel_, step, tf.add(i, 1)]

    def condition(pos, vel, step, i):
        return tf.less(i, n_steps)

    de_dp = tf.gradients(tf.reduce_sum(energy_fn(initial_pos)), initial_pos)[0]
    vel_half_step = initial_vel - 0.5 * stepsize * de_dp
    pos_full_step = initial_pos + stepsize * vel_half_step

    i = tf.constant(0)
    final_pos, new_vel, _, _ = tf.while_loop(condition, leapfrog, [pos_full_step, vel_half_step, stepsize, i])
    de_dp = tf.gradients(tf.reduce_sum(energy_fn(final_pos)), final_pos)[0]
    final_vel = new_vel - 0.5 * stepsize * de_dp
    return final_pos, final_vel
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号