def test_dict_fact(method, memory):
if memory:
memory = Memory(cachedir=get_cache_dirs()[0])
memory_level = 2
else:
if method != 'masked':
pytest.skip()
memory = Memory(cachedir=None)
memory_level = 0
data, mask_img, components, init = _make_test_data(n_subjects=10)
dict_fact = fMRIDictFact(n_components=4, random_state=0,
memory=memory,
memory_level=memory_level,
mask=mask_img,
dict_init=init,
method=method,
reduction=2,
smoothing_fwhm=0., n_epochs=2, alpha=1)
dict_fact.fit(data)
maps = np.rollaxis(dict_fact.components_img_.get_data(), 3, 0)
components = np.rollaxis(components.get_data(), 3, 0)
maps = maps.reshape((maps.shape[0], -1))
components = components.reshape((components.shape[0], -1))
S = np.sqrt(np.sum(components ** 2, axis=1))
S[S == 0] = 1
components /= S[:, np.newaxis]
S = np.sqrt(np.sum(maps ** 2, axis=1))
S[S == 0] = 1
maps /= S[:, np.newaxis]
G = np.abs(components.dot(maps.T))
recovered_maps = np.sum(G > 0.95)
assert (recovered_maps >= 4)
评论列表
文章目录