def cholesky(kron_a):
"""Computes the Cholesky decomposition of a given Kronecker-factorized matrix.
Args:
kron_a: `TensorTrain` object containing a matrix of size N x N,
factorized into a Kronecker product of square matrices (all
tt-ranks are 1 and all tt-cores are square). All the cores
must be symmetric positive-definite.
Returns:
`TensorTrain` object, containing a TT-matrix of size N x N.
Raises:
ValueError if the tt-cores of the provided matrix are not square,
or the tt-ranks are not 1.
"""
if not _is_kron(kron_a):
raise ValueError('The argument should be a Kronecker product '
'(tt-ranks should be 1)')
shapes_defined = kron_a.get_shape().is_fully_defined()
if shapes_defined:
i_shapes = kron_a.get_raw_shape()[0]
j_shapes = kron_a.get_raw_shape()[1]
else:
i_shapes = ops.raw_shape(kron_a)[0]
j_shapes = ops.raw_shape(kron_a)[1]
if shapes_defined:
if i_shapes != j_shapes:
raise ValueError('The argument should be a Kronecker product of square '
'matrices (tt-cores must be square)')
cho_cores = []
for core_idx in range(kron_a.ndims()):
core = kron_a.tt_cores[core_idx]
core_cho = tf.cholesky(core[0, :, :, 0])
cho_cores.append(tf.expand_dims(tf.expand_dims(core_cho, 0), -1))
res_ranks = kron_a.get_tt_ranks()
res_shape = kron_a.get_raw_shape()
return TensorTrain(cho_cores, res_shape, res_ranks)
评论列表
文章目录