def dictionary_learn_ex():
patch_shape = (18, 18)
n_atoms = 225
n_nonzero_coefs = 2
n_plot_atoms = 100
n_jobs = 8
lfw_people = fetch_lfw_people(min_faces_per_person=70, resize=0.4,color=False)
#faces = lfw_people.data
n_imgs, h, w = lfw_people.images.shape
imgs = []
for i in range(n_imgs):
img = lfw_people.images[i, :, :].reshape((h, w))
img /= 255.
imgs.append(img)
print 'Extracting reference patches...'
X = extract_patches(imgs, patch_size=patch_shape[0],scale=False,n_patches=int(1e5),verbose=True,n_jobs=n_jobs)
print "number of patches:", X.shape[1]
se = sparse_encoder(algorithm='bomp',params={'n_nonzero_coefs': n_nonzero_coefs}, n_jobs=n_jobs)
odc = online_dictionary_coder(n_atoms=n_atoms, sparse_coder=se, n_epochs=1,
batch_size=1000, non_neg=False, verbose=True, n_jobs=n_jobs)
odc.fit(X)
D = odc.D
n_atoms_plot = 225
plt.figure(figsize=(4.2, 4))
for i in range(n_atoms_plot):
plt.subplot(15, 15, i + 1)
plt.imshow(D[:, i].reshape(patch_shape), cmap=plt.cm.gray)
plt.xticks(())
plt.yticks(())
plt.show()
评论列表
文章目录