def _content_focus(self, memory_vb):
"""
variables needed:
key_vb: [batch_size x num_heads x mem_wid]
-> similarity key vector, to compare to each row in memory
-> by cosine similarity
beta_vb: [batch_size x num_heads x 1]
-> NOTE: refer here: https://github.com/deepmind/dnc/issues/9
-> \in (1, +inf) after oneplus(); similarity key strength
-> amplify or attenuate the pecision of the focus
memory_vb: [batch_size x mem_hei x mem_wid]
returns:
wc_vb: [batch_size x num_heads x mem_hei]
-> the attention weight by content focus
"""
K_vb = batch_cosine_sim(self.key_vb, memory_vb) # [batch_size x num_heads x mem_hei]
self.wc_vb = K_vb * self.beta_vb.expand_as(K_vb) # [batch_size x num_heads x mem_hei]
self.wc_vb = F.softmax(self.wc_vb.transpose(0, 2)).transpose(0, 2)
评论列表
文章目录