kronecker.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:t3f 作者: Bihaqo 项目源码 文件源码
def cholesky(kron_a):
  """Computes the Cholesky decomposition of a given Kronecker-factorized matrix.

  Args:
    kron_a: `TensorTrain` object containing a matrix of size N x N, 
    factorized into a Kronecker product of square matrices (all 
    tt-ranks are 1 and all tt-cores are square). All the cores
    must be symmetric positive-definite.

  Returns:
    `TensorTrain` object, containing a TT-matrix of size N x N. 

  Raises:
    ValueError if the tt-cores of the provided matrix are not square,
    or the tt-ranks are not 1.
  """
  if not _is_kron(kron_a):
    raise ValueError('The argument should be a Kronecker product ' 
                     '(tt-ranks should be 1)')

  shapes_defined = kron_a.get_shape().is_fully_defined()
  if shapes_defined:
    i_shapes = kron_a.get_raw_shape()[0]
    j_shapes = kron_a.get_raw_shape()[1]
  else:
    i_shapes = ops.raw_shape(kron_a)[0]
    j_shapes = ops.raw_shape(kron_a)[1]

  if shapes_defined:
    if i_shapes != j_shapes:
      raise ValueError('The argument should be a Kronecker product of square '
                       'matrices (tt-cores must be square)')
  cho_cores = []
  for core_idx in range(kron_a.ndims()):
    core = kron_a.tt_cores[core_idx]
    core_cho = tf.cholesky(core[0, :, :, 0])
    cho_cores.append(tf.expand_dims(tf.expand_dims(core_cho, 0), -1))

  res_ranks = kron_a.get_tt_ranks() 
  res_shape = kron_a.get_raw_shape()
  return TensorTrain(cho_cores, res_shape, res_ranks)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号