def hmc_updates(initial_pos, stepsize, avg_acceptance_rate, final_pos, accept,
target_acceptance_rate, stepsize_inc, stepsize_dec,
stepsize_min, stepsize_max, avg_acceptance_slowness):
new_pos = tf.where(accept, final_pos, initial_pos)
new_stepsize_ = tf.multiply(
stepsize,
tf.where(tf.greater(avg_acceptance_rate, target_acceptance_rate), stepsize_inc, stepsize_dec)
)
new_stepsize = tf.maximum(tf.minimum(new_stepsize_, stepsize_max), stepsize_min)
new_acceptance_rate = tf.add(
avg_acceptance_slowness * avg_acceptance_rate,
(1.0 - avg_acceptance_slowness) * tf.reduce_mean(tf.to_float(accept))
)
return new_pos, new_stepsize, new_acceptance_rate
评论列表
文章目录