def simulate_dynamics(initial_pos, initial_vel, stepsize, n_steps, energy_fn):
def leapfrog(pos, vel, step, i):
de_dp_ = tf.gradients(tf.reduce_sum(energy_fn(pos)), pos)[0]
new_vel_ = vel - step * de_dp_
new_pos_ = pos + step * new_vel_
return [new_pos_, new_vel_, step, tf.add(i, 1)]
def condition(pos, vel, step, i):
return tf.less(i, n_steps)
de_dp = tf.gradients(tf.reduce_sum(energy_fn(initial_pos)), initial_pos)[0]
vel_half_step = initial_vel - 0.5 * stepsize * de_dp
pos_full_step = initial_pos + stepsize * vel_half_step
i = tf.constant(0)
final_pos, new_vel, _, _ = tf.while_loop(condition, leapfrog, [pos_full_step, vel_half_step, stepsize, i])
de_dp = tf.gradients(tf.reduce_sum(energy_fn(final_pos)), final_pos)[0]
final_vel = new_vel - 0.5 * stepsize * de_dp
return final_pos, final_vel
评论列表
文章目录