def determinant(kron_a):
"""Computes the determinant of a given Kronecker-factorized matrix.
Note, that this method can suffer from overflow.
Args:
kron_a: `TensorTrain` object containing a matrix of size N x N,
factorized into a Kronecker product of square matrices (all
tt-ranks are 1 and all tt-cores are square).
Returns:
Number, the determinant of the given matrix.
Raises:
ValueError if the tt-cores of the provided matrix are not square,
or the tt-ranks are not 1.
"""
if not _is_kron(kron_a):
raise ValueError('The argument should be a Kronecker product (tt-ranks '
'should be 1)')
shapes_defined = kron_a.get_shape().is_fully_defined()
if shapes_defined:
i_shapes = kron_a.get_raw_shape()[0]
j_shapes = kron_a.get_raw_shape()[1]
else:
i_shapes = ops.raw_shape(kron_a)[0]
j_shapes = ops.raw_shape(kron_a)[1]
if shapes_defined:
if i_shapes != j_shapes:
raise ValueError('The argument should be a Kronecker product of square '
'matrices (tt-cores must be square)')
pows = tf.cast(tf.reduce_prod(i_shapes), kron_a.dtype)
cores = kron_a.tt_cores
det = 1
for core_idx in range(kron_a.ndims()):
core = cores[core_idx]
core_det = tf.matrix_determinant(core[0, :, :, 0])
core_pow = pows / i_shapes[core_idx].value
det *= tf.pow(core_det, core_pow)
return det
评论列表
文章目录