def test_linear_operator():
"""Test linear operator."""
n_times, n_atoms, n_times_atom = 128, 32, 32
n_times_valid = n_times - n_times_atom + 1
rng = check_random_state(42)
ds = rng.randn(n_atoms, n_times_atom)
some_sample_weights = np.abs(rng.randn(n_times))
for sample_weights in [None, some_sample_weights]:
gbc = partial(gram_block_circulant, ds=ds, n_times_valid=n_times_valid,
sample_weights=sample_weights)
DTD_full = gbc(method='full')
DTD_scipy = gbc(method='scipy')
DTD_custom = gbc(method='custom')
z = rng.rand(DTD_full.shape[1])
assert_allclose(DTD_full.dot(z), DTD_scipy.dot(z))
assert_allclose(DTD_full.dot(z), DTD_custom.dot(z))
# test power iterations with linear operator
mu, _ = linalg.eigh(DTD_full)
t = []
for DTD in [DTD_full, DTD_scipy, DTD_custom]:
start = time.time()
mu_hat = power_iteration(DTD)
t.append(time.time() - start)
assert_allclose(np.max(mu), mu_hat, rtol=1e-2)
print(t)
assert_true(t[1] < t[0])
assert_true(t[2] < t[0])
评论列表
文章目录