python类circulant()的实例源码

wishart_test.py 文件源码 项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码 阅读 18 收藏 0 点赞 0 评论 0
def make_pd(start, n):
  """Deterministically create a positive definite matrix."""
  x = np.tril(linalg.circulant(np.arange(start, start + n)))
  return np.dot(x, x.T)
update_z.py 文件源码 项目:alphacsc 作者: alphacsc 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def gram_block_circulant(ds, n_times_valid, method='full',
                         sample_weights=None):
    """Returns ...

    Parameters
    ----------
    ds : array, shape (n_atoms, n_times_atom)
        The atoms
    n_times_valid : int
        n_times - n_times_atom + 1
    method : string
        If 'full', returns full circulant matrix.
        If 'scipy', returns scipy linear operator.
        If 'custom', returns custom linear operator.
    sample_weights : array, shape (n_times, )
        The sample weights for one trial
    """
    from scipy.sparse.linalg import LinearOperator
    from functools import partial

    n_atoms, n_times_atom = ds.shape
    n_times = n_times_valid + n_times_atom - 1

    if method == 'full':
        D = np.zeros((n_times, n_atoms * n_times_valid))
        for k_idx in range(n_atoms):
            d_padded = np.zeros((n_times, ))
            d_padded[:n_times_atom] = ds[k_idx]
            start = k_idx * n_times_valid
            stop = start + n_times_valid
            D[:, start:stop] = linalg.circulant((d_padded))[:, :n_times_valid]
        if sample_weights is not None:
            wD = sample_weights[:, None] * D
            return np.dot(D.T, wD)
        else:
            return np.dot(D.T, D)

    elif method == 'scipy':

        def matvec(v, ds):
            assert v.shape[0] % ds.shape[0] == 0
            return _fprime(ds, v, Xi=None, reg=None,
                           sample_weights=sample_weights)

        D = LinearOperator((n_atoms * n_times_valid, n_atoms * n_times_valid),
                           matvec=partial(matvec, ds=ds))

    elif method == 'custom':
        return CustomLinearOperator(ds, n_times_valid, sample_weights)
    else:
        raise ValueError('Unkown method %s.' % method)
    return D


问题


面经


文章

微信
公众号

扫码关注公众号