def test_link_matrix(self):
b, n = 2, 5
write_weighting = np.random.rand(b, n)
precedence_weighting = np.random.rand(b, n) # precedence weighting from previous time step
link_matrix_old = np.random.rand(b, n, n) * (
1 - np.tile(np.eye(5), [b, 1, 1])) # random link matrix with diagonals zero
link_matrix_correct = np.zeros((b, n, n))
for k in range(b):
for i in range(n):
for j in range(n):
if i != j:
link_matrix_correct[k, i, j] = (1 - write_weighting[k, i] - write_weighting[k, j]) * \
link_matrix_old[k, i, j] + \
write_weighting[k, i] * precedence_weighting[k, j]
with self.test_session():
tf.global_variables_initializer().run()
Memory.batch_size = b
Memory.memory_size = n
new_link_matrix = Memory.update_link_matrix(Memory,
tf.constant(link_matrix_old, dtype=tf.float32),
tf.constant(precedence_weighting, dtype=tf.float32),
tf.constant(write_weighting, dtype=tf.float32))
self.assertAllClose(link_matrix_correct, new_link_matrix.eval())
评论列表
文章目录