kronecker.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:t3f 作者: Bihaqo 项目源码 文件源码
def slog_determinant(kron_a):
  """Computes the sign and log-det of a given Kronecker-factorized matrix.

  Args:
    kron_a: `TensorTrain` object containing a matrix of size N x N, 
    factorized into a Kronecker product of square matrices (all 
    tt-ranks are 1 and all tt-cores are square).

  Returns:
    Two numbers, sign of the determinant and the log-determinant of the given 
    matrix. If the determinant is zero, then sign will be 0 and logdet will be
    -Inf. In all cases, the determinant is equal to sign * np.exp(logdet).

  Raises:
    ValueError if the tt-cores of the provided matrix are not square,
    or the tt-ranks are not 1.
  """
  if not _is_kron(kron_a):
    raise ValueError('The argument should be a Kronecker product ' 
                     '(tt-ranks should be 1)')

  shapes_defined = kron_a.get_shape().is_fully_defined()
  if shapes_defined:
    i_shapes = kron_a.get_raw_shape()[0]
    j_shapes = kron_a.get_raw_shape()[1]
  else:
    i_shapes = ops.raw_shape(kron_a)[0]
    j_shapes = ops.raw_shape(kron_a)[1]

  if shapes_defined:
    if i_shapes != j_shapes:
      raise ValueError('The argument should be a Kronecker product of square '
                       'matrices (tt-cores must be square)')
  pows = tf.cast(tf.reduce_prod(i_shapes), kron_a.dtype)

  logdet = 0.
  det_sign = 1.
  for core_idx in range(kron_a.ndims()):
    core = kron_a.tt_cores[core_idx]
    core_det = tf.matrix_determinant(core[0, :, :, 0])
    core_abs_det = tf.abs(core_det)
    core_det_sign = tf.sign(core_det)
    core_pow = pows / i_shapes[core_idx].value
    logdet += tf.log(core_abs_det) * core_pow
    det_sign *= core_det_sign**(core_pow)


  return det_sign, logdet
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号