kronecker.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:t3f 作者: Bihaqo 项目源码 文件源码
def determinant(kron_a):
  """Computes the determinant of a given Kronecker-factorized matrix. 

  Note, that this method can suffer from overflow.

  Args:
    kron_a: `TensorTrain` object containing a matrix of size N x N, 
    factorized into a Kronecker product of square matrices (all 
    tt-ranks are 1 and all tt-cores are square). 

  Returns:
    Number, the determinant of the given matrix.

  Raises:
    ValueError if the tt-cores of the provided matrix are not square,
    or the tt-ranks are not 1.
  """
  if not _is_kron(kron_a):
    raise ValueError('The argument should be a Kronecker product (tt-ranks '
                     'should be 1)')

  shapes_defined = kron_a.get_shape().is_fully_defined()
  if shapes_defined:
    i_shapes = kron_a.get_raw_shape()[0]
    j_shapes = kron_a.get_raw_shape()[1]
  else:
    i_shapes = ops.raw_shape(kron_a)[0]
    j_shapes = ops.raw_shape(kron_a)[1]

  if shapes_defined:
    if i_shapes != j_shapes:
      raise ValueError('The argument should be a Kronecker product of square '
                       'matrices (tt-cores must be square)')

  pows = tf.cast(tf.reduce_prod(i_shapes), kron_a.dtype)
  cores = kron_a.tt_cores
  det = 1
  for core_idx in range(kron_a.ndims()):
    core = cores[core_idx]
    core_det = tf.matrix_determinant(core[0, :, :, 0])
    core_pow = pows / i_shapes[core_idx].value

    det *= tf.pow(core_det, core_pow)
  return det
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号