def _leapfrog_step(xs, ps, epsilon, max_iterations, logprob_grads_fn):
def update_xs(ps_values):
return _map(lambda x, p: x.assign_add(epsilon * p), xs, ps_values)
def whether_proceed(grads):
finits = _map(lambda grad: tf.reduce_all(tf.is_finite(grad)), grads)
return tf.reduce_all(finits)
def cond(i, proceed, _ps, _xs):
return tf.logical_and(proceed, i < max_iterations)
def body(i, _proceed, ps, _xs):
xs_new = update_xs(ps)
with tf.control_dependencies(xs_new):
_, grads = logprob_grads_fn()
proceed = whether_proceed(grads)
def ps_step():
with tf.control_dependencies(grads):
return _update_ps(ps, grads, epsilon)
def ps_no_step():
with tf.control_dependencies(grads):
return ps
ps_new = tf.cond(proceed, ps_step, ps_no_step, strict=True)
return i + 1, proceed, ps_new, xs_new
result = _while_loop(cond, body, [0, True, ps, xs])
_i, proceed_out, ps_out, xs_out = result
deps = _flat([proceed_out], ps_out, xs_out)
with tf.control_dependencies(deps):
logprob_out, grads_out = logprob_grads_fn()
return proceed_out, xs_out, ps_out, logprob_out, grads_out
评论列表
文章目录