def testSqrtMatmulSingleMatrix(self):
with self.test_session():
batch_shape = ()
for k in [1, 4]:
x_shape = batch_shape + (k, 3)
x = self._rng.rand(*x_shape)
chol_shape = batch_shape + (k, k)
chol = self._random_cholesky_array(chol_shape)
operator = operator_pd_cholesky.OperatorPDCholesky(chol)
sqrt_operator_times_x = operator.sqrt_matmul(x)
expected = math_ops.matmul(chol, x)
self.assertEqual(expected.get_shape(),
sqrt_operator_times_x.get_shape())
self.assertAllClose(expected.eval(), sqrt_operator_times_x.eval())
operator_pd_cholesky_test.py 文件源码
python
阅读 19
收藏 0
点赞 0
评论 0
评论列表
文章目录