def slog_determinant(kron_a):
"""Computes the sign and log-det of a given Kronecker-factorized matrix.
Args:
kron_a: `TensorTrain` object containing a matrix of size N x N,
factorized into a Kronecker product of square matrices (all
tt-ranks are 1 and all tt-cores are square).
Returns:
Two numbers, sign of the determinant and the log-determinant of the given
matrix. If the determinant is zero, then sign will be 0 and logdet will be
-Inf. In all cases, the determinant is equal to sign * np.exp(logdet).
Raises:
ValueError if the tt-cores of the provided matrix are not square,
or the tt-ranks are not 1.
"""
if not _is_kron(kron_a):
raise ValueError('The argument should be a Kronecker product '
'(tt-ranks should be 1)')
shapes_defined = kron_a.get_shape().is_fully_defined()
if shapes_defined:
i_shapes = kron_a.get_raw_shape()[0]
j_shapes = kron_a.get_raw_shape()[1]
else:
i_shapes = ops.raw_shape(kron_a)[0]
j_shapes = ops.raw_shape(kron_a)[1]
if shapes_defined:
if i_shapes != j_shapes:
raise ValueError('The argument should be a Kronecker product of square '
'matrices (tt-cores must be square)')
pows = tf.cast(tf.reduce_prod(i_shapes), kron_a.dtype)
logdet = 0.
det_sign = 1.
for core_idx in range(kron_a.ndims()):
core = kron_a.tt_cores[core_idx]
core_det = tf.matrix_determinant(core[0, :, :, 0])
core_abs_det = tf.abs(core_det)
core_det_sign = tf.sign(core_det)
core_pow = pows / i_shapes[core_idx].value
logdet += tf.log(core_abs_det) * core_pow
det_sign *= core_det_sign**(core_pow)
return det_sign, logdet
评论列表
文章目录