def gather_nd_states(inputs, inds, batch_size, input_size, state_size):
"""Gathers an embedding for each batch entry with index inds from inputs.
Args:
inputs: Tensor [batch_size, input_size, state_size].
inds: Tensor [batch_size]
Returns:
output: Tensor [batch_size, embedding_size]
"""
sparse_inds = tf.transpose(tf.pack(
[tf.range(batch_size), inds]))
dense_inds = tf.sparse_to_dense(sparse_inds,
tf.pack([batch_size, input_size]),
tf.ones(tf.pack([batch_size])))
output_sum = tf.reduce_sum(tf.reshape(dense_inds,
[-1, input_size, 1, 1]) * tf.reshape(inputs,
[-1, input_size, 1, state_size]), [1, 2])
output = tf.reshape(output_sum, [-1, state_size])
return output
评论列表
文章目录