def test_log_abs_det(self):
self._maybe_skip("log_abs_det")
for use_placeholder in False, True:
for shape in self._shapes_to_test:
for dtype in self._dtypes_to_test:
if dtype.is_complex:
self.skipTest(
"tf.matrix_determinant does not work with complex, so this "
"test is being skipped.")
with self.test_session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
shape, dtype, use_placeholder=use_placeholder)
op_log_abs_det = operator.log_abs_determinant()
mat_log_abs_det = math_ops.log(
math_ops.abs(linalg_ops.matrix_determinant(mat)))
if not use_placeholder:
self.assertAllEqual(shape[:-2], op_log_abs_det.get_shape())
op_log_abs_det_v, mat_log_abs_det_v = sess.run(
[op_log_abs_det, mat_log_abs_det],
feed_dict=feed_dict)
self.assertAC(op_log_abs_det_v, mat_log_abs_det_v)
linear_operator_test_util.py 文件源码
python
阅读 19
收藏 0
点赞 0
评论 0
评论列表
文章目录