def _allocation(self, usage_vb, epsilon=1e-6):
"""
computes allocation by sorting usage, a = a_t[\phi_t[j]]
variables needed:
usage_vb: [batch_size x mem_hei]
-> indicating current memory usage, this is equal to u_t in
the paper when we only have one write head, but for
multiple write heads, one should update the usage while
iterating through the write heads to take into account the
allocation returned by this function
returns:
alloc_vb: [batch_size x num_write_heads x mem_hei]
"""
# ensure values are not too small prior to cumprod
usage_vb = epsilon + (1 - epsilon) * usage_vb
# NOTE: we sort usage in ascending order
sorted_usage_vb, indices_vb = torch.topk(usage_vb, k=self.mem_hei, dim=1, largest=False)
# to imitate tf.cumrprod(exclusive=True) https://discuss.pytorch.org/t/cumprod-exclusive-true-equivalences/2614/8
cat_sorted_usage_vb = torch.cat((Variable(torch.ones(self.batch_size, 1)).type(self.dtype), sorted_usage_vb), 1)[:, :-1]
# TODO: seems we have to wait for this PR: https://github.com/pytorch/pytorch/pull/1439
prod_sorted_usage_vb = fake_cumprod(cat_sorted_usage_vb)
# prod_sorted_usage_vb = torch.cumprod(cat_sorted_usage_vb, dim=1) # TODO: use this once the PR is ready
# alloc_weight_vb = (1 - sorted_usage_vb) * prod_sorted_usage_vb # equ. (1) # 0.1.12
alloc_weight_vb = (1 - sorted_usage_vb) * prod_sorted_usage_vb.squeeze() # equ. (1) # 0.2.0
_, indices_vb = torch.topk(indices_vb, k=self.mem_hei, dim=1, largest=False)
alloc_weight_vb = alloc_weight_vb.gather(1, indices_vb)
return alloc_weight_vb
评论列表
文章目录