def write_thin_stack_vals(thin_stack, stack_pointers, new_vals, batch_size,
max_num_concepts):
"""Writes to the thin stack at the given pointers the current decoder position."""
# SparseTensor requires type int64.
stack_inds = tf.transpose(tf.to_int64(tf.pack(
[tf.range(batch_size), stack_pointers]))) # nn_stack_pointers
current_vals = tf.gather_nd(thin_stack, stack_inds)
delta = tf.SparseTensor(stack_inds, new_vals - current_vals,
tf.pack([tf.to_int64(batch_size), max_num_concepts]))
new_thin_stack = thin_stack + tf.sparse_tensor_to_dense(delta)
return new_thin_stack
评论列表
文章目录