def get_content_adressing(self, memory_matrix, keys, strengths):
"""
retrives a content-based addressing weighting given the keys
Parameters:
----------
memory_matrix: Tensor (batch_size, memory_locations, word_size)
the memory matrix to lookup in
keys: Tensor (batch_size, word_size, number_of_keys)
the keys to query the memory with
strengths: Tensor (batch_size, number_of_keys)
the list of strengths for each lookup key
Returns: Tensor (batch_size, memory_locations, number_of_keys)
The list of lookup weightings for each provided key
"""
normalized_memory = tf.nn.l2_normalize(memory_matrix, 2)
normalized_keys = tf.nn.l2_normalize(keys, 1)
similiarity = tf.batch_matmul(normalized_memory, normalized_keys)
strengths = tf.expand_dims(strengths, 1)
return tf.nn.softmax(similiarity * strengths, 1)
评论列表
文章目录