def gather_prev_stack_aux_state_index(pointer_vals, prev_index, transition_state,
batch_size):
"""Gather new prev state index for aux rnn: as for main, but zero if shift."""
new_pointer_vals = tf.reshape(pointer_vals, [-1, 1])
# Helper tensors.
prev_vals = tf.reshape(tf.fill(
tf.pack([batch_size]), prev_index), [-1, 1])
trans_inds = tf.transpose(tf.pack(
[tf.range(batch_size), transition_state]))
batch_zeros = tf.reshape(tf.zeros(
tf.pack([batch_size]), dtype=tf.int32), [-1, 1])
# Gather new prev state for aux tf.nn.
# State inds dimension [batch_size, NUM_TR_STATES]
state_inds = tf.concat(1,
[prev_vals, batch_zeros] + [prev_vals]*4 + [new_pointer_vals, prev_vals])
prev_state_index = tf.gather_nd(state_inds, trans_inds)
return prev_state_index
评论列表
文章目录