test_learn_d_z.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:alphacsc 作者: alphacsc 项目源码 文件源码
def test_linear_operator():
    """Test linear operator."""
    n_times, n_atoms, n_times_atom = 128, 32, 32
    n_times_valid = n_times - n_times_atom + 1

    rng = check_random_state(42)
    ds = rng.randn(n_atoms, n_times_atom)
    some_sample_weights = np.abs(rng.randn(n_times))

    for sample_weights in [None, some_sample_weights]:
        gbc = partial(gram_block_circulant, ds=ds, n_times_valid=n_times_valid,
                      sample_weights=sample_weights)
        DTD_full = gbc(method='full')
        DTD_scipy = gbc(method='scipy')
        DTD_custom = gbc(method='custom')

        z = rng.rand(DTD_full.shape[1])
        assert_allclose(DTD_full.dot(z), DTD_scipy.dot(z))
        assert_allclose(DTD_full.dot(z), DTD_custom.dot(z))

        # test power iterations with linear operator
        mu, _ = linalg.eigh(DTD_full)
        t = []
        for DTD in [DTD_full, DTD_scipy, DTD_custom]:
            start = time.time()
            mu_hat = power_iteration(DTD)
            t.append(time.time() - start)
            assert_allclose(np.max(mu), mu_hat, rtol=1e-2)

        print(t)
    assert_true(t[1] < t[0])
    assert_true(t[2] < t[0])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号