pondering_rnn.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:sonnet 作者: deepmind 项目源码 文件源码
def _body(self, x, cumul_out, prev_state, cumul_state,
            cumul_halting, iteration, remainder, halting_linear, x_ones):
    """The `body` of `tf.while_loop`."""
    # Increase iteration count only for those elements that are still running.
    all_ones = tf.constant(1, shape=(self._batch_size, 1), dtype=self._dtype)
    is_iteration_over = tf.equal(cumul_halting, all_ones)
    next_iteration = tf.where(is_iteration_over, iteration, iteration + 1)
    out, next_state = self._core(x, prev_state)
    # Get part of state used to compute halting values.
    halting_input = halting_linear(self._get_state_for_halting(next_state))
    halting = tf.sigmoid(halting_input, name="halting")
    next_cumul_halting_raw = cumul_halting + halting
    over_threshold = next_cumul_halting_raw > self._threshold
    next_cumul_halting = tf.where(over_threshold, all_ones,
                                  next_cumul_halting_raw)
    next_remainder = tf.where(over_threshold, remainder,
                              1 - next_cumul_halting_raw)
    p = next_cumul_halting - cumul_halting
    next_cumul_state = _nested_add(cumul_state,
                                   _nested_unary_mul(next_state, p))
    next_cumul_out = cumul_out + p * out

    return (x_ones, next_cumul_out, next_state, next_cumul_state,
            next_cumul_halting, next_iteration, next_remainder)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号