def _update_usage(self, hidden_vb, prev_usage_vb):
"""
calculates the new usage after reading and freeing from memory
variables needed:
hidden_vb: [batch_size x hidden_dim]
prev_usage_vb: [batch_size x mem_hei]
free_gate_vb: [batch_size x num_heads x 1]
wl_prev_vb: [batch_size x num_heads x mem_hei]
returns:
usage_vb: [batch_size x mem_hei]
"""
self.free_gate_vb = F.sigmoid(self.hid_2_free_gate(hidden_vb)).view(-1, self.num_heads, 1)
free_read_weights_vb = self.free_gate_vb.expand_as(self.wl_prev_vb) * self.wl_prev_vb
psi_vb = torch.prod(1. - free_read_weights_vb, 1)
return prev_usage_vb * psi_vb
评论列表
文章目录