def testSqrtMatmulBatchMatrixWithTranspose(self):
with self.test_session():
batch_shape = (2, 3)
for k in [1, 4]:
x_shape = batch_shape + (5, k)
x = self._rng.rand(*x_shape)
chol_shape = batch_shape + (k, k)
chol = self._random_cholesky_array(chol_shape)
operator = operator_pd_cholesky.OperatorPDCholesky(chol)
sqrt_operator_times_x = operator.sqrt_matmul(x, transpose_x=True)
# tf.batch_matmul is defined x * y, so "y" is on the right, not "x".
expected = math_ops.matmul(chol, x, adjoint_b=True)
self.assertEqual(expected.get_shape(),
sqrt_operator_times_x.get_shape())
self.assertAllClose(expected.eval(), sqrt_operator_times_x.eval())
operator_pd_cholesky_test.py 文件源码
python
阅读 18
收藏 0
点赞 0
评论 0
评论列表
文章目录