def _access(self, memory_vb): # write
"""
variables needed:
wl_curr_vb: [batch_size x num_heads x mem_hei]
erase_vb: [batch_size x num_heads x mem_wid]
-> /in (0, 1)
add_vb: [batch_size x num_heads x mem_wid]
-> w/ no restrictions in range
memory_vb: [batch_size x mem_hei x mem_wid]
returns:
memory_vb: [batch_size x mem_hei x mem_wid]
NOTE: IMPORTANT: https://github.com/deepmind/dnc/issues/10
"""
# first let's do erasion
weighted_erase_vb = torch.bmm(self.wl_curr_vb.contiguous().view(-1, self.mem_hei, 1),
self.erase_vb.contiguous().view(-1, 1, self.mem_wid)).view(-1, self.num_heads, self.mem_hei, self.mem_wid)
keep_vb = torch.prod(1. - weighted_erase_vb, dim=1)
memory_vb = memory_vb * keep_vb
# finally let's write (do addition)
return memory_vb + torch.bmm(self.wl_curr_vb.transpose(1, 2), self.add_vb)
评论列表
文章目录