hmc.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:GPflow 作者: GPflow 项目源码 文件源码
def _leapfrog_step(xs, ps, epsilon, max_iterations, logprob_grads_fn):
    def update_xs(ps_values):
        return _map(lambda x, p: x.assign_add(epsilon * p), xs, ps_values)

    def whether_proceed(grads):
        finits = _map(lambda grad: tf.reduce_all(tf.is_finite(grad)), grads)
        return tf.reduce_all(finits)

    def cond(i, proceed, _ps, _xs):
        return tf.logical_and(proceed, i < max_iterations)

    def body(i, _proceed, ps, _xs):
        xs_new = update_xs(ps)
        with tf.control_dependencies(xs_new):
            _, grads = logprob_grads_fn()
            proceed = whether_proceed(grads)
            def ps_step():
                with tf.control_dependencies(grads):
                    return _update_ps(ps, grads, epsilon)
            def ps_no_step():
                with tf.control_dependencies(grads):
                    return ps

            ps_new = tf.cond(proceed, ps_step, ps_no_step, strict=True)
            return i + 1, proceed, ps_new, xs_new

    result = _while_loop(cond, body, [0, True, ps, xs])

    _i, proceed_out, ps_out, xs_out = result
    deps = _flat([proceed_out], ps_out, xs_out)
    with tf.control_dependencies(deps):
        logprob_out, grads_out = logprob_grads_fn()
        return proceed_out, xs_out, ps_out, logprob_out, grads_out
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号