def whitened_rgb_atoms():
#a small dataset of images
imgs = get_images(colored=True)
#alternatively we could use the lfw dataset
"""
lfw_people = fetch_lfw_people(min_faces_per_person=70, resize=0.4,color=True)
faces = lfw_people.data
n_imgs,h,w,n_channels = lfw_people.images.shape
imgs = []
for i in range(n_imgs):
img = lfw_people.images[i,:,:,:].reshape((h,w,n_channels))
imgs.append(img)
"""
patch_shape = (8,8)
n_atoms = 100
n_plot_atoms = 100
n_nonzero_coefs = 1
print 'Extracting reference patches...'
X = extract_patches(imgs, patch_size=patch_shape[0],scale=False,n_patches=int(5e5),mem="low")
print "number of patches:",X.shape[1]
wn = preproc("whitening")
from lyssa.feature_extract.preproc import local_contrast_normalization
#apply lcn and then whiten the patches
X = wn(local_contrast_normalization(X))
#learn the dictionary using Batch Orthognal Matching Pursuit and KSVD
se = sparse_encoder(algorithm='bomp',params={'n_nonzero_coefs':n_nonzero_coefs},n_jobs=8)
kc = ksvd_coder(n_atoms=n_atoms,sparse_coder=se,init_dict = "data",
max_iter=10,verbose=True,approx=False,n_jobs=8)
kc.fit(X)
D = kc.D
for i in range(n_atoms):
D[:, i] = (D[:, i] - D[:, i].min()) / float((D[:, i].max() - D[:, i].min()))
#plot the learned dictionary
plt.figure(figsize=(4.2, 4))
for i in range(n_plot_atoms):
plt.subplot(10, 10, i + 1)
plt.imshow(D[:, i].reshape((patch_shape[0], patch_shape[1], 3)))
plt.xticks(())
plt.yticks(())
plt.show()
评论列表
文章目录