def _update_usage(self, prev_usage_vb):
"""
calculates the new usage after writing to memory
variables needed:
prev_usage_vb: [batch_size x mem_hei]
wl_prev_vb: [batch_size x num_write_heads x mem_hei]
returns:
usage_vb: [batch_size x mem_hei]
"""
# calculate the aggregated effect of all write heads
# NOTE: how multiple write heads are delt w/ is not discussed in the paper
# NOTE: this part is only shown in the source code
write_weights_vb = 1. - torch.prod(1. - self.wl_prev_vb, 1)
return prev_usage_vb + (1. - prev_usage_vb) * write_weights_vb
评论列表
文章目录