def testCompareProjectSumAndProject(self):
# Compare results of project_sum and project.
tens = initializers.random_tensor_batch((2, 3, 4), 3, batch_size=4)
tangent_tens = initializers.random_tensor((2, 3, 4), 4)
project_sum = riemannian.project_sum(tens, tangent_tens, tf.eye(4))
project = riemannian.project(tens, tangent_tens)
with self.test_session() as sess:
res = sess.run((ops.full(project_sum), ops.full(project)))
project_sum_val, project_val = res
self.assertAllClose(project_sum_val, project_val)
评论列表
文章目录