def _allocation(self, usage):
r"""Computes allocation by sorting `usage`.
This corresponds to the value a = a_t[\phi_t[j]] in the paper.
Args:
usage: tensor of shape `[batch_size, memory_size]` indicating current
memory usage. This is equal to u_t in the paper when we only have one
write head, but for multiple write heads, one should update the usage
while iterating through the write heads to take into account the
allocation returned by this function.
Returns:
Tensor of shape `[batch_size, memory_size]` corresponding to allocation.
"""
with tf.name_scope('allocation'):
# Ensure values are not too small prior to cumprod.
usage = _EPSILON + (1 - _EPSILON) * usage
nonusage = 1 - usage
sorted_nonusage, indices = tf.nn.top_k(
nonusage, k=self._memory_size, name='sort')
sorted_usage = 1 - sorted_nonusage
prod_sorted_usage = tf.cumprod(sorted_usage, axis=1, exclusive=True)
sorted_allocation = sorted_nonusage * prod_sorted_usage
inverse_indices = util.batch_invert_permutation(indices)
# This final line "unsorts" sorted_allocation, so that the indexing
# corresponds to the original indexing of `usage`.
return util.batch_gather(sorted_allocation, inverse_indices)
评论列表
文章目录