def update_link_matrix(self, link_matrix_old, precedence_weighting_old, write_weighting):
"""
Updating the link matrix takes some effort (in order to vectorize the implementation)
Instead of the original index-by-index operation, it's all done at once.
:param link_matrix_old: from previous time step, shape [batch_size, memory_size, memory_size]
:param precedence_weighting_old: from previous time step, shape [batch_size, memory_size]
:param write_weighting: from current time step, shape [batch_size, memory_size]
:return: updated link matrix
"""
expanded = tf.expand_dims(write_weighting, axis=2)
# vectorizing the paper's original implementation
w = tf.tile(expanded, [1, 1, self.memory_size]) # shape [batch_size, memory_size, memory_size]
# shape of w_transpose is the same: [batch_size, memory_size, memory_size]
w_transp = tf.tile(tf.transpose(expanded, [0, 2, 1]), [1, self.memory_size, 1])
# in einsum, m and n are the same dimension because tensorflow doesn't support duplicated subscripts. Why?
lm = (1 - w - w_transp) * link_matrix_old + tf.einsum("bn,bm->bmn", precedence_weighting_old, write_weighting)
lm *= (1 - tf.eye(self.memory_size, batch_shape=[self.batch_size])) # making sure self links are off
return tf.identity(lm, name="Link_matrix")
评论列表
文章目录