def _body(self, x, cumul_out, prev_state, cumul_state,
cumul_halting, iteration, remainder, halting_linear, x_ones):
"""The `body` of `tf.while_loop`."""
# Increase iteration count only for those elements that are still running.
all_ones = tf.constant(1, shape=(self._batch_size, 1), dtype=self._dtype)
is_iteration_over = tf.equal(cumul_halting, all_ones)
next_iteration = tf.where(is_iteration_over, iteration, iteration + 1)
out, next_state = self._core(x, prev_state)
# Get part of state used to compute halting values.
halting_input = halting_linear(self._get_state_for_halting(next_state))
halting = tf.sigmoid(halting_input, name="halting")
next_cumul_halting_raw = cumul_halting + halting
over_threshold = next_cumul_halting_raw > self._threshold
next_cumul_halting = tf.where(over_threshold, all_ones,
next_cumul_halting_raw)
next_remainder = tf.where(over_threshold, remainder,
1 - next_cumul_halting_raw)
p = next_cumul_halting - cumul_halting
next_cumul_state = _nested_add(cumul_state,
_nested_unary_mul(next_state, p))
next_cumul_out = cumul_out + p * out
return (x_ones, next_cumul_out, next_state, next_cumul_state,
next_cumul_halting, next_iteration, next_remainder)
评论列表
文章目录