def gather_prev_stack_state_index(pointer_vals, prev_index, transition_state,
batch_size):
"""Gathers new previous state index."""
new_pointer_vals = tf.reshape(pointer_vals, [-1, 1])
# Helper tensors.
prev_vals = tf.reshape(tf.fill(
tf.pack([batch_size]), prev_index), [-1, 1])
trans_inds = tf.transpose(tf.pack(
[tf.range(batch_size), transition_state]))
# Gather new prev state for main tf.nn. Pointer vals if reduce, else prev.
# State inds dimension [batch_size, NUM_TR_STATES]
state_inds = tf.concat(1, [prev_vals]*6 + [new_pointer_vals, prev_vals])
prev_state_index = tf.gather_nd(state_inds, trans_inds)
return prev_state_index
评论列表
文章目录